Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8e7d6da
Add LLM inference support to JMLC API via Py4J bridge
kubraaksux Feb 12, 2026
47dd0db
Refactor loadModel to accept worker script path as parameter
kubraaksux Feb 13, 2026
672a3fa
Add dynamic port allocation and improve resource cleanup
kubraaksux Feb 13, 2026
dacdc1c
Move llm_worker.py to fix Python module collision
kubraaksux Feb 13, 2026
29f657c
Use python3 with fallback to python in Connection.java
kubraaksux Feb 14, 2026
e40e4f2
Add batch inference with FrameBlock and metrics support
kubraaksux Feb 14, 2026
fdd1684
Clean up test: extract constants and shared setup method
kubraaksux Feb 14, 2026
b9ba3e0
Add token counts, GPU support, and improve error handling
kubraaksux Feb 14, 2026
2e984a2
Increase worker startup timeout to 300s for larger models
kubraaksux Feb 16, 2026
bf666c2
Revert accidental changes to MatrixBlockDictionary.java
kubraaksux Feb 16, 2026
5faa691
Add GPU batching support to JMLC LLM inference
kubraaksux Feb 16, 2026
c9c85d4
Keep both sequential and batched inference modes in PreparedScript
kubraaksux Feb 16, 2026
4b44dd1
Add gitignore rules for .env files, meeting notes, and local tool config
kubraaksux Feb 16, 2026
72bc334
Add llmPredict builtin, opcode and ParamBuiltinOp entries
kubraaksux Feb 16, 2026
0ad1b56
Add llmPredict parser validation in ParameterizedBuiltinFunctionExpre…
kubraaksux Feb 16, 2026
1e48362
Wire llmPredict through hop, lop and instruction generation
kubraaksux Feb 16, 2026
de675ac
Add llmPredict CP instruction with HTTP-based inference
kubraaksux Feb 16, 2026
5eab87d
Remove Py4J-based LLM inference from JMLC API
kubraaksux Feb 16, 2026
bea062a
Rewrite LLM test to use llmPredict DML built-in
kubraaksux Feb 16, 2026
edf4e39
Add OpenAI-compatible HTTP inference server for HuggingFace models
kubraaksux Feb 16, 2026
45882e2
Fix llmPredict code quality issues
kubraaksux Feb 16, 2026
c3e9a1f
Add concurrency parameter to llmPredict built-in
kubraaksux Feb 16, 2026
53e3feb
Remove license header from test, clarify llm_server.py docstring
kubraaksux Feb 16, 2026
e872f22
Fix JMLC frame binding: match DML variable names to registered inputs
kubraaksux Feb 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,3 @@ docker/mountFolder/*.bin
docker/mountFolder/*.bin.mtd

SEAL-*/

1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ public enum Builtins {
LMDS("lmDS", true),
LMPREDICT("lmPredict", true),
LMPREDICT_STATS("lmPredictStats", true),
LLMPREDICT("llmPredict", false, true),
LOCAL("local", false),
LOG("log", false),
LOGSUMEXP("logSumExp", true),
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Opcodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ public enum Opcodes {
GROUPEDAGG("groupedagg", InstructionType.ParameterizedBuiltin),
RMEMPTY("rmempty", InstructionType.ParameterizedBuiltin),
REPLACE("replace", InstructionType.ParameterizedBuiltin),
LLMPREDICT("llmpredict", InstructionType.ParameterizedBuiltin),
LOWERTRI("lowertri", InstructionType.ParameterizedBuiltin),
UPPERTRI("uppertri", InstructionType.ParameterizedBuiltin),
REXPAND("rexpand", InstructionType.ParameterizedBuiltin),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) {

/** Parameterized operations that require named variable arguments */
public enum ParamBuiltinOp {
AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, LLMPREDICT, RMEMPTY, REPLACE, REXPAND,
LOWER_TRI, UPPER_TRI,
TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
TOKENIZE, TOSTRING, LIST, PARAMSERV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ public Lop constructLops()
case LOWER_TRI:
case UPPER_TRI:
case TOKENIZE:
case LLMPREDICT:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
Expand Down Expand Up @@ -758,7 +759,7 @@ && getTargetHop().areDimsBelowThreshold() ) {
if (_op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA
|| _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST
|| _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF
|| _op == ParamBuiltinOp.PARAMSERV) {
|| _op == ParamBuiltinOp.PARAMSERV || _op == ParamBuiltinOp.LLMPREDICT) {
_etype = ExecType.CP;
}

Expand All @@ -768,7 +769,7 @@ && getTargetHop().areDimsBelowThreshold() ) {
switch(_op) {
case CONTAINS:
if(getTargetHop().optFindExecType() == ExecType.SPARK)
_etype = ExecType.SPARK;
_etype = ExecType.SPARK;
break;
default:
// Do not change execution type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ public String getInstructions(String output)
case CONTAINS:
case REPLACE:
case TOKENIZE:
case LLMPREDICT:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu
case LOWER_TRI:
case UPPER_TRI:
case TOKENIZE:
case LLMPREDICT:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG);
pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY);
pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE);
pbHopMap.put(Builtins.LLMPREDICT, ParamBuiltinOp.LLMPREDICT);
pbHopMap.put(Builtins.LOWER_TRI, ParamBuiltinOp.LOWER_TRI);
pbHopMap.put(Builtins.UPPER_TRI, ParamBuiltinOp.UPPER_TRI);

Expand Down Expand Up @@ -211,6 +212,10 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
validateOrder(output, conditional);
break;

case LLMPREDICT:
validateLlmPredict(output, conditional);
break;

case TOKENIZE:
validateTokenize(output, conditional);
break;
Expand Down Expand Up @@ -614,6 +619,18 @@ private void validateTokenize(DataIdentifier output, boolean conditional)
output.setDimensions(-1, -1);
}

