Skip to content

Add LLM inference support to JMLC API#2430

Open
kubraaksux wants to merge 24 commits intoapache:mainfrom
kubraaksux:llm-api
Open

Add LLM inference support to JMLC API#2430
kubraaksux wants to merge 24 commits intoapache:mainfrom
kubraaksux:llm-api

Conversation

@kubraaksux
Copy link

@kubraaksux kubraaksux commented Feb 13, 2026

Adds LLM text generation to the JMLC API using Py4J to bridge Java and Python (HuggingFace models).

Changes

  • Connection.java: loadModel() / releaseModel() to start and stop the Python worker (300s timeout for large models)
  • PreparedScript.java: generateBatchWithMetrics() for batch inference via FrameBlock — now uses a single generateBatch() call to the Python worker instead of a per-prompt loop
  • LLMCallback.java: Java interface for the Py4J callback, including generateBatch() for batched GPU inference
  • llm_worker.py: Python worker that loads HuggingFace models and serves inference requests, with batched tokenization and model.generate() for GPU parallelism
  • JMLCLLMInferenceTest.java: Integration test using distilgpt2

GPU batching

The latest update adds true GPU batching: all prompts are tokenized together (with padding) and processed in a single model.generate() call. This achieves 3-14x speedup over the previous sequential per-prompt approach on NVIDIA H100, making SystemDS JMLC faster than sequential vLLM for batch workloads. See #2431 for full benchmark results.

Test

mvn test -Dtest=JMLCLLMInferenceTest -pl .

Also evaluated with Qwen/Qwen2.5-3B-Instruct and mistralai/Mistral-7B-Instruct-v0.3 on NVIDIA H100 in the benchmarking framework (#2431).

- Connection.java: Changed loadModel(modelName) to loadModel(modelName, workerScriptPath)
- Connection.java: Removed findPythonScript() method
- LLMCallback.java: Added Javadoc for generate() method
- JMLCLLMInferenceTest.java: Updated to pass script path to loadModel()
- Connection.java: Auto-find available ports for Py4J communication
- Connection.java: Add loadModel() overload for manual port override
- Connection.java: Use destroyForcibly() with waitFor() for clean shutdown
- llm_worker.py: Accept python_port as command line argument
Move worker script from src/main/python/systemds/ to src/main/python/
to avoid shadowing Python stdlib operator module.
- Add generateWithTokenCount() returning JSON with input/output token counts
- Update generateBatchWithMetrics() to include input_tokens and output_tokens columns
- Add CUDA auto-detection with device_map=auto for multi-GPU support in llm_worker.py
- Check Python process liveness during startup instead of blind 60s timeout
7B+ models need more time to load weights into GPU memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, how are these changes related to the llm inference?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, you're right. They seem to be from the Nicolas Korjahn's shampoo optimizer code that was already in my branch when I branched off main. They got accidentally included in my commit. I've reverted them now, the file should be back to its original state. Sorry about that!

This file was accidentally modified in a prior commit. Restoring the
original vectorized SIMD implementation.
- LLMCallback.java: add generateBatch() interface method
- PreparedScript.java: replace per-prompt for-loop with single batch call
- llm_worker.py: implement batched tokenization and model.generate()

Achieves 3-14x speedup over sequential inference on H100.
generateBatchWithMetrics() now accepts a boolean batched parameter:
true for GPU-batched (new), false for original sequential for-loop.
@e-strauss
Copy link
Contributor

Hi @kubraaksux , thanks for the contribution!
I have a concern about the current approach: I’m not sure moving LLM inference into Python is the right direction, especially since most calls still go through Python wrapper functions and there’s additional overhead from using Py4J.
Also, as implemented now, it seems we’re bypassing systemd’s core functionality entirely.
Looping in @mboehm7 .

@kubraaksux
Copy link
Author

Hi @kubraaksux , thanks for the contribution! I have a concern about the current approach: I’m not sure moving LLM inference into Python is the right direction, especially since most calls still go through Python wrapper functions and there’s additional overhead from using Py4J. Also, as implemented now, it seems we’re bypassing systemd’s core functionality entirely. Looping in @mboehm7 .

Hi @e-strauss, thanks for the feedback. Both points are valid.

I redesigned the approach. Instead of the Py4J bridge, llmPredict is now a native parameterized built-in. The DML goes through the full compilation pipeline: parser → hops → lops → CP instruction. The instruction makes HTTP calls directly via java.net.HttpURLConnection.

Thanks again for catching this early.

- Use proper imports instead of inline fully-qualified class names
- Add try-with-resources for HTTP streams to prevent resource leaks
- Add connect/read timeouts to HTTP calls
- Add lineage tracing support for llmPredict
- Add checkInvalidParameters validation in parser
- Remove .claude/.env/meeting_notes from .gitignore
- Trim verbose docstrings
@e-strauss
Copy link
Contributor

Hey @kubraaksux, just sharing my thoughts here — not trying to push in any direction, since I’m not the project supervisor. Let’s wait for Matthias’s feedback.

Supports parallel HTTP calls to the inference server via
ExecutorService. Default concurrency=1 keeps sequential behavior.
@kubraaksux
Copy link
Author

kubraaksux commented Feb 16, 2026

Hey @kubraaksux, just sharing my thoughts here — not trying to push in any direction, since I’m not the project supervisor. Let’s wait for Matthias’s feedback.

Your points were helpful. I've reworked the approach accordingly. Looking forward to @mboehm7 's input.

JMLC requires the LHS variable name in read() assignments to match
the input name registered in prepareScript(). Changed X/R to
prompts/results so RewriteRemovePersistentReadWrite correctly
converts persistent reads to transient reads.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

2 participants