private void validateLlmPredict(DataIdentifier output, boolean conditional)
{
Set<String> valid = new HashSet<>(Arrays.asList(
"target", "url", "max_tokens", "temperature", "top_p", "concurrency"));
checkInvalidParameters(getOpCode(), getVarParams(), valid);
checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional);
checkStringParam(false, "llmPredict", "url", conditional);
output.setDataType(DataType.FRAME);
output.setValueType(ValueType.STRING);
output.setDimensions(-1, -1);
}

// example: A = transformapply(target=X, meta=M, spec=s)
private void validateTransformApply(DataIdentifier output, boolean conditional)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,25 @@

package org.apache.sysds.runtime.instructions.cp;

import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.wink.json4j.JSONObject;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -154,7 +165,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.RMEMPTY.toString()) || opcode.equalsIgno
}
else if(opcode.equals(Opcodes.TRANSFORMAPPLY.toString()) || opcode.equals(Opcodes.TRANSFORMDECODE.toString())
|| opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString()) || opcode.equals(Opcodes.TRANSFORMCOLMAP.toString())
|| opcode.equals(Opcodes.TRANSFORMMETA.toString()) || opcode.equals(Opcodes.TOKENIZE.toString())
|| opcode.equals(Opcodes.TRANSFORMMETA.toString()) || opcode.equals(Opcodes.TOKENIZE.toString()) || opcode.equals(Opcodes.LLMPREDICT.toString())
|| opcode.equals(Opcodes.TOSTRING.toString()) || opcode.equals(Opcodes.NVLIST.toString()) || opcode.equals(Opcodes.AUTODIFF.toString())) {
return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str);
}
Expand Down Expand Up @@ -324,6 +335,60 @@ else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) {
ec.setFrameOutput(output.getName(), fbout);
ec.releaseFrameInput(params.get("target"));
}

else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) {
FrameBlock prompts = ec.getFrameInput(params.get("target"));
String url = params.get("url");
int maxTokens = params.containsKey("max_tokens") ?
Integer.parseInt(params.get("max_tokens")) : 512;
double temperature = params.containsKey("temperature") ?
Double.parseDouble(params.get("temperature")) : 0.0;
double topP = params.containsKey("top_p") ?
Double.parseDouble(params.get("top_p")) : 0.9;
int concurrency = params.containsKey("concurrency") ?
Integer.parseInt(params.get("concurrency")) : 1;

int n = prompts.getNumRows();
String[][] data = new String[n][];

// build one callable per prompt
List<Callable<String[]>> tasks = new ArrayList<>(n);
for(int i = 0; i < n; i++) {
String prompt = prompts.get(i, 0).toString();
tasks.add(() -> callLlmEndpoint(prompt, url, maxTokens, temperature, topP));
}

try {
if(concurrency <= 1) {
// sequential
for(int i = 0; i < n; i++)
data[i] = tasks.get(i).call();
}
else {
// parallel
ExecutorService pool = Executors.newFixedThreadPool(
Math.min(concurrency, n));
List<Future<String[]>> futures = pool.invokeAll(tasks);
pool.shutdown();
for(int i = 0; i < n; i++)
data[i] = futures.get(i).get();
}
}
catch(Exception e) {
throw new DMLRuntimeException("llmPredict failed: " + e.getMessage(), e);
}

ValueType[] schema = {ValueType.STRING, ValueType.STRING,
ValueType.INT64, ValueType.INT64, ValueType.INT64};
String[] colNames = {"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"};
FrameBlock fbout = new FrameBlock(schema, colNames);
for(String[] row : data)
fbout.appendRow(row);

ec.setFrameOutput(output.getName(), fbout);
ec.releaseFrameInput(params.get("target"));
}

else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) {
// acquire locks
FrameBlock data = ec.getFrameInput(params.get("target"));
Expand Down Expand Up @@ -488,6 +553,50 @@ private void warnOnTrunction(TensorBlock data, int rows, int cols) {
}
}

private static String[] callLlmEndpoint(String prompt, String url,
int maxTokens, double temperature, double topP) throws Exception {
long t0 = System.nanoTime();
JSONObject req = new JSONObject();
req.put("prompt", prompt);
req.put("max_tokens", maxTokens);
req.put("temperature", temperature);
req.put("top_p", topP);

HttpURLConnection conn = (HttpURLConnection)
new URI(url).toURL().openConnection();
conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/json");
conn.setConnectTimeout(10_000);
conn.setReadTimeout(120_000);
conn.setDoOutput(true);

try(OutputStream os = conn.getOutputStream()) {
os.write(req.toString().getBytes(StandardCharsets.UTF_8));
}
if(conn.getResponseCode() != 200)
throw new DMLRuntimeException(
"LLM endpoint returned HTTP " + conn.getResponseCode());

String body;
try(InputStream is = conn.getInputStream()) {
body = new String(is.readAllBytes(), StandardCharsets.UTF_8);
}
conn.disconnect();

JSONObject resp = new JSONObject(body);
String text = resp.getJSONArray("choices")
.getJSONObject(0).getString("text");
long elapsed = (System.nanoTime() - t0) / 1_000_000;
int inTok = 0, outTok = 0;
if(resp.has("usage")) {
JSONObject usage = resp.getJSONObject("usage");
inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0;
outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0;
}
return new String[]{prompt, text,
String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)};
}

@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
String opcode = getOpcode();
Expand Down Expand Up @@ -549,6 +658,12 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString()) || opcode.eq
return Pair.of(output.getName(),
new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, meta, spec)));
}
else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) {
CPOperand target = new CPOperand(params.get("target"), ValueType.STRING, DataType.FRAME);
CPOperand urlOp = getStringLiteral("url");
return Pair.of(output.getName(),
new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, urlOp)));
}
else if (opcode.equalsIgnoreCase(Opcodes.NVLIST.toString()) || opcode.equalsIgnoreCase(Opcodes.AUTODIFF.toString())) {
List<String> names = new ArrayList<>(params.keySet());
CPOperand[] listOperands = names.stream().map(n -> ec.containsVariable(params.get(n))
Expand Down
96 changes: 96 additions & 0 deletions src/main/python/llm_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Local inference server for llmPredict. Loads a HuggingFace model
and serves it at http://localhost:PORT/v1/completions.

Usage: python llm_server.py distilgpt2 --port 8080
"""

import argparse
import json
import sys
import time
from http.server import HTTPServer, BaseHTTPRequestHandler

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


class InferenceHandler(BaseHTTPRequestHandler):

def do_POST(self):
if self.path != "/v1/completions":
self.send_error(404)
return
length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(length))

prompt = body.get("prompt", "")
max_tokens = int(body.get("max_tokens", 512))
temperature = float(body.get("temperature", 0.0))
top_p = float(body.get("top_p", 0.9))

model = self.server.model
tokenizer = self.server.tokenizer

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature if temperature > 0 else 1.0,
top_p=top_p,
do_sample=temperature > 0,
)
new_tokens = outputs[0][input_len:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True)

resp = {
"choices": [{"text": text}],
"usage": {
"prompt_tokens": input_len,
"completion_tokens": len(new_tokens),
},
}
payload = json.dumps(resp).encode("utf-8")
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(payload)))
self.end_headers()
self.wfile.write(payload)

def log_message(self, fmt, *args):
sys.stderr.write("[llm_server] %s\n" % (fmt % args))


def main():
parser = argparse.ArgumentParser(description="OpenAI-compatible LLM server")
parser.add_argument("model", help="HuggingFace model name")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()

print(f"Loading model: {args.model}", flush=True)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if torch.cuda.is_available():
print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True)
model = AutoModelForCausalLM.from_pretrained(
args.model, device_map="auto", torch_dtype=torch.float16)
else:
model = AutoModelForCausalLM.from_pretrained(args.model)
model.eval()
print(f"Model loaded on {next(model.parameters()).device}", flush=True)

server = HTTPServer(("0.0.0.0", args.port), InferenceHandler)
server.model = model
server.tokenizer = tokenizer
print(f"Serving on http://0.0.0.0:{args.port}/v1/completions", flush=True)
try:
server.serve_forever()
except KeyboardInterrupt:
print("Shutting down", flush=True)
server.server_close()


if __name__ == "__main__":
main()
Loading