From 8e7d6da46a592fc99338dd198f5d4f266884df21 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:36:38 +0100 Subject: [PATCH 01/24] Add LLM inference support to JMLC API via Py4J bridge --- .../org/apache/sysds/api/jmlc/Connection.java | 119 ++++++++++++++++++ .../apache/sysds/api/jmlc/LLMCallback.java | 9 ++ .../apache/sysds/api/jmlc/PreparedScript.java | 38 ++++++ .../dictionary/MatrixBlockDictionary.java | 26 +--- src/main/python/systemds/llm_worker.py | 63 ++++++++++ .../functions/jmlc/JMLCLLMInferenceTest.java | 83 ++++++++++++ 6 files changed, 316 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java create mode 100644 src/main/python/systemds/llm_worker.py create mode 100644 src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 525c1a97bb2..c58241a222b 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -28,7 +28,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.sysds.hops.OptimizerUtils; @@ -66,6 +70,7 @@ import org.apache.sysds.runtime.transform.meta.TfMetaUtils; import org.apache.sysds.runtime.util.CollectionUtils; import org.apache.sysds.runtime.util.DataConverter; +import py4j.GatewayServer; /** * Interaction with SystemDS using the JMLC (Java Machine Learning Connector) API is initiated with @@ -91,6 +96,12 @@ public class Connection implements Closeable private final DMLConfig _dmlconf; private final CompilerConfig _cconf; private static FileSystem fs = null; + private Process _pythonProcess = null; + private py4j.GatewayServer _gatewayServer = null; + private LLMCallback _llmWorker = null; + private CountDownLatch _workerLatch = null; + + private static final Log LOG = LogFactory.getLog(Connection.class.getName()); /** * Connection constructor, the starting point for any other JMLC API calls. @@ -287,6 +298,103 @@ public PreparedScript prepareScript(String script, Map nsscripts, //return newly create precompiled script return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); } + + /** + * Loads a HuggingFace model via Python worker for LLM inference. + * Starts a Python subprocess and connects via Py4J. + * + * @param modelName HuggingFace model name (e.g., "distilgpt2") + * @return LLMCallback interface to the Python worker + */ + public LLMCallback loadModel(String modelName) { + if (_llmWorker != null) + return _llmWorker; + try { + // Initialize latch for worker registration + _workerLatch = new CountDownLatch(1); + + // Start Py4J gateway server with callback support + _gatewayServer = new GatewayServer.GatewayServerBuilder() + .entryPoint(this) + .javaPort(25333) + .callbackClient(25334, java.net.InetAddress.getLoopbackAddress()) + .build(); + _gatewayServer.start(); + + // Give gateway time to fully start accepting connections + Thread.sleep(500); + + // Find the Python script - try multiple locations + String pythonScript = findPythonScript(); + LOG.info("Starting LLM worker with script: " + pythonScript); + + _pythonProcess = new ProcessBuilder( + "python", pythonScript, modelName, "25333" + ).redirectErrorStream(true).start(); + + // Read Python process output in background thread + new Thread(() -> { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(_pythonProcess.getInputStream()))) { + String line; + while ((line = reader.readLine()) != null) { + LOG.info("[LLM Worker] " + line); + } + } catch (IOException e) { + LOG.error("Error reading LLM worker output", e); + } + }).start(); + + // Wait for worker to register with timeout + if (!_workerLatch.await(60, TimeUnit.SECONDS)) { + throw new DMLException("Timeout waiting for LLM worker to register"); + } + + } catch (DMLException e) { + throw e; + } catch (Exception e) { + throw new DMLException("Failed to start LLM worker: " + e.getMessage()); + } + return _llmWorker; + } + + /** + * Called by Python worker to register itself via Py4J. + */ + public void registerWorker(LLMCallback worker) { + _llmWorker = worker; + if (_workerLatch != null) { + _workerLatch.countDown(); + } + LOG.info("LLM worker registered successfully"); + } + + /** + * Finds the Python LLM worker script by checking multiple possible locations. + * @return absolute path to the Python script + * @throws IOException if script cannot be found + */ + private String findPythonScript() throws IOException { + String[] possiblePaths = { + // Relative to project root (when running from IDE or mvn) + "src/main/python/systemds/llm_worker.py", + // Relative to target directory (when running tests) + "../src/main/python/systemds/llm_worker.py", + // Absolute path using system property + System.getProperty("user.dir") + "/src/main/python/systemds/llm_worker.py" + }; + + for (String path : possiblePaths) { + java.io.File f = new java.io.File(path); + if (f.exists()) { + return f.getAbsolutePath(); + } + } + + // If not found, return the default and let it fail with a clear error + throw new IOException("Cannot find llm_worker.py. Searched: " + + String.join(", ", possiblePaths) + ". Current dir: " + System.getProperty("user.dir")); + } /** * Close connection to SystemDS, which clears the @@ -294,6 +402,17 @@ public PreparedScript prepareScript(String script, Map nsscripts, */ @Override public void close() { + + //shutdown LLM worker if running + if (_pythonProcess != null) { + _pythonProcess.destroy(); + _pythonProcess = null; + } + if (_gatewayServer != null) { + _gatewayServer.shutdown(); + _gatewayServer = null; + } + //clear thread-local configurations ConfigurationManager.clearLocalConfigs(); if( ConfigurationManager.isCodegenEnabled() ) diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java new file mode 100644 index 00000000000..68f1767994e --- /dev/null +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -0,0 +1,9 @@ +package org.apache.sysds.api.jmlc; + +/** + * Interface for the Python LLM worker. + * The Python side implements this via Py4J callback. + */ +public interface LLMCallback { + String generate(String prompt, int maxNewTokens, double temperature, double topP); +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index 31bb7457227..2e6109d0102 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -80,6 +80,9 @@ public class PreparedScript implements ConfigurableAPI private final CompilerConfig _cconf; private HashMap _outVarLineage; + //LLM inference support + private LLMCallback _llmWorker = null; + private PreparedScript(PreparedScript that) { //shallow copy, except for a separate symbol table //and related meta data of reused inputs @@ -160,6 +163,41 @@ public CompilerConfig getCompilerConfig() { return _cconf; } + /** + * Sets the LLM worker callback for text generation. + * + * @param worker the LLM callback interface + */ + public void setLLMWorker(LLMCallback worker) { + _llmWorker = worker; + } + + /** + * Gets the LLM worker callback. + * + * @return the LLM callback interface, or null if not set + */ + public LLMCallback getLLMWorker() { + return _llmWorker; + } + + /** + * Generates text using the LLM worker. + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return generated text + * @throws DMLException if no LLM worker is set + */ + public String generate(String prompt, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + return _llmWorker.generate(prompt, maxNewTokens, temperature, topP); + } + /** * Binds a scalar boolean to a registered input variable. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 71a4112f157..f6b09d4384a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -67,8 +67,6 @@ public class MatrixBlockDictionary extends ADictionary { final private MatrixBlock _data; - static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; - /** * Unsafe private constructor that does not check the data validity. USE WITH CAUTION. * @@ -2127,9 +2125,6 @@ private void preaggValuesFromDenseDictDenseAggRangeRange(final int numVals, fina private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, double[] ret, int bi, int bk, int bj, int bie, int bke, int cz, int az, int ls, int cut, int sOffT, int eOffT) { - final int vLen = SPECIES.length(); - final DoubleVector vVec = DoubleVector.zero(SPECIES); - final int leftover = (eOffT - sOffT) % vLen; // leftover not vectorized for(int i = bi; i < bie; i++) { final int offI = i * cz; final int offOutT = i * az + bj; @@ -2138,27 +2133,14 @@ private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, final int sOff = sOffT + idb; final int eOff = eOffT + idb; final double v = a[offI + k]; - vecInnerLoop(v, b, ret, offOutT, eOff, sOff, leftover, vLen, vVec); + int offOut = offOutT; + for(int j = sOff; j < eOff; j++, offOut++) { + ret[offOut] += v * b[j]; + } } } } - private static void vecInnerLoop(final double v, final double[] b, final double[] ret, final int offOutT, - final int eOff, final int sOff, final int leftover, final int vLen, DoubleVector vVec) { - int offOut = offOutT; - vVec = vVec.broadcast(v); - final int end = eOff - leftover; - for(int j = sOff; j < end; j += vLen, offOut += vLen) { - DoubleVector res = DoubleVector.fromArray(SPECIES, ret, offOut); - DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, j); - vVec.fma(bVec, res).intoArray(ret, offOut); - } - for(int j = end; j < eOff; j++, offOut++) { - ret[offOut] += v * b[j]; - } - - } - private void preaggValuesFromDenseDictDenseAggRangeGeneric(final int numVals, final IColIndex colIndexes, final int s, final int e, final double[] b, final int cut, final double[] ret) { final int cz = colIndexes.size(); diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/systemds/llm_worker.py new file mode 100644 index 00000000000..6aed4d1fbae --- /dev/null +++ b/src/main/python/systemds/llm_worker.py @@ -0,0 +1,63 @@ +""" +SystemDS LLM Worker — Python side of the Py4J bridge. +Java starts this script, then calls generate() via Py4J. +""" +import sys +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters + +class LLMWorker: + def __init__(self, model_name="distilgpt2"): + print(f"Loading model: {model_name}", flush=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.model.eval() + print(f"Model loaded: {model_name}", flush=True) + + def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): + inputs = self.tokenizer(prompt, return_tensors="pt") + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + do_sample=float(temperature) > 0.0 + ) + new_tokens = outputs[0][inputs["input_ids"].shape[1]:] + return self.tokenizer.decode(new_tokens, skip_special_tokens=True) + + class Java: + implements = ["org.apache.sysds.api.jmlc.LLMCallback"] + +if __name__ == "__main__": + model_name = sys.argv[1] if len(sys.argv) > 1 else "distilgpt2" + java_port = int(sys.argv[2]) if len(sys.argv) > 2 else 25333 + + print(f"Starting LLM worker, connecting to Java on port {java_port}", flush=True) + + worker = LLMWorker(model_name) + + # Connect to Java's GatewayServer and register this worker + # The callback_server starts a server on Python's side for Java to call back + # Use port 25334 which Java's CallbackClient expects + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=java_port), + callback_server_parameters=CallbackServerParameters(port=25334) + ) + + print(f"Python callback server started on port 25334", flush=True) + + gateway.entry_point.registerWorker(worker) + print("Worker registered with Java, waiting for requests...", flush=True) + + # Keep the worker alive to handle callbacks from Java + # The callback server runs in a daemon thread, so we need to block here + import threading + shutdown_event = threading.Event() + try: + # Wait indefinitely until Java closes the connection or kills the process + shutdown_event.wait() + except KeyboardInterrupt: + print("Worker shutting down", flush=True) \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java new file mode 100644 index 00000000000..c28dc768a95 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.jmlc; + +import org.apache.sysds.api.jmlc.Connection; +import org.apache.sysds.api.jmlc.LLMCallback; +import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.test.AutomatedTestBase; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test LLM inference capabilities via JMLC API. + * This test requires Python with transformers and torch installed. + */ +public class JMLCLLMInferenceTest extends AutomatedTestBase { + private final static String TEST_NAME = "JMLCLLMInferenceTest"; + private final static String TEST_DIR = "functions/jmlc/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testLLMInference() { + Connection conn = null; + try { + // Create a connection + conn = new Connection(); + + // Load the LLM model via Python worker + LLMCallback llmWorker = conn.loadModel("distilgpt2"); + Assert.assertNotNull("LLM worker should not be null", llmWorker); + + // Create a PreparedScript with a dummy script + String dummyScript = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(dummyScript, new String[]{}, new String[]{"x"}); + + // Set the LLM worker on the PreparedScript + ps.setLLMWorker(llmWorker); + + // Generate text using the LLM + String prompt = "The meaning of life is"; + String result = ps.generate(prompt, 20, 0.7, 0.9); + + // Assert the result is not null and not empty + Assert.assertNotNull("Generated text should not be null", result); + Assert.assertFalse("Generated text should not be empty", result.isEmpty()); + + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + result); + + } catch (Exception e) { + // Skip test if Python/transformers not available + System.out.println("Skipping LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) { + conn.close(); + } + } + } +} From 47dd0db1dfd95caecce251099bda101182359c0f Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:45:02 +0100 Subject: [PATCH 02/24] Refactor loadModel to accept worker script path as parameter - Connection.java: Changed loadModel(modelName) to loadModel(modelName, workerScriptPath) - Connection.java: Removed findPythonScript() method - LLMCallback.java: Added Javadoc for generate() method - JMLCLLMInferenceTest.java: Updated to pass script path to loadModel() --- .../org/apache/sysds/api/jmlc/Connection.java | 48 ++++--------------- .../apache/sysds/api/jmlc/LLMCallback.java | 12 ++++- src/main/python/systemds/llm_worker.py | 2 +- .../functions/jmlc/JMLCLLMInferenceTest.java | 20 ++++---- 4 files changed, 30 insertions(+), 52 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index c58241a222b..2ec7754d298 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -304,16 +304,17 @@ public PreparedScript prepareScript(String script, Map nsscripts, * Starts a Python subprocess and connects via Py4J. * * @param modelName HuggingFace model name (e.g., "distilgpt2") + * @param workerScriptPath path to the Python worker script (llm_worker.py) * @return LLMCallback interface to the Python worker */ - public LLMCallback loadModel(String modelName) { + public LLMCallback loadModel(String modelName, String workerScriptPath) { if (_llmWorker != null) return _llmWorker; try { - // Initialize latch for worker registration + //initialize latch for worker registration _workerLatch = new CountDownLatch(1); - // Start Py4J gateway server with callback support + //start Py4J gateway server with callback support _gatewayServer = new GatewayServer.GatewayServerBuilder() .entryPoint(this) .javaPort(25333) @@ -321,18 +322,16 @@ public LLMCallback loadModel(String modelName) { .build(); _gatewayServer.start(); - // Give gateway time to fully start accepting connections + //give gateway time to start Thread.sleep(500); - // Find the Python script - try multiple locations - String pythonScript = findPythonScript(); - LOG.info("Starting LLM worker with script: " + pythonScript); - + //start python worker process + LOG.info("Starting LLM worker with script: " + workerScriptPath); _pythonProcess = new ProcessBuilder( - "python", pythonScript, modelName, "25333" + "python", workerScriptPath, modelName, "25333" ).redirectErrorStream(true).start(); - // Read Python process output in background thread + //read python output in background thread new Thread(() -> { try (BufferedReader reader = new BufferedReader( new InputStreamReader(_pythonProcess.getInputStream()))) { @@ -345,7 +344,7 @@ public LLMCallback loadModel(String modelName) { } }).start(); - // Wait for worker to register with timeout + //wait for worker to register if (!_workerLatch.await(60, TimeUnit.SECONDS)) { throw new DMLException("Timeout waiting for LLM worker to register"); } @@ -369,33 +368,6 @@ public void registerWorker(LLMCallback worker) { LOG.info("LLM worker registered successfully"); } - /** - * Finds the Python LLM worker script by checking multiple possible locations. - * @return absolute path to the Python script - * @throws IOException if script cannot be found - */ - private String findPythonScript() throws IOException { - String[] possiblePaths = { - // Relative to project root (when running from IDE or mvn) - "src/main/python/systemds/llm_worker.py", - // Relative to target directory (when running tests) - "../src/main/python/systemds/llm_worker.py", - // Absolute path using system property - System.getProperty("user.dir") + "/src/main/python/systemds/llm_worker.py" - }; - - for (String path : possiblePaths) { - java.io.File f = new java.io.File(path); - if (f.exists()) { - return f.getAbsolutePath(); - } - } - - // If not found, return the default and let it fail with a clear error - throw new IOException("Cannot find llm_worker.py. Searched: " + - String.join(", ", possiblePaths) + ". Current dir: " + System.getProperty("user.dir")); - } - /** * Close connection to SystemDS, which clears the * thread-local DML and compiler configurations. diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java index 68f1767994e..09ee8debb29 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -5,5 +5,15 @@ * The Python side implements this via Py4J callback. */ public interface LLMCallback { - String generate(String prompt, int maxNewTokens, double temperature, double topP); + + /** + * Generates text using the LLM model. + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return generated text continuation + */ + String generate(String prompt, int maxNewTokens, double temperature, double topP); } \ No newline at end of file diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/systemds/llm_worker.py index 6aed4d1fbae..57872b37f82 100644 --- a/src/main/python/systemds/llm_worker.py +++ b/src/main/python/systemds/llm_worker.py @@ -1,5 +1,5 @@ """ -SystemDS LLM Worker — Python side of the Py4J bridge. +SystemDS LLM Worker - Python side of the Py4J bridge. Java starts this script, then calls generate() via Py4J. """ import sys diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index c28dc768a95..5909b0cef4c 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -44,25 +44,21 @@ public void setUp() { public void testLLMInference() { Connection conn = null; try { - // Create a connection + //create connection and load model conn = new Connection(); - - // Load the LLM model via Python worker - LLMCallback llmWorker = conn.loadModel("distilgpt2"); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/systemds/llm_worker.py"); Assert.assertNotNull("LLM worker should not be null", llmWorker); - // Create a PreparedScript with a dummy script - String dummyScript = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(dummyScript, new String[]{}, new String[]{"x"}); - - // Set the LLM worker on the PreparedScript + //create prepared script and set llm worker + String script = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); ps.setLLMWorker(llmWorker); - // Generate text using the LLM + //generate text using llm String prompt = "The meaning of life is"; String result = ps.generate(prompt, 20, 0.7, 0.9); - // Assert the result is not null and not empty + //verify result Assert.assertNotNull("Generated text should not be null", result); Assert.assertFalse("Generated text should not be empty", result.isEmpty()); @@ -70,7 +66,7 @@ public void testLLMInference() { System.out.println("Generated: " + result); } catch (Exception e) { - // Skip test if Python/transformers not available + //skip test if dependencies not available System.out.println("Skipping LLM test:"); e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); From 672a3faa2bc52b0b5687b9c510e0e8dad8d12bce Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:20:51 +0100 Subject: [PATCH 03/24] Add dynamic port allocation and improve resource cleanup - Connection.java: Auto-find available ports for Py4J communication - Connection.java: Add loadModel() overload for manual port override - Connection.java: Use destroyForcibly() with waitFor() for clean shutdown - llm_worker.py: Accept python_port as command line argument --- .../org/apache/sysds/api/jmlc/Connection.java | 56 ++++++++++++++++--- src/main/python/systemds/llm_worker.py | 11 ++-- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 2ec7754d298..13dcf2a0247 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -301,13 +301,30 @@ public PreparedScript prepareScript(String script, Map nsscripts, /** * Loads a HuggingFace model via Python worker for LLM inference. - * Starts a Python subprocess and connects via Py4J. + * Uses auto-detected available ports for Py4J communication. * - * @param modelName HuggingFace model name (e.g., "distilgpt2") - * @param workerScriptPath path to the Python worker script (llm_worker.py) + * @param modelName HuggingFace model name + * @param workerScriptPath path to the Python worker script * @return LLMCallback interface to the Python worker */ public LLMCallback loadModel(String modelName, String workerScriptPath) { + //auto-find available ports + int javaPort = findAvailablePort(); + int pythonPort = findAvailablePort(); + return loadModel(modelName, workerScriptPath, javaPort, pythonPort); + } + + /** + * Loads a HuggingFace model via Python worker for LLM inference. + * Starts a Python subprocess and connects via Py4J. + * + * @param modelName HuggingFace model name + * @param workerScriptPath path to the Python worker script + * @param javaPort port for Java gateway server + * @param pythonPort port for Python callback server + * @return LLMCallback interface to the Python worker + */ + public LLMCallback loadModel(String modelName, String workerScriptPath, int javaPort, int pythonPort) { if (_llmWorker != null) return _llmWorker; try { @@ -317,18 +334,20 @@ public LLMCallback loadModel(String modelName, String workerScriptPath) { //start Py4J gateway server with callback support _gatewayServer = new GatewayServer.GatewayServerBuilder() .entryPoint(this) - .javaPort(25333) - .callbackClient(25334, java.net.InetAddress.getLoopbackAddress()) + .javaPort(javaPort) + .callbackClient(pythonPort, java.net.InetAddress.getLoopbackAddress()) .build(); _gatewayServer.start(); //give gateway time to start Thread.sleep(500); - //start python worker process - LOG.info("Starting LLM worker with script: " + workerScriptPath); + //start python worker process with both ports + LOG.info("Starting LLM worker with script: " + workerScriptPath + + " (javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); _pythonProcess = new ProcessBuilder( - "python", workerScriptPath, modelName, "25333" + "python", workerScriptPath, modelName, + String.valueOf(javaPort), String.valueOf(pythonPort) ).redirectErrorStream(true).start(); //read python output in background thread @@ -357,6 +376,19 @@ public LLMCallback loadModel(String modelName, String workerScriptPath) { return _llmWorker; } + /** + * Finds an available port on the local machine. + * @return available port number + */ + private int findAvailablePort() { + try (java.net.ServerSocket socket = new java.net.ServerSocket(0)) { + socket.setReuseAddress(true); + return socket.getLocalPort(); + } catch (IOException e) { + throw new DMLException("Failed to find available port: " + e.getMessage()); + } + } + /** * Called by Python worker to register itself via Py4J. */ @@ -377,13 +409,19 @@ public void close() { //shutdown LLM worker if running if (_pythonProcess != null) { - _pythonProcess.destroy(); + _pythonProcess.destroyForcibly(); + try { + _pythonProcess.waitFor(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } _pythonProcess = null; } if (_gatewayServer != null) { _gatewayServer.shutdown(); _gatewayServer = null; } + _llmWorker = null; //clear thread-local configurations ConfigurationManager.clearLocalConfigs(); diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/systemds/llm_worker.py index 57872b37f82..9b5dd4a9155 100644 --- a/src/main/python/systemds/llm_worker.py +++ b/src/main/python/systemds/llm_worker.py @@ -34,20 +34,19 @@ class Java: if __name__ == "__main__": model_name = sys.argv[1] if len(sys.argv) > 1 else "distilgpt2" java_port = int(sys.argv[2]) if len(sys.argv) > 2 else 25333 + python_port = int(sys.argv[3]) if len(sys.argv) > 3 else 25334 - print(f"Starting LLM worker, connecting to Java on port {java_port}", flush=True) + print(f"Starting LLM worker (javaPort={java_port}, pythonPort={python_port})", flush=True) worker = LLMWorker(model_name) - # Connect to Java's GatewayServer and register this worker - # The callback_server starts a server on Python's side for Java to call back - # Use port 25334 which Java's CallbackClient expects + #connect to Java's GatewayServer and register this worker gateway = JavaGateway( gateway_parameters=GatewayParameters(port=java_port), - callback_server_parameters=CallbackServerParameters(port=25334) + callback_server_parameters=CallbackServerParameters(port=python_port) ) - print(f"Python callback server started on port 25334", flush=True) + print(f"Python callback server started on port {python_port}", flush=True) gateway.entry_point.registerWorker(worker) print("Worker registered with Java, waiting for requests...", flush=True) From dacdc1c1cae94d55a1c7fd7c6c525b88cdeb861d Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:53:53 +0100 Subject: [PATCH 04/24] Move llm_worker.py to fix Python module collision Move worker script from src/main/python/systemds/ to src/main/python/ to avoid shadowing Python stdlib operator module. --- src/main/java/org/apache/sysds/api/jmlc/Connection.java | 2 +- src/main/python/{systemds => }/llm_worker.py | 0 .../apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/main/python/{systemds => }/llm_worker.py (100%) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 13dcf2a0247..9dcc15937a0 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -346,7 +346,7 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java LOG.info("Starting LLM worker with script: " + workerScriptPath + " (javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); _pythonProcess = new ProcessBuilder( - "python", workerScriptPath, modelName, + "python3", workerScriptPath, modelName, String.valueOf(javaPort), String.valueOf(pythonPort) ).redirectErrorStream(true).start(); diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/llm_worker.py similarity index 100% rename from src/main/python/systemds/llm_worker.py rename to src/main/python/llm_worker.py diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index 5909b0cef4c..ac3f3e7e069 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -46,7 +46,7 @@ public void testLLMInference() { try { //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/systemds/llm_worker.py"); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); Assert.assertNotNull("LLM worker should not be null", llmWorker); //create prepared script and set llm worker From 29f657c2a55ad3d29c02b6f75ff19c9a2a0b1e9e Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 16:07:39 +0100 Subject: [PATCH 05/24] Use python3 with fallback to python in Connection.java --- .../org/apache/sysds/api/jmlc/Connection.java | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 9dcc15937a0..da0cb0be1d2 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -343,10 +343,11 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java Thread.sleep(500); //start python worker process with both ports + String pythonCmd = findPythonCommand(); LOG.info("Starting LLM worker with script: " + workerScriptPath + - " (javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); + " (python=" + pythonCmd + ", javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); _pythonProcess = new ProcessBuilder( - "python3", workerScriptPath, modelName, + pythonCmd, workerScriptPath, modelName, String.valueOf(javaPort), String.valueOf(pythonPort) ).redirectErrorStream(true).start(); @@ -376,6 +377,25 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java return _llmWorker; } + /** + * Finds the available Python command, trying python3 first then python. + * @return python command name + */ + private static String findPythonCommand() { + for (String cmd : new String[]{"python3", "python"}) { + try { + Process p = new ProcessBuilder(cmd, "--version") + .redirectErrorStream(true).start(); + int exitCode = p.waitFor(); + if (exitCode == 0) + return cmd; + } catch (Exception e) { + //command not found, try next + } + } + throw new DMLException("No Python installation found (tried python3, python)"); + } + /** * Finds an available port on the local machine. * @return available port number From e40e4f232035643aee11f6154bf3211e0416e35f Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 17:04:10 +0100 Subject: [PATCH 06/24] Add batch inference with FrameBlock and metrics support --- .../apache/sysds/api/jmlc/PreparedScript.java | 62 ++++++++++++ .../functions/jmlc/JMLCLLMInferenceTest.java | 95 +++++++++++++++++++ 2 files changed, 157 insertions(+) diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index 2e6109d0102..f8a61cfaf20 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -198,6 +198,68 @@ public String generate(String prompt, int maxNewTokens, double temperature, doub return _llmWorker.generate(prompt, maxNewTokens, temperature, topP); } + /** + * Generates text for multiple prompts and returns results as a FrameBlock. + * The FrameBlock has two columns: [prompt, generated_text]. + * + * @param prompts array of input prompt texts + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature + * @param topP nucleus sampling probability threshold + * @return FrameBlock with columns [prompt, generated_text] + */ + public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + //generate text for each prompt + String[][] data = new String[prompts.length][2]; + for (int i = 0; i < prompts.length; i++) { + data[i][0] = prompts[i]; + data[i][1] = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); + } + //create FrameBlock with string schema + ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING}; + String[] colNames = new String[]{"prompt", "generated_text"}; + FrameBlock fb = new FrameBlock(schema, colNames); + for (String[] row : data) + fb.appendRow(row); + return fb; + } + + /** + * Generates text for multiple prompts and returns results with timing metrics. + * The FrameBlock has three columns: [prompt, generated_text, time_ms]. + * + * @param prompts array of input prompt texts + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature + * @param topP nucleus sampling probability threshold + * @return FrameBlock with columns [prompt, generated_text, time_ms] + */ + public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + //generate text for each prompt with timing + String[][] data = new String[prompts.length][3]; + for (int i = 0; i < prompts.length; i++) { + long start = System.nanoTime(); + String result = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); + long elapsed = (System.nanoTime() - start) / 1_000_000; + data[i][0] = prompts[i]; + data[i][1] = result; + data[i][2] = String.valueOf(elapsed); + } + //create FrameBlock with schema + ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING, ValueType.INT64}; + String[] colNames = new String[]{"prompt", "generated_text", "time_ms"}; + FrameBlock fb = new FrameBlock(schema, colNames); + for (String[] row : data) + fb.appendRow(row); + return fb; + } + /** * Binds a scalar boolean to a registered input variable. * diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index ac3f3e7e069..0fab3134703 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -22,6 +22,7 @@ import org.apache.sysds.api.jmlc.Connection; import org.apache.sysds.api.jmlc.LLMCallback; import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.test.AutomatedTestBase; import org.junit.Assert; import org.junit.Test; @@ -76,4 +77,98 @@ public void testLLMInference() { } } } + + @Test + public void testBatchInference() { + Connection conn = null; + try { + //create connection and load model + conn = new Connection(); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); + + //create prepared script and set llm worker + String script = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); + ps.setLLMWorker(llmWorker); + + //batch generate with multiple prompts + String[] prompts = { + "The meaning of life is", + "Machine learning is", + "Apache SystemDS enables" + }; + FrameBlock result = ps.generateBatch(prompts, 20, 0.7, 0.9); + + //verify FrameBlock structure + Assert.assertNotNull("Batch result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 2 columns", 2, result.getNumColumns()); + + //verify each row has prompt and generated text + for (int i = 0; i < prompts.length; i++) { + String prompt = (String) result.get(i, 0); + String generated = (String) result.get(i, 1); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertNotNull("Generated text should not be null", generated); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated); + } + + } catch (Exception e) { + System.out.println("Skipping batch LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) { + conn.close(); + } + } + } + + @Test + public void testBatchWithMetrics() { + Connection conn = null; + try { + //create connection and load model + conn = new Connection(); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); + + //create prepared script and set llm worker + String script = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); + ps.setLLMWorker(llmWorker); + + //batch generate with metrics + String[] prompts = {"The meaning of life is", "Data science is"}; + FrameBlock result = ps.generateBatchWithMetrics(prompts, 20, 0.7, 0.9); + + //verify FrameBlock structure with metrics + Assert.assertNotNull("Metrics result should not be null", result); + Assert.assertEquals("Should have 2 rows", 2, result.getNumRows()); + Assert.assertEquals("Should have 3 columns", 3, result.getNumColumns()); + + //verify metrics column contains timing data + for (int i = 0; i < prompts.length; i++) { + String prompt = (String) result.get(i, 0); + String generated = (String) result.get(i, 1); + long timeMs = Long.parseLong(result.get(i, 2).toString()); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + Assert.assertTrue("Time should be positive", timeMs > 0); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated); + System.out.println("Time: " + timeMs + "ms"); + } + + } catch (Exception e) { + System.out.println("Skipping metrics LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) { + conn.close(); + } + } + } } From fdd16849a024685bbbc6b0a98f9836b57a901555 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 17:41:57 +0100 Subject: [PATCH 07/24] Clean up test: extract constants and shared setup method --- .../functions/jmlc/JMLCLLMInferenceTest.java | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index 0fab3134703..e54d606c16d 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -34,6 +34,9 @@ public class JMLCLLMInferenceTest extends AutomatedTestBase { private final static String TEST_NAME = "JMLCLLMInferenceTest"; private final static String TEST_DIR = "functions/jmlc/"; + private final static String MODEL_NAME = "distilgpt2"; + private final static String WORKER_SCRIPT = "src/main/python/llm_worker.py"; + private final static String DML_SCRIPT = "x = 1;\nwrite(x, './tmp/x');"; @Override public void setUp() { @@ -41,19 +44,24 @@ public void setUp() { getAndLoadTestConfiguration(TEST_NAME); } + /** + * Creates a connection, loads the LLM model, and returns a PreparedScript + * with the LLM worker attached. + */ + private PreparedScript createLLMScript(Connection conn) throws Exception { + LLMCallback llmWorker = conn.loadModel(MODEL_NAME, WORKER_SCRIPT); + Assert.assertNotNull("LLM worker should not be null", llmWorker); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, new String[]{}, new String[]{"x"}); + ps.setLLMWorker(llmWorker); + return ps; + } + @Test public void testLLMInference() { Connection conn = null; try { - //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); - Assert.assertNotNull("LLM worker should not be null", llmWorker); - - //create prepared script and set llm worker - String script = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); - ps.setLLMWorker(llmWorker); + PreparedScript ps = createLLMScript(conn); //generate text using llm String prompt = "The meaning of life is"; @@ -67,14 +75,12 @@ public void testLLMInference() { System.out.println("Generated: " + result); } catch (Exception e) { - //skip test if dependencies not available System.out.println("Skipping LLM test:"); e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); } finally { - if (conn != null) { + if (conn != null) conn.close(); - } } } @@ -82,14 +88,8 @@ public void testLLMInference() { public void testBatchInference() { Connection conn = null; try { - //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); - - //create prepared script and set llm worker - String script = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); - ps.setLLMWorker(llmWorker); + PreparedScript ps = createLLMScript(conn); //batch generate with multiple prompts String[] prompts = { @@ -120,9 +120,8 @@ public void testBatchInference() { e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); } finally { - if (conn != null) { + if (conn != null) conn.close(); - } } } @@ -130,14 +129,8 @@ public void testBatchInference() { public void testBatchWithMetrics() { Connection conn = null; try { - //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); - - //create prepared script and set llm worker - String script = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); - ps.setLLMWorker(llmWorker); + PreparedScript ps = createLLMScript(conn); //batch generate with metrics String[] prompts = {"The meaning of life is", "Data science is"}; @@ -166,9 +159,8 @@ public void testBatchWithMetrics() { e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); } finally { - if (conn != null) { + if (conn != null) conn.close(); - } } } } From b9ba3e05cd2a5c4076b23b7fc0a433158879e3d2 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 19:02:19 +0100 Subject: [PATCH 08/24] Add token counts, GPU support, and improve error handling - Add generateWithTokenCount() returning JSON with input/output token counts - Update generateBatchWithMetrics() to include input_tokens and output_tokens columns - Add CUDA auto-detection with device_map=auto for multi-GPU support in llm_worker.py - Check Python process liveness during startup instead of blind 60s timeout --- .../org/apache/sysds/api/jmlc/Connection.java | 20 ++++++++--- .../apache/sysds/api/jmlc/LLMCallback.java | 12 +++++++ .../apache/sysds/api/jmlc/PreparedScript.java | 29 +++++++++------ src/main/python/llm_worker.py | 35 +++++++++++++++++-- .../functions/jmlc/JMLCLLMInferenceTest.java | 11 ++++-- 5 files changed, 86 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index da0cb0be1d2..21b98c8563e 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -352,7 +352,7 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java ).redirectErrorStream(true).start(); //read python output in background thread - new Thread(() -> { + Thread outputReader = new Thread(() -> { try (BufferedReader reader = new BufferedReader( new InputStreamReader(_pythonProcess.getInputStream()))) { String line; @@ -362,11 +362,21 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java } catch (IOException e) { LOG.error("Error reading LLM worker output", e); } - }).start(); + }); + outputReader.setName("llm-worker-output"); + outputReader.setDaemon(true); + outputReader.start(); - //wait for worker to register - if (!_workerLatch.await(60, TimeUnit.SECONDS)) { - throw new DMLException("Timeout waiting for LLM worker to register"); + //wait for worker to register, checking process liveness periodically + long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(60); + while (!_workerLatch.await(2, TimeUnit.SECONDS)) { + if (!_pythonProcess.isAlive()) { + int exitCode = _pythonProcess.exitValue(); + throw new DMLException("LLM worker process died during startup (exit code " + exitCode + ")"); + } + if (System.nanoTime() > deadlineNs) { + throw new DMLException("Timeout waiting for LLM worker to register (60s)"); + } } } catch (DMLException e) { diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java index 09ee8debb29..10d9787d992 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -16,4 +16,16 @@ public interface LLMCallback { * @return generated text continuation */ String generate(String prompt, int maxNewTokens, double temperature, double topP); + + /** + * Generates text and returns result with token counts as a JSON string. + * Format: {"text": "...", "input_tokens": N, "output_tokens": M} + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return JSON string with generated text and token counts + */ + String generateWithTokenCount(String prompt, int maxNewTokens, double temperature, double topP); } \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index f8a61cfaf20..e664f04ca04 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -229,31 +229,40 @@ public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double tempe /** * Generates text for multiple prompts and returns results with timing metrics. - * The FrameBlock has three columns: [prompt, generated_text, time_ms]. + * The FrameBlock has five columns: [prompt, generated_text, time_ms, input_tokens, output_tokens]. * * @param prompts array of input prompt texts * @param maxNewTokens maximum number of new tokens to generate * @param temperature sampling temperature * @param topP nucleus sampling probability threshold - * @return FrameBlock with columns [prompt, generated_text, time_ms] + * @return FrameBlock with columns [prompt, generated_text, time_ms, input_tokens, output_tokens] */ public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) { if (_llmWorker == null) { throw new DMLException("No LLM worker set. Call setLLMWorker() first."); } - //generate text for each prompt with timing - String[][] data = new String[prompts.length][3]; + //generate text for each prompt with timing and token counts + String[][] data = new String[prompts.length][5]; for (int i = 0; i < prompts.length; i++) { long start = System.nanoTime(); - String result = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); + String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP); long elapsed = (System.nanoTime() - start) / 1_000_000; - data[i][0] = prompts[i]; - data[i][1] = result; - data[i][2] = String.valueOf(elapsed); + //parse JSON response: {"text": "...", "input_tokens": N, "output_tokens": M} + try { + org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json); + data[i][0] = prompts[i]; + data[i][1] = obj.getString("text"); + data[i][2] = String.valueOf(elapsed); + data[i][3] = String.valueOf(obj.getInt("input_tokens")); + data[i][4] = String.valueOf(obj.getInt("output_tokens")); + } catch (Exception e) { + throw new DMLException("Failed to parse LLM worker response: " + e.getMessage()); + } } //create FrameBlock with schema - ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING, ValueType.INT64}; - String[] colNames = new String[]{"prompt", "generated_text", "time_ms"}; + ValueType[] schema = new ValueType[]{ + ValueType.STRING, ValueType.STRING, ValueType.INT64, ValueType.INT64, ValueType.INT64}; + String[] colNames = new String[]{"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; FrameBlock fb = new FrameBlock(schema, colNames); for (String[] row : data) fb.appendRow(row); diff --git a/src/main/python/llm_worker.py b/src/main/python/llm_worker.py index 9b5dd4a9155..27160e27f13 100644 --- a/src/main/python/llm_worker.py +++ b/src/main/python/llm_worker.py @@ -3,6 +3,7 @@ Java starts this script, then calls generate() via Py4J. """ import sys +import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters @@ -11,12 +12,20 @@ class LLMWorker: def __init__(self, model_name="distilgpt2"): print(f"Loading model: {model_name}", flush=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained(model_name) + #auto-detect GPU and load model accordingly + if torch.cuda.is_available(): + print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", torch_dtype=torch.float16) + self.device = "cuda" + else: + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.device = "cpu" self.model.eval() - print(f"Model loaded: {model_name}", flush=True) + print(f"Model loaded: {model_name} (device={self.device})", flush=True) def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): - inputs = self.tokenizer(prompt, return_tensors="pt") + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, @@ -28,6 +37,26 @@ def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): new_tokens = outputs[0][inputs["input_ids"].shape[1]:] return self.tokenizer.decode(new_tokens, skip_special_tokens=True) + def generateWithTokenCount(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + input_token_count = inputs["input_ids"].shape[1] + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + do_sample=float(temperature) > 0.0 + ) + new_tokens = outputs[0][input_token_count:] + output_token_count = len(new_tokens) + text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) + return json.dumps({ + "text": text, + "input_tokens": input_token_count, + "output_tokens": output_token_count + }) + class Java: implements = ["org.apache.sysds.api.jmlc.LLMCallback"] diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index e54d606c16d..0e07e07d221 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -136,22 +136,27 @@ public void testBatchWithMetrics() { String[] prompts = {"The meaning of life is", "Data science is"}; FrameBlock result = ps.generateBatchWithMetrics(prompts, 20, 0.7, 0.9); - //verify FrameBlock structure with metrics + //verify FrameBlock structure with metrics and token counts Assert.assertNotNull("Metrics result should not be null", result); Assert.assertEquals("Should have 2 rows", 2, result.getNumRows()); - Assert.assertEquals("Should have 3 columns", 3, result.getNumColumns()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); - //verify metrics column contains timing data + //verify metrics columns contain timing and token data for (int i = 0; i < prompts.length; i++) { String prompt = (String) result.get(i, 0); String generated = (String) result.get(i, 1); long timeMs = Long.parseLong(result.get(i, 2).toString()); + long inputTokens = Long.parseLong(result.get(i, 3).toString()); + long outputTokens = Long.parseLong(result.get(i, 4).toString()); Assert.assertEquals("Prompt should match", prompts[i], prompt); Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); Assert.assertTrue("Time should be positive", timeMs > 0); + Assert.assertTrue("Input tokens should be positive", inputTokens > 0); + Assert.assertTrue("Output tokens should be positive", outputTokens > 0); System.out.println("Prompt: " + prompt); System.out.println("Generated: " + generated); System.out.println("Time: " + timeMs + "ms"); + System.out.println("Tokens: " + inputTokens + " in, " + outputTokens + " out"); } } catch (Exception e) { From 2e984a23d617665dfecb64254330b01f087251e8 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:02:29 +0100 Subject: [PATCH 09/24] Increase worker startup timeout to 300s for larger models 7B+ models need more time to load weights into GPU memory. --- src/main/java/org/apache/sysds/api/jmlc/Connection.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 21b98c8563e..f76c0ceb00c 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -367,15 +367,15 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java outputReader.setDaemon(true); outputReader.start(); - //wait for worker to register, checking process liveness periodically - long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(60); + //larger models (7B+) need more time to load weights into GPU memory + long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(300); while (!_workerLatch.await(2, TimeUnit.SECONDS)) { if (!_pythonProcess.isAlive()) { int exitCode = _pythonProcess.exitValue(); throw new DMLException("LLM worker process died during startup (exit code " + exitCode + ")"); } if (System.nanoTime() > deadlineNs) { - throw new DMLException("Timeout waiting for LLM worker to register (60s)"); + throw new DMLException("Timeout waiting for LLM worker to register (300s)"); } } From bf666c204edb8440f07bb90c6cde5a49f088a42e Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 17:15:38 +0100 Subject: [PATCH 10/24] Revert accidental changes to MatrixBlockDictionary.java This file was accidentally modified in a prior commit. Restoring the original vectorized SIMD implementation. --- .../dictionary/MatrixBlockDictionary.java | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index f6b09d4384a..71a4112f157 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -67,6 +67,8 @@ public class MatrixBlockDictionary extends ADictionary { final private MatrixBlock _data; + static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; + /** * Unsafe private constructor that does not check the data validity. USE WITH CAUTION. * @@ -2125,6 +2127,9 @@ private void preaggValuesFromDenseDictDenseAggRangeRange(final int numVals, fina private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, double[] ret, int bi, int bk, int bj, int bie, int bke, int cz, int az, int ls, int cut, int sOffT, int eOffT) { + final int vLen = SPECIES.length(); + final DoubleVector vVec = DoubleVector.zero(SPECIES); + final int leftover = (eOffT - sOffT) % vLen; // leftover not vectorized for(int i = bi; i < bie; i++) { final int offI = i * cz; final int offOutT = i * az + bj; @@ -2133,14 +2138,27 @@ private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, final int sOff = sOffT + idb; final int eOff = eOffT + idb; final double v = a[offI + k]; - int offOut = offOutT; - for(int j = sOff; j < eOff; j++, offOut++) { - ret[offOut] += v * b[j]; - } + vecInnerLoop(v, b, ret, offOutT, eOff, sOff, leftover, vLen, vVec); } } } + private static void vecInnerLoop(final double v, final double[] b, final double[] ret, final int offOutT, + final int eOff, final int sOff, final int leftover, final int vLen, DoubleVector vVec) { + int offOut = offOutT; + vVec = vVec.broadcast(v); + final int end = eOff - leftover; + for(int j = sOff; j < end; j += vLen, offOut += vLen) { + DoubleVector res = DoubleVector.fromArray(SPECIES, ret, offOut); + DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, j); + vVec.fma(bVec, res).intoArray(ret, offOut); + } + for(int j = end; j < eOff; j++, offOut++) { + ret[offOut] += v * b[j]; + } + + } + private void preaggValuesFromDenseDictDenseAggRangeGeneric(final int numVals, final IColIndex colIndexes, final int s, final int e, final double[] b, final int cut, final double[] ret) { final int cz = colIndexes.size(); From 5faa69110feb16e6a70149da0179d619c3b1843a Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 17:51:14 +0100 Subject: [PATCH 11/24] Add GPU batching support to JMLC LLM inference - LLMCallback.java: add generateBatch() interface method - PreparedScript.java: replace per-prompt for-loop with single batch call - llm_worker.py: implement batched tokenization and model.generate() Achieves 3-14x speedup over sequential inference on H100. --- .../apache/sysds/api/jmlc/LLMCallback.java | 27 +++++---- .../apache/sysds/api/jmlc/PreparedScript.java | 20 +++---- src/main/python/llm_worker.py | 60 +++++++++++++------ 3 files changed, 66 insertions(+), 41 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java index 10d9787d992..308abc2f316 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -8,24 +8,25 @@ public interface LLMCallback { /** * Generates text using the LLM model. - * - * @param prompt the input prompt text - * @param maxNewTokens maximum number of new tokens to generate - * @param temperature sampling temperature (0.0 = deterministic, higher = more random) - * @param topP nucleus sampling probability threshold - * @return generated text continuation */ String generate(String prompt, int maxNewTokens, double temperature, double topP); /** * Generates text and returns result with token counts as a JSON string. * Format: {"text": "...", "input_tokens": N, "output_tokens": M} - * - * @param prompt the input prompt text - * @param maxNewTokens maximum number of new tokens to generate - * @param temperature sampling temperature (0.0 = deterministic, higher = more random) - * @param topP nucleus sampling probability threshold - * @return JSON string with generated text and token counts */ String generateWithTokenCount(String prompt, int maxNewTokens, double temperature, double topP); -} \ No newline at end of file + + /** + * Generates text for multiple prompts in a single batched GPU call. + * Returns a JSON array of objects with text and token counts. + * Format: [{"text": "...", "input_tokens": N, "output_tokens": M, "time_ms": T}, ...] + * + * @param prompts array of input prompt texts + * @param maxNewTokens maximum number of new tokens to generate per prompt + * @param temperature sampling temperature + * @param topP nucleus sampling probability threshold + * @return JSON array string with results for each prompt + */ + String generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP); +} diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index e664f04ca04..ca939457dbe 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -241,23 +241,23 @@ public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, d if (_llmWorker == null) { throw new DMLException("No LLM worker set. Call setLLMWorker() first."); } - //generate text for each prompt with timing and token counts + //batch all prompts in a single GPU call via the Python worker String[][] data = new String[prompts.length][5]; - for (int i = 0; i < prompts.length; i++) { + try { long start = System.nanoTime(); - String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP); - long elapsed = (System.nanoTime() - start) / 1_000_000; - //parse JSON response: {"text": "...", "input_tokens": N, "output_tokens": M} - try { - org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json); + String jsonArray = _llmWorker.generateBatch(prompts, maxNewTokens, temperature, topP); + long totalElapsed = (System.nanoTime() - start) / 1_000_000; + org.apache.wink.json4j.JSONArray results = new org.apache.wink.json4j.JSONArray(jsonArray); + for (int i = 0; i < prompts.length; i++) { + org.apache.wink.json4j.JSONObject obj = results.getJSONObject(i); data[i][0] = prompts[i]; data[i][1] = obj.getString("text"); - data[i][2] = String.valueOf(elapsed); + data[i][2] = String.valueOf(obj.getInt("time_ms")); data[i][3] = String.valueOf(obj.getInt("input_tokens")); data[i][4] = String.valueOf(obj.getInt("output_tokens")); - } catch (Exception e) { - throw new DMLException("Failed to parse LLM worker response: " + e.getMessage()); } + } catch (Exception e) { + throw new DMLException("Failed to parse batched LLM response: " + e.getMessage()); } //create FrameBlock with schema ValueType[] schema = new ValueType[]{ diff --git a/src/main/python/llm_worker.py b/src/main/python/llm_worker.py index 27160e27f13..7df196fcd89 100644 --- a/src/main/python/llm_worker.py +++ b/src/main/python/llm_worker.py @@ -1,10 +1,4 @@ -""" -SystemDS LLM Worker - Python side of the Py4J bridge. -Java starts this script, then calls generate() via Py4J. -""" -import sys -import json -import torch +import sys, json, time, torch from transformers import AutoTokenizer, AutoModelForCausalLM from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters @@ -12,7 +6,8 @@ class LLMWorker: def __init__(self, model_name="distilgpt2"): print(f"Loading model: {model_name}", flush=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name) - #auto-detect GPU and load model accordingly + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token if torch.cuda.is_available(): print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) self.model = AutoModelForCausalLM.from_pretrained( @@ -57,6 +52,44 @@ def generateWithTokenCount(self, prompt, max_new_tokens=50, temperature=0.7, top "output_tokens": output_token_count }) + def generateBatch(self, prompts, max_new_tokens=50, temperature=0.7, top_p=0.9): + prompt_list = list(prompts) + n = len(prompt_list) + results = [] + # process in sub-batches to avoid OOM + batch_size = min(n, 8) + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + batch = prompt_list[start:end] + t0 = time.time() + inputs = self.tokenizer( + batch, return_tensors="pt", padding=True, truncation=True, + max_length=2048 + ).to(self.model.device) + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + do_sample=float(temperature) > 0.0 + ) + elapsed_ms = (time.time() - t0) * 1000 + per_prompt_ms = elapsed_ms / len(batch) + for i, prompt_text in enumerate(batch): + input_len = (inputs["input_ids"][i] != self.tokenizer.pad_token_id).sum().item() + new_tokens = outputs[i][inputs["input_ids"].shape[1]:] + # strip padding from generated tokens + non_pad = [t for t in new_tokens.tolist() if t != self.tokenizer.pad_token_id] + text = self.tokenizer.decode(non_pad, skip_special_tokens=True) + results.append({ + "text": text, + "input_tokens": input_len, + "output_tokens": len(non_pad), + "time_ms": int(per_prompt_ms) + }) + return json.dumps(results) + class Java: implements = ["org.apache.sysds.api.jmlc.LLMCallback"] @@ -66,26 +99,17 @@ class Java: python_port = int(sys.argv[3]) if len(sys.argv) > 3 else 25334 print(f"Starting LLM worker (javaPort={java_port}, pythonPort={python_port})", flush=True) - worker = LLMWorker(model_name) - - #connect to Java's GatewayServer and register this worker gateway = JavaGateway( gateway_parameters=GatewayParameters(port=java_port), callback_server_parameters=CallbackServerParameters(port=python_port) ) - print(f"Python callback server started on port {python_port}", flush=True) - gateway.entry_point.registerWorker(worker) print("Worker registered with Java, waiting for requests...", flush=True) - - # Keep the worker alive to handle callbacks from Java - # The callback server runs in a daemon thread, so we need to block here import threading shutdown_event = threading.Event() try: - # Wait indefinitely until Java closes the connection or kills the process shutdown_event.wait() except KeyboardInterrupt: - print("Worker shutting down", flush=True) \ No newline at end of file + print("Worker shutting down", flush=True) From c9c85d4537c8c825579d054d1c5a0dec02bd6b8e Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 18:10:55 +0100 Subject: [PATCH 12/24] Keep both sequential and batched inference modes in PreparedScript generateBatchWithMetrics() now accepts a boolean batched parameter: true for GPU-batched (new), false for original sequential for-loop. --- .../apache/sysds/api/jmlc/PreparedScript.java | 61 ++++++++++++++----- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index ca939457dbe..4b4a9d33e6f 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -238,26 +238,59 @@ public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double tempe * @return FrameBlock with columns [prompt, generated_text, time_ms, input_tokens, output_tokens] */ public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) { + return generateBatchWithMetrics(prompts, maxNewTokens, temperature, topP, true); + } + + /** + * Generates text for an array of prompts and returns a FrameBlock with columns + * [prompt, generated_text, time_ms, input_tokens, output_tokens]. + * + * @param prompts array of input prompts + * @param maxNewTokens max tokens to generate per prompt + * @param temperature sampling temperature + * @param topP nucleus sampling threshold + * @param batched if true, sends all prompts to the GPU in one call (faster); + * if false, processes prompts sequentially (original behavior) + * @return FrameBlock with inference results + */ + public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP, boolean batched) { if (_llmWorker == null) { throw new DMLException("No LLM worker set. Call setLLMWorker() first."); } - //batch all prompts in a single GPU call via the Python worker String[][] data = new String[prompts.length][5]; - try { - long start = System.nanoTime(); - String jsonArray = _llmWorker.generateBatch(prompts, maxNewTokens, temperature, topP); - long totalElapsed = (System.nanoTime() - start) / 1_000_000; - org.apache.wink.json4j.JSONArray results = new org.apache.wink.json4j.JSONArray(jsonArray); + if (batched) { + //GPU-batched: single call to Python worker for all prompts + try { + String jsonArray = _llmWorker.generateBatch(prompts, maxNewTokens, temperature, topP); + org.apache.wink.json4j.JSONArray results = new org.apache.wink.json4j.JSONArray(jsonArray); + for (int i = 0; i < prompts.length; i++) { + org.apache.wink.json4j.JSONObject obj = results.getJSONObject(i); + data[i][0] = prompts[i]; + data[i][1] = obj.getString("text"); + data[i][2] = String.valueOf(obj.getInt("time_ms")); + data[i][3] = String.valueOf(obj.getInt("input_tokens")); + data[i][4] = String.valueOf(obj.getInt("output_tokens")); + } + } catch (Exception e) { + throw new DMLException("Failed to parse batched LLM response: " + e.getMessage()); + } + } else { + //sequential: one prompt at a time (original behavior) for (int i = 0; i < prompts.length; i++) { - org.apache.wink.json4j.JSONObject obj = results.getJSONObject(i); - data[i][0] = prompts[i]; - data[i][1] = obj.getString("text"); - data[i][2] = String.valueOf(obj.getInt("time_ms")); - data[i][3] = String.valueOf(obj.getInt("input_tokens")); - data[i][4] = String.valueOf(obj.getInt("output_tokens")); + long start = System.nanoTime(); + String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP); + long elapsed = (System.nanoTime() - start) / 1_000_000; + try { + org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json); + data[i][0] = prompts[i]; + data[i][1] = obj.getString("text"); + data[i][2] = String.valueOf(elapsed); + data[i][3] = String.valueOf(obj.getInt("input_tokens")); + data[i][4] = String.valueOf(obj.getInt("output_tokens")); + } catch (Exception e) { + throw new DMLException("Failed to parse LLM worker response: " + e.getMessage()); + } } - } catch (Exception e) { - throw new DMLException("Failed to parse batched LLM response: " + e.getMessage()); } //create FrameBlock with schema ValueType[] schema = new ValueType[]{ From 4b44dd1d4c789a181a4c15c57898b4770b7aa204 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 18:22:41 +0100 Subject: [PATCH 13/24] Add gitignore rules for .env files, meeting notes, and local tool config --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index 5de697a37e3..642669e2600 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,10 @@ docker/mountFolder/*.bin.mtd SEAL-*/ +# local tool config and sensitive files +.claude/ +.env +.env.* +meeting_notes/ +meeting_notes.* +*meeting_notes* From 72bc3348ae2dc951f8418979e9a3a7d100d47a68 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:02:17 +0100 Subject: [PATCH 14/24] Add llmPredict builtin, opcode and ParamBuiltinOp entries --- src/main/java/org/apache/sysds/common/Builtins.java | 1 + src/main/java/org/apache/sysds/common/Opcodes.java | 1 + src/main/java/org/apache/sysds/common/Types.java | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index dc1f23b83fc..82eccbec021 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -226,6 +226,7 @@ public enum Builtins { LMDS("lmDS", true), LMPREDICT("lmPredict", true), LMPREDICT_STATS("lmPredictStats", true), + LLMPREDICT("llmPredict", false, true), LOCAL("local", false), LOG("log", false), LOGSUMEXP("logSumExp", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..94055d055c5 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -204,6 +204,7 @@ public enum Opcodes { GROUPEDAGG("groupedagg", InstructionType.ParameterizedBuiltin), RMEMPTY("rmempty", InstructionType.ParameterizedBuiltin), REPLACE("replace", InstructionType.ParameterizedBuiltin), + LLMPREDICT("llmpredict", InstructionType.ParameterizedBuiltin), LOWERTRI("lowertri", InstructionType.ParameterizedBuiltin), UPPERTRI("uppertri", InstructionType.ParameterizedBuiltin), REXPAND("rexpand", InstructionType.ParameterizedBuiltin), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..3414614991c 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -805,7 +805,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) { /** Parameterized operations that require named variable arguments */ public enum ParamBuiltinOp { - AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND, + AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, LLMPREDICT, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, TOKENIZE, TOSTRING, LIST, PARAMSERV From 0ad1b5637a0f997aca5df82100647ee859f7a539 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:17:31 +0100 Subject: [PATCH 15/24] Add llmPredict parser validation in ParameterizedBuiltinFunctionExpression --- .../ParameterizedBuiltinFunctionExpression.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 314440628e0..89ed5d8ae34 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -61,6 +61,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG); pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY); pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE); + pbHopMap.put(Builtins.LLMPREDICT, ParamBuiltinOp.LLMPREDICT); pbHopMap.put(Builtins.LOWER_TRI, ParamBuiltinOp.LOWER_TRI); pbHopMap.put(Builtins.UPPER_TRI, ParamBuiltinOp.UPPER_TRI); @@ -211,6 +212,10 @@ public void validateExpression(HashMap ids, HashMap Date: Mon, 16 Feb 2026 20:28:44 +0100 Subject: [PATCH 16/24] Wire llmPredict through hop, lop and instruction generation --- .../java/org/apache/sysds/hops/ParameterizedBuiltinOp.java | 5 +++-- .../java/org/apache/sysds/lops/ParameterizedBuiltin.java | 1 + src/main/java/org/apache/sysds/parser/DMLTranslator.java | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java index 61a4b8b8f91..b791478214b 100644 --- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java @@ -187,6 +187,7 @@ public Lop constructLops() case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: @@ -758,7 +759,7 @@ && getTargetHop().areDimsBelowThreshold() ) { if (_op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA || _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST || _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF - || _op == ParamBuiltinOp.PARAMSERV) { + || _op == ParamBuiltinOp.PARAMSERV || _op == ParamBuiltinOp.LLMPREDICT) { _etype = ExecType.CP; } @@ -768,7 +769,7 @@ && getTargetHop().areDimsBelowThreshold() ) { switch(_op) { case CONTAINS: if(getTargetHop().optFindExecType() == ExecType.SPARK) - _etype = ExecType.SPARK; + _etype = ExecType.SPARK; break; default: // Do not change execution type. diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java index 3604121aac8..dcec28f76ca 100644 --- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java @@ -176,6 +176,7 @@ public String getInstructions(String output) case CONTAINS: case REPLACE: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..b1536371711 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: From de675acca6966adf6d7a0f8f4a6bd9dcac2962b3 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:37:32 +0100 Subject: [PATCH 17/24] Add llmPredict CP instruction with HTTP-based inference --- .../cp/ParameterizedBuiltinCPInstruction.java | 75 ++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..69045beeb0e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -154,7 +154,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.RMEMPTY.toString()) || opcode.equalsIgno } else if(opcode.equals(Opcodes.TRANSFORMAPPLY.toString()) || opcode.equals(Opcodes.TRANSFORMDECODE.toString()) || opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString()) || opcode.equals(Opcodes.TRANSFORMCOLMAP.toString()) - || opcode.equals(Opcodes.TRANSFORMMETA.toString()) || opcode.equals(Opcodes.TOKENIZE.toString()) + || opcode.equals(Opcodes.TRANSFORMMETA.toString()) || opcode.equals(Opcodes.TOKENIZE.toString()) || opcode.equals(Opcodes.LLMPREDICT.toString()) || opcode.equals(Opcodes.TOSTRING.toString()) || opcode.equals(Opcodes.NVLIST.toString()) || opcode.equals(Opcodes.AUTODIFF.toString())) { return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str); } @@ -324,6 +324,79 @@ else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) { ec.setFrameOutput(output.getName(), fbout); ec.releaseFrameInput(params.get("target")); } + + else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) { + FrameBlock prompts = ec.getFrameInput(params.get("target")); + String url = params.get("url"); + int maxTokens = params.containsKey("max_tokens") ? + Integer.parseInt(params.get("max_tokens")) : 512; + double temperature = params.containsKey("temperature") ? + Double.parseDouble(params.get("temperature")) : 0.0; + double topP = params.containsKey("top_p") ? + Double.parseDouble(params.get("top_p")) : 0.9; + + int n = prompts.getNumRows(); + String[][] data = new String[n][5]; + for(int i = 0; i < n; i++) { + String prompt = prompts.get(i, 0).toString(); + long t0 = System.nanoTime(); + try { + org.apache.wink.json4j.JSONObject req = new org.apache.wink.json4j.JSONObject(); + req.put("prompt", prompt); + req.put("max_tokens", maxTokens); + req.put("temperature", temperature); + req.put("top_p", topP); + + java.net.URL endpoint = new java.net.URI(url).toURL(); + java.net.HttpURLConnection conn = + (java.net.HttpURLConnection) endpoint.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setDoOutput(true); + conn.getOutputStream().write( + req.toString().getBytes(java.nio.charset.StandardCharsets.UTF_8)); + conn.getOutputStream().close(); + + if(conn.getResponseCode() != 200) + throw new DMLRuntimeException( + "LLM endpoint returned HTTP " + conn.getResponseCode()); + + String body = new String(conn.getInputStream().readAllBytes(), + java.nio.charset.StandardCharsets.UTF_8); + conn.disconnect(); + + org.apache.wink.json4j.JSONObject resp = + new org.apache.wink.json4j.JSONObject(body); + String text = resp.getJSONArray("choices") + .getJSONObject(0).getString("text"); + long elapsed = (System.nanoTime() - t0) / 1_000_000; + int inTok = 0, outTok = 0; + if(resp.has("usage")) { + org.apache.wink.json4j.JSONObject usage = resp.getJSONObject("usage"); + inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0; + outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0; + } + data[i] = new String[]{prompt, text, + String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)}; + } + catch(DMLRuntimeException e) { throw e; } + catch(Exception e) { + throw new DMLRuntimeException("llmPredict HTTP call failed: " + e.getMessage(), e); + } + } + + ValueType[] schema = {ValueType.STRING, ValueType.STRING, + ValueType.INT64, ValueType.INT64, ValueType.INT64}; + String[] colNames = {"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; + FrameBlock fbout = new FrameBlock(schema, colNames); + for(String[] row : data) + fbout.appendRow(row); + + ec.setFrameOutput(output.getName(), fbout); + ec.releaseFrameInput(params.get("target")); + } + + else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) { // acquire locks FrameBlock data = ec.getFrameInput(params.get("target")); From 5eab87d02eee85304336e8af505c2e559ff189e5 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:41:49 +0100 Subject: [PATCH 18/24] Remove Py4J-based LLM inference from JMLC API --- .../org/apache/sysds/api/jmlc/Connection.java | 159 ------------------ .../apache/sysds/api/jmlc/LLMCallback.java | 32 ---- .../apache/sysds/api/jmlc/PreparedScript.java | 142 ---------------- src/main/python/llm_worker.py | 115 ------------- 4 files changed, 448 deletions(-) delete mode 100644 src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java delete mode 100644 src/main/python/llm_worker.py diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index f76c0ceb00c..525c1a97bb2 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -28,11 +28,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.sysds.hops.OptimizerUtils; @@ -70,7 +66,6 @@ import org.apache.sysds.runtime.transform.meta.TfMetaUtils; import org.apache.sysds.runtime.util.CollectionUtils; import org.apache.sysds.runtime.util.DataConverter; -import py4j.GatewayServer; /** * Interaction with SystemDS using the JMLC (Java Machine Learning Connector) API is initiated with @@ -96,12 +91,6 @@ public class Connection implements Closeable private final DMLConfig _dmlconf; private final CompilerConfig _cconf; private static FileSystem fs = null; - private Process _pythonProcess = null; - private py4j.GatewayServer _gatewayServer = null; - private LLMCallback _llmWorker = null; - private CountDownLatch _workerLatch = null; - - private static final Log LOG = LogFactory.getLog(Connection.class.getName()); /** * Connection constructor, the starting point for any other JMLC API calls. @@ -298,137 +287,6 @@ public PreparedScript prepareScript(String script, Map nsscripts, //return newly create precompiled script return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); } - - /** - * Loads a HuggingFace model via Python worker for LLM inference. - * Uses auto-detected available ports for Py4J communication. - * - * @param modelName HuggingFace model name - * @param workerScriptPath path to the Python worker script - * @return LLMCallback interface to the Python worker - */ - public LLMCallback loadModel(String modelName, String workerScriptPath) { - //auto-find available ports - int javaPort = findAvailablePort(); - int pythonPort = findAvailablePort(); - return loadModel(modelName, workerScriptPath, javaPort, pythonPort); - } - - /** - * Loads a HuggingFace model via Python worker for LLM inference. - * Starts a Python subprocess and connects via Py4J. - * - * @param modelName HuggingFace model name - * @param workerScriptPath path to the Python worker script - * @param javaPort port for Java gateway server - * @param pythonPort port for Python callback server - * @return LLMCallback interface to the Python worker - */ - public LLMCallback loadModel(String modelName, String workerScriptPath, int javaPort, int pythonPort) { - if (_llmWorker != null) - return _llmWorker; - try { - //initialize latch for worker registration - _workerLatch = new CountDownLatch(1); - - //start Py4J gateway server with callback support - _gatewayServer = new GatewayServer.GatewayServerBuilder() - .entryPoint(this) - .javaPort(javaPort) - .callbackClient(pythonPort, java.net.InetAddress.getLoopbackAddress()) - .build(); - _gatewayServer.start(); - - //give gateway time to start - Thread.sleep(500); - - //start python worker process with both ports - String pythonCmd = findPythonCommand(); - LOG.info("Starting LLM worker with script: " + workerScriptPath + - " (python=" + pythonCmd + ", javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); - _pythonProcess = new ProcessBuilder( - pythonCmd, workerScriptPath, modelName, - String.valueOf(javaPort), String.valueOf(pythonPort) - ).redirectErrorStream(true).start(); - - //read python output in background thread - Thread outputReader = new Thread(() -> { - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(_pythonProcess.getInputStream()))) { - String line; - while ((line = reader.readLine()) != null) { - LOG.info("[LLM Worker] " + line); - } - } catch (IOException e) { - LOG.error("Error reading LLM worker output", e); - } - }); - outputReader.setName("llm-worker-output"); - outputReader.setDaemon(true); - outputReader.start(); - - //larger models (7B+) need more time to load weights into GPU memory - long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(300); - while (!_workerLatch.await(2, TimeUnit.SECONDS)) { - if (!_pythonProcess.isAlive()) { - int exitCode = _pythonProcess.exitValue(); - throw new DMLException("LLM worker process died during startup (exit code " + exitCode + ")"); - } - if (System.nanoTime() > deadlineNs) { - throw new DMLException("Timeout waiting for LLM worker to register (300s)"); - } - } - - } catch (DMLException e) { - throw e; - } catch (Exception e) { - throw new DMLException("Failed to start LLM worker: " + e.getMessage()); - } - return _llmWorker; - } - - /** - * Finds the available Python command, trying python3 first then python. - * @return python command name - */ - private static String findPythonCommand() { - for (String cmd : new String[]{"python3", "python"}) { - try { - Process p = new ProcessBuilder(cmd, "--version") - .redirectErrorStream(true).start(); - int exitCode = p.waitFor(); - if (exitCode == 0) - return cmd; - } catch (Exception e) { - //command not found, try next - } - } - throw new DMLException("No Python installation found (tried python3, python)"); - } - - /** - * Finds an available port on the local machine. - * @return available port number - */ - private int findAvailablePort() { - try (java.net.ServerSocket socket = new java.net.ServerSocket(0)) { - socket.setReuseAddress(true); - return socket.getLocalPort(); - } catch (IOException e) { - throw new DMLException("Failed to find available port: " + e.getMessage()); - } - } - - /** - * Called by Python worker to register itself via Py4J. - */ - public void registerWorker(LLMCallback worker) { - _llmWorker = worker; - if (_workerLatch != null) { - _workerLatch.countDown(); - } - LOG.info("LLM worker registered successfully"); - } /** * Close connection to SystemDS, which clears the @@ -436,23 +294,6 @@ public void registerWorker(LLMCallback worker) { */ @Override public void close() { - - //shutdown LLM worker if running - if (_pythonProcess != null) { - _pythonProcess.destroyForcibly(); - try { - _pythonProcess.waitFor(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - _pythonProcess = null; - } - if (_gatewayServer != null) { - _gatewayServer.shutdown(); - _gatewayServer = null; - } - _llmWorker = null; - //clear thread-local configurations ConfigurationManager.clearLocalConfigs(); if( ConfigurationManager.isCodegenEnabled() ) diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java deleted file mode 100644 index 308abc2f316..00000000000 --- a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java +++ /dev/null @@ -1,32 +0,0 @@ -package org.apache.sysds.api.jmlc; - -/** - * Interface for the Python LLM worker. - * The Python side implements this via Py4J callback. - */ -public interface LLMCallback { - - /** - * Generates text using the LLM model. - */ - String generate(String prompt, int maxNewTokens, double temperature, double topP); - - /** - * Generates text and returns result with token counts as a JSON string. - * Format: {"text": "...", "input_tokens": N, "output_tokens": M} - */ - String generateWithTokenCount(String prompt, int maxNewTokens, double temperature, double topP); - - /** - * Generates text for multiple prompts in a single batched GPU call. - * Returns a JSON array of objects with text and token counts. - * Format: [{"text": "...", "input_tokens": N, "output_tokens": M, "time_ms": T}, ...] - * - * @param prompts array of input prompt texts - * @param maxNewTokens maximum number of new tokens to generate per prompt - * @param temperature sampling temperature - * @param topP nucleus sampling probability threshold - * @return JSON array string with results for each prompt - */ - String generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP); -} diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index 4b4a9d33e6f..31bb7457227 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -80,9 +80,6 @@ public class PreparedScript implements ConfigurableAPI private final CompilerConfig _cconf; private HashMap _outVarLineage; - //LLM inference support - private LLMCallback _llmWorker = null; - private PreparedScript(PreparedScript that) { //shallow copy, except for a separate symbol table //and related meta data of reused inputs @@ -163,145 +160,6 @@ public CompilerConfig getCompilerConfig() { return _cconf; } - /** - * Sets the LLM worker callback for text generation. - * - * @param worker the LLM callback interface - */ - public void setLLMWorker(LLMCallback worker) { - _llmWorker = worker; - } - - /** - * Gets the LLM worker callback. - * - * @return the LLM callback interface, or null if not set - */ - public LLMCallback getLLMWorker() { - return _llmWorker; - } - - /** - * Generates text using the LLM worker. - * - * @param prompt the input prompt text - * @param maxNewTokens maximum number of new tokens to generate - * @param temperature sampling temperature (0.0 = deterministic, higher = more random) - * @param topP nucleus sampling probability threshold - * @return generated text - * @throws DMLException if no LLM worker is set - */ - public String generate(String prompt, int maxNewTokens, double temperature, double topP) { - if (_llmWorker == null) { - throw new DMLException("No LLM worker set. Call setLLMWorker() first."); - } - return _llmWorker.generate(prompt, maxNewTokens, temperature, topP); - } - - /** - * Generates text for multiple prompts and returns results as a FrameBlock. - * The FrameBlock has two columns: [prompt, generated_text]. - * - * @param prompts array of input prompt texts - * @param maxNewTokens maximum number of new tokens to generate - * @param temperature sampling temperature - * @param topP nucleus sampling probability threshold - * @return FrameBlock with columns [prompt, generated_text] - */ - public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP) { - if (_llmWorker == null) { - throw new DMLException("No LLM worker set. Call setLLMWorker() first."); - } - //generate text for each prompt - String[][] data = new String[prompts.length][2]; - for (int i = 0; i < prompts.length; i++) { - data[i][0] = prompts[i]; - data[i][1] = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); - } - //create FrameBlock with string schema - ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING}; - String[] colNames = new String[]{"prompt", "generated_text"}; - FrameBlock fb = new FrameBlock(schema, colNames); - for (String[] row : data) - fb.appendRow(row); - return fb; - } - - /** - * Generates text for multiple prompts and returns results with timing metrics. - * The FrameBlock has five columns: [prompt, generated_text, time_ms, input_tokens, output_tokens]. - * - * @param prompts array of input prompt texts - * @param maxNewTokens maximum number of new tokens to generate - * @param temperature sampling temperature - * @param topP nucleus sampling probability threshold - * @return FrameBlock with columns [prompt, generated_text, time_ms, input_tokens, output_tokens] - */ - public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) { - return generateBatchWithMetrics(prompts, maxNewTokens, temperature, topP, true); - } - - /** - * Generates text for an array of prompts and returns a FrameBlock with columns - * [prompt, generated_text, time_ms, input_tokens, output_tokens]. - * - * @param prompts array of input prompts - * @param maxNewTokens max tokens to generate per prompt - * @param temperature sampling temperature - * @param topP nucleus sampling threshold - * @param batched if true, sends all prompts to the GPU in one call (faster); - * if false, processes prompts sequentially (original behavior) - * @return FrameBlock with inference results - */ - public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP, boolean batched) { - if (_llmWorker == null) { - throw new DMLException("No LLM worker set. Call setLLMWorker() first."); - } - String[][] data = new String[prompts.length][5]; - if (batched) { - //GPU-batched: single call to Python worker for all prompts - try { - String jsonArray = _llmWorker.generateBatch(prompts, maxNewTokens, temperature, topP); - org.apache.wink.json4j.JSONArray results = new org.apache.wink.json4j.JSONArray(jsonArray); - for (int i = 0; i < prompts.length; i++) { - org.apache.wink.json4j.JSONObject obj = results.getJSONObject(i); - data[i][0] = prompts[i]; - data[i][1] = obj.getString("text"); - data[i][2] = String.valueOf(obj.getInt("time_ms")); - data[i][3] = String.valueOf(obj.getInt("input_tokens")); - data[i][4] = String.valueOf(obj.getInt("output_tokens")); - } - } catch (Exception e) { - throw new DMLException("Failed to parse batched LLM response: " + e.getMessage()); - } - } else { - //sequential: one prompt at a time (original behavior) - for (int i = 0; i < prompts.length; i++) { - long start = System.nanoTime(); - String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP); - long elapsed = (System.nanoTime() - start) / 1_000_000; - try { - org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json); - data[i][0] = prompts[i]; - data[i][1] = obj.getString("text"); - data[i][2] = String.valueOf(elapsed); - data[i][3] = String.valueOf(obj.getInt("input_tokens")); - data[i][4] = String.valueOf(obj.getInt("output_tokens")); - } catch (Exception e) { - throw new DMLException("Failed to parse LLM worker response: " + e.getMessage()); - } - } - } - //create FrameBlock with schema - ValueType[] schema = new ValueType[]{ - ValueType.STRING, ValueType.STRING, ValueType.INT64, ValueType.INT64, ValueType.INT64}; - String[] colNames = new String[]{"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; - FrameBlock fb = new FrameBlock(schema, colNames); - for (String[] row : data) - fb.appendRow(row); - return fb; - } - /** * Binds a scalar boolean to a registered input variable. * diff --git a/src/main/python/llm_worker.py b/src/main/python/llm_worker.py deleted file mode 100644 index 7df196fcd89..00000000000 --- a/src/main/python/llm_worker.py +++ /dev/null @@ -1,115 +0,0 @@ -import sys, json, time, torch -from transformers import AutoTokenizer, AutoModelForCausalLM -from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters - -class LLMWorker: - def __init__(self, model_name="distilgpt2"): - print(f"Loading model: {model_name}", flush=True) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - if torch.cuda.is_available(): - print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="auto", torch_dtype=torch.float16) - self.device = "cuda" - else: - self.model = AutoModelForCausalLM.from_pretrained(model_name) - self.device = "cpu" - self.model.eval() - print(f"Model loaded: {model_name} (device={self.device})", flush=True) - - def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): - inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) - with torch.no_grad(): - outputs = self.model.generate( - **inputs, - max_new_tokens=int(max_new_tokens), - temperature=float(temperature), - top_p=float(top_p), - do_sample=float(temperature) > 0.0 - ) - new_tokens = outputs[0][inputs["input_ids"].shape[1]:] - return self.tokenizer.decode(new_tokens, skip_special_tokens=True) - - def generateWithTokenCount(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): - inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) - input_token_count = inputs["input_ids"].shape[1] - with torch.no_grad(): - outputs = self.model.generate( - **inputs, - max_new_tokens=int(max_new_tokens), - temperature=float(temperature), - top_p=float(top_p), - do_sample=float(temperature) > 0.0 - ) - new_tokens = outputs[0][input_token_count:] - output_token_count = len(new_tokens) - text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) - return json.dumps({ - "text": text, - "input_tokens": input_token_count, - "output_tokens": output_token_count - }) - - def generateBatch(self, prompts, max_new_tokens=50, temperature=0.7, top_p=0.9): - prompt_list = list(prompts) - n = len(prompt_list) - results = [] - # process in sub-batches to avoid OOM - batch_size = min(n, 8) - for start in range(0, n, batch_size): - end = min(start + batch_size, n) - batch = prompt_list[start:end] - t0 = time.time() - inputs = self.tokenizer( - batch, return_tensors="pt", padding=True, truncation=True, - max_length=2048 - ).to(self.model.device) - with torch.no_grad(): - outputs = self.model.generate( - **inputs, - max_new_tokens=int(max_new_tokens), - temperature=float(temperature), - top_p=float(top_p), - do_sample=float(temperature) > 0.0 - ) - elapsed_ms = (time.time() - t0) * 1000 - per_prompt_ms = elapsed_ms / len(batch) - for i, prompt_text in enumerate(batch): - input_len = (inputs["input_ids"][i] != self.tokenizer.pad_token_id).sum().item() - new_tokens = outputs[i][inputs["input_ids"].shape[1]:] - # strip padding from generated tokens - non_pad = [t for t in new_tokens.tolist() if t != self.tokenizer.pad_token_id] - text = self.tokenizer.decode(non_pad, skip_special_tokens=True) - results.append({ - "text": text, - "input_tokens": input_len, - "output_tokens": len(non_pad), - "time_ms": int(per_prompt_ms) - }) - return json.dumps(results) - - class Java: - implements = ["org.apache.sysds.api.jmlc.LLMCallback"] - -if __name__ == "__main__": - model_name = sys.argv[1] if len(sys.argv) > 1 else "distilgpt2" - java_port = int(sys.argv[2]) if len(sys.argv) > 2 else 25333 - python_port = int(sys.argv[3]) if len(sys.argv) > 3 else 25334 - - print(f"Starting LLM worker (javaPort={java_port}, pythonPort={python_port})", flush=True) - worker = LLMWorker(model_name) - gateway = JavaGateway( - gateway_parameters=GatewayParameters(port=java_port), - callback_server_parameters=CallbackServerParameters(port=python_port) - ) - print(f"Python callback server started on port {python_port}", flush=True) - gateway.entry_point.registerWorker(worker) - print("Worker registered with Java, waiting for requests...", flush=True) - import threading - shutdown_event = threading.Event() - try: - shutdown_event.wait() - except KeyboardInterrupt: - print("Worker shutting down", flush=True) From bea062ad296032f0a2ba7057571faf1df9eda8e9 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:50:03 +0100 Subject: [PATCH 19/24] Rewrite LLM test to use llmPredict DML built-in --- .../functions/jmlc/JMLCLLMInferenceTest.java | 164 +++++++----------- 1 file changed, 65 insertions(+), 99 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index 0e07e07d221..fb8b53770b7 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,24 +19,30 @@ package org.apache.sysds.test.functions.jmlc; +import java.util.HashMap; +import java.util.Map; + import org.apache.sysds.api.jmlc.Connection; -import org.apache.sysds.api.jmlc.LLMCallback; import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.api.jmlc.ResultVariables; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.test.AutomatedTestBase; import org.junit.Assert; import org.junit.Test; /** - * Test LLM inference capabilities via JMLC API. - * This test requires Python with transformers and torch installed. + * Test LLM inference via the llmPredict built-in function. + * Requires an OpenAI-compatible server (e.g., llm_server.py) on localhost:8080. */ public class JMLCLLMInferenceTest extends AutomatedTestBase { private final static String TEST_NAME = "JMLCLLMInferenceTest"; private final static String TEST_DIR = "functions/jmlc/"; - private final static String MODEL_NAME = "distilgpt2"; - private final static String WORKER_SCRIPT = "src/main/python/llm_worker.py"; - private final static String DML_SCRIPT = "x = 1;\nwrite(x, './tmp/x');"; + private final static String LLM_URL = "http://localhost:8080/v1/completions"; + + private final static String DML_SCRIPT = + "X = read(\"prompts\", data_type=\"frame\")\n" + + "R = llmPredict(target=X, url=$url, max_tokens=$mt, temperature=$temp, top_p=$tp)\n" + + "write(R, \"results\")"; @Override public void setUp() { @@ -44,128 +50,88 @@ public void setUp() { getAndLoadTestConfiguration(TEST_NAME); } - /** - * Creates a connection, loads the LLM model, and returns a PreparedScript - * with the LLM worker attached. - */ - private PreparedScript createLLMScript(Connection conn) throws Exception { - LLMCallback llmWorker = conn.loadModel(MODEL_NAME, WORKER_SCRIPT); - Assert.assertNotNull("LLM worker should not be null", llmWorker); - PreparedScript ps = conn.prepareScript(DML_SCRIPT, new String[]{}, new String[]{"x"}); - ps.setLLMWorker(llmWorker); - return ps; - } - @Test - public void testLLMInference() { + public void testSinglePrompt() { Connection conn = null; try { conn = new Connection(); - PreparedScript ps = createLLMScript(conn); - - //generate text using llm - String prompt = "The meaning of life is"; - String result = ps.generate(prompt, 20, 0.7, 0.9); - - //verify result - Assert.assertNotNull("Generated text should not be null", result); - Assert.assertFalse("Generated text should not be empty", result.isEmpty()); - - System.out.println("Prompt: " + prompt); - System.out.println("Generated: " + result); - + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"The meaning of life is"}}; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 1 row", 1, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + String generated = result.get(0, 1).toString(); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + + System.out.println("Prompt: " + promptData[0][0]); + System.out.println("Generated: " + generated); } catch (Exception e) { - System.out.println("Skipping LLM test:"); + System.out.println("Skipping LLM test (server not running):"); e.printStackTrace(); - org.junit.Assume.assumeNoException("LLM dependencies not available", e); + org.junit.Assume.assumeNoException("LLM server not available", e); } finally { - if (conn != null) - conn.close(); + if (conn != null) conn.close(); } } - + @Test public void testBatchInference() { Connection conn = null; try { conn = new Connection(); - PreparedScript ps = createLLMScript(conn); - - //batch generate with multiple prompts + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + String[] prompts = { "The meaning of life is", "Machine learning is", "Apache SystemDS enables" }; - FrameBlock result = ps.generateBatch(prompts, 20, 0.7, 0.9); - - //verify FrameBlock structure - Assert.assertNotNull("Batch result should not be null", result); + String[][] promptData = new String[prompts.length][1]; + for (int i = 0; i < prompts.length; i++) + promptData[i][0] = prompts[i]; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); - Assert.assertEquals("Should have 2 columns", 2, result.getNumColumns()); - - //verify each row has prompt and generated text - for (int i = 0; i < prompts.length; i++) { - String prompt = (String) result.get(i, 0); - String generated = (String) result.get(i, 1); - Assert.assertEquals("Prompt should match", prompts[i], prompt); - Assert.assertNotNull("Generated text should not be null", generated); - Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); - System.out.println("Prompt: " + prompt); - System.out.println("Generated: " + generated); - } - - } catch (Exception e) { - System.out.println("Skipping batch LLM test:"); - e.printStackTrace(); - org.junit.Assume.assumeNoException("LLM dependencies not available", e); - } finally { - if (conn != null) - conn.close(); - } - } - - @Test - public void testBatchWithMetrics() { - Connection conn = null; - try { - conn = new Connection(); - PreparedScript ps = createLLMScript(conn); - - //batch generate with metrics - String[] prompts = {"The meaning of life is", "Data science is"}; - FrameBlock result = ps.generateBatchWithMetrics(prompts, 20, 0.7, 0.9); - - //verify FrameBlock structure with metrics and token counts - Assert.assertNotNull("Metrics result should not be null", result); - Assert.assertEquals("Should have 2 rows", 2, result.getNumRows()); Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); - - //verify metrics columns contain timing and token data + for (int i = 0; i < prompts.length; i++) { - String prompt = (String) result.get(i, 0); - String generated = (String) result.get(i, 1); + String prompt = result.get(i, 0).toString(); + String generated = result.get(i, 1).toString(); long timeMs = Long.parseLong(result.get(i, 2).toString()); - long inputTokens = Long.parseLong(result.get(i, 3).toString()); - long outputTokens = Long.parseLong(result.get(i, 4).toString()); Assert.assertEquals("Prompt should match", prompts[i], prompt); Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); Assert.assertTrue("Time should be positive", timeMs > 0); - Assert.assertTrue("Input tokens should be positive", inputTokens > 0); - Assert.assertTrue("Output tokens should be positive", outputTokens > 0); System.out.println("Prompt: " + prompt); - System.out.println("Generated: " + generated); - System.out.println("Time: " + timeMs + "ms"); - System.out.println("Tokens: " + inputTokens + " in, " + outputTokens + " out"); + System.out.println("Generated: " + generated + " (" + timeMs + "ms)"); } - } catch (Exception e) { - System.out.println("Skipping metrics LLM test:"); + System.out.println("Skipping batch LLM test (server not running):"); e.printStackTrace(); - org.junit.Assume.assumeNoException("LLM dependencies not available", e); + org.junit.Assume.assumeNoException("LLM server not available", e); } finally { - if (conn != null) - conn.close(); + if (conn != null) conn.close(); } } } From edf4e395bf60c110560df032d1b9f90e54ae445a Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:51:35 +0100 Subject: [PATCH 20/24] Add OpenAI-compatible HTTP inference server for HuggingFace models --- src/main/python/llm_server.py | 103 ++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 src/main/python/llm_server.py diff --git a/src/main/python/llm_server.py b/src/main/python/llm_server.py new file mode 100644 index 00000000000..f90e8cd642b --- /dev/null +++ b/src/main/python/llm_server.py @@ -0,0 +1,103 @@ +"""OpenAI-compatible HTTP server for local HuggingFace model inference. + +Serves the /v1/completions endpoint so that SystemDS llmPredict() and +other OpenAI-compatible clients can call it directly. + +Usage: + python llm_server.py [--port PORT] + +Examples: + python llm_server.py distilgpt2 --port 8080 + python llm_server.py Qwen/Qwen2.5-3B-Instruct --port 8080 +""" + +import argparse +import json +import sys +import time +from http.server import HTTPServer, BaseHTTPRequestHandler + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +class InferenceHandler(BaseHTTPRequestHandler): + + def do_POST(self): + if self.path != "/v1/completions": + self.send_error(404) + return + length = int(self.headers.get("Content-Length", 0)) + body = json.loads(self.rfile.read(length)) + + prompt = body.get("prompt", "") + max_tokens = int(body.get("max_tokens", 512)) + temperature = float(body.get("temperature", 0.0)) + top_p = float(body.get("top_p", 0.9)) + + model = self.server.model + tokenizer = self.server.tokenizer + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_len = inputs["input_ids"].shape[1] + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_tokens, + temperature=temperature if temperature > 0 else 1.0, + top_p=top_p, + do_sample=temperature > 0, + ) + new_tokens = outputs[0][input_len:] + text = tokenizer.decode(new_tokens, skip_special_tokens=True) + + resp = { + "choices": [{"text": text}], + "usage": { + "prompt_tokens": input_len, + "completion_tokens": len(new_tokens), + }, + } + payload = json.dumps(resp).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + self.wfile.write(payload) + + def log_message(self, fmt, *args): + sys.stderr.write("[llm_server] %s\n" % (fmt % args)) + + +def main(): + parser = argparse.ArgumentParser(description="OpenAI-compatible LLM server") + parser.add_argument("model", help="HuggingFace model name") + parser.add_argument("--port", type=int, default=8080) + args = parser.parse_args() + + print(f"Loading model: {args.model}", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if torch.cuda.is_available(): + print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, device_map="auto", torch_dtype=torch.float16) + else: + model = AutoModelForCausalLM.from_pretrained(args.model) + model.eval() + print(f"Model loaded on {next(model.parameters()).device}", flush=True) + + server = HTTPServer(("0.0.0.0", args.port), InferenceHandler) + server.model = model + server.tokenizer = tokenizer + print(f"Serving on http://0.0.0.0:{args.port}/v1/completions", flush=True) + try: + server.serve_forever() + except KeyboardInterrupt: + print("Shutting down", flush=True) + server.server_close() + + +if __name__ == "__main__": + main() From 45882e232ff78d5c378efd58b5c1855c58437ac8 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:25:17 +0100 Subject: [PATCH 21/24] Fix llmPredict code quality issues - Use proper imports instead of inline fully-qualified class names - Add try-with-resources for HTTP streams to prevent resource leaks - Add connect/read timeouts to HTTP calls - Add lineage tracing support for llmPredict - Add checkInvalidParameters validation in parser - Remove .claude/.env/meeting_notes from .gitignore - Trim verbose docstrings --- .gitignore | 8 ---- ...arameterizedBuiltinFunctionExpression.java | 5 ++- .../cp/ParameterizedBuiltinCPInstruction.java | 43 +++++++++++++------ src/main/python/llm_server.py | 12 +----- 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 642669e2600..db93fe7c49c 100644 --- a/.gitignore +++ b/.gitignore @@ -156,11 +156,3 @@ docker/mountFolder/*.bin docker/mountFolder/*.bin.mtd SEAL-*/ - -# local tool config and sensitive files -.claude/ -.env -.env.* -meeting_notes/ -meeting_notes.* -*meeting_notes* diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 89ed5d8ae34..29ad6cbd737 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -619,8 +619,11 @@ private void validateTokenize(DataIdentifier output, boolean conditional) output.setDimensions(-1, -1); } - private void validateLlmPredict(DataIdentifier output, boolean conditional) + private void validateLlmPredict(DataIdentifier output, boolean conditional) { + Set valid = new HashSet<>(Arrays.asList( + "target", "url", "max_tokens", "temperature", "top_p")); + checkInvalidParameters(getOpCode(), getVarParams(), valid); checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional); checkStringParam(false, "llmPredict", "url", conditional); output.setDataType(DataType.FRAME); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 69045beeb0e..1cef848a1e7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -19,6 +19,11 @@ package org.apache.sysds.runtime.instructions.cp; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -27,6 +32,8 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.wink.json4j.JSONObject; + import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -336,43 +343,46 @@ else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) { Double.parseDouble(params.get("top_p")) : 0.9; int n = prompts.getNumRows(); - String[][] data = new String[n][5]; + String[][] data = new String[n][]; for(int i = 0; i < n; i++) { String prompt = prompts.get(i, 0).toString(); long t0 = System.nanoTime(); try { - org.apache.wink.json4j.JSONObject req = new org.apache.wink.json4j.JSONObject(); + JSONObject req = new JSONObject(); req.put("prompt", prompt); req.put("max_tokens", maxTokens); req.put("temperature", temperature); req.put("top_p", topP); - java.net.URL endpoint = new java.net.URI(url).toURL(); - java.net.HttpURLConnection conn = - (java.net.HttpURLConnection) endpoint.openConnection(); + HttpURLConnection conn = (HttpURLConnection) + new URI(url).toURL().openConnection(); conn.setRequestMethod("POST"); conn.setRequestProperty("Content-Type", "application/json"); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(120_000); conn.setDoOutput(true); - conn.getOutputStream().write( - req.toString().getBytes(java.nio.charset.StandardCharsets.UTF_8)); - conn.getOutputStream().close(); + + try(OutputStream os = conn.getOutputStream()) { + os.write(req.toString().getBytes(StandardCharsets.UTF_8)); + } if(conn.getResponseCode() != 200) throw new DMLRuntimeException( "LLM endpoint returned HTTP " + conn.getResponseCode()); - String body = new String(conn.getInputStream().readAllBytes(), - java.nio.charset.StandardCharsets.UTF_8); + String body; + try(InputStream is = conn.getInputStream()) { + body = new String(is.readAllBytes(), StandardCharsets.UTF_8); + } conn.disconnect(); - org.apache.wink.json4j.JSONObject resp = - new org.apache.wink.json4j.JSONObject(body); + JSONObject resp = new JSONObject(body); String text = resp.getJSONArray("choices") .getJSONObject(0).getString("text"); long elapsed = (System.nanoTime() - t0) / 1_000_000; int inTok = 0, outTok = 0; if(resp.has("usage")) { - org.apache.wink.json4j.JSONObject usage = resp.getJSONObject("usage"); + JSONObject usage = resp.getJSONObject("usage"); inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0; outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0; } @@ -396,7 +406,6 @@ else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) { ec.releaseFrameInput(params.get("target")); } - else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) { // acquire locks FrameBlock data = ec.getFrameInput(params.get("target")); @@ -622,6 +631,12 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString()) || opcode.eq return Pair.of(output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, meta, spec))); } + else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) { + CPOperand target = new CPOperand(params.get("target"), ValueType.STRING, DataType.FRAME); + CPOperand urlOp = getStringLiteral("url"); + return Pair.of(output.getName(), + new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, urlOp))); + } else if (opcode.equalsIgnoreCase(Opcodes.NVLIST.toString()) || opcode.equalsIgnoreCase(Opcodes.AUTODIFF.toString())) { List names = new ArrayList<>(params.keySet()); CPOperand[] listOperands = names.stream().map(n -> ec.containsVariable(params.get(n)) diff --git a/src/main/python/llm_server.py b/src/main/python/llm_server.py index f90e8cd642b..7ff439949bf 100644 --- a/src/main/python/llm_server.py +++ b/src/main/python/llm_server.py @@ -1,14 +1,6 @@ -"""OpenAI-compatible HTTP server for local HuggingFace model inference. +"""Simple /v1/completions server for local HuggingFace models. -Serves the /v1/completions endpoint so that SystemDS llmPredict() and -other OpenAI-compatible clients can call it directly. - -Usage: - python llm_server.py [--port PORT] - -Examples: - python llm_server.py distilgpt2 --port 8080 - python llm_server.py Qwen/Qwen2.5-3B-Instruct --port 8080 +Usage: python llm_server.py distilgpt2 --port 8080 """ import argparse From c3e9a1fc952872e47fb7dc5ab87ebf6d284714d5 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:33:50 +0100 Subject: [PATCH 22/24] Add concurrency parameter to llmPredict built-in Supports parallel HTTP calls to the inference server via ExecutorService. Default concurrency=1 keeps sequential behavior. --- ...arameterizedBuiltinFunctionExpression.java | 2 +- .../cp/ParameterizedBuiltinCPInstruction.java | 117 +++++++++++------- 2 files changed, 73 insertions(+), 46 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 29ad6cbd737..08dc91af405 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -622,7 +622,7 @@ private void validateTokenize(DataIdentifier output, boolean conditional) private void validateLlmPredict(DataIdentifier output, boolean conditional) { Set valid = new HashSet<>(Arrays.asList( - "target", "url", "max_tokens", "temperature", "top_p")); + "target", "url", "max_tokens", "temperature", "top_p", "concurrency")); checkInvalidParameters(getOpCode(), getVarParams(), valid); checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional); checkStringParam(false, "llmPredict", "url", conditional); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 1cef848a1e7..90401b8cd02 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -29,6 +29,10 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -341,59 +345,38 @@ else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) { Double.parseDouble(params.get("temperature")) : 0.0; double topP = params.containsKey("top_p") ? Double.parseDouble(params.get("top_p")) : 0.9; + int concurrency = params.containsKey("concurrency") ? + Integer.parseInt(params.get("concurrency")) : 1; int n = prompts.getNumRows(); String[][] data = new String[n][]; + + // build one callable per prompt + List> tasks = new ArrayList<>(n); for(int i = 0; i < n; i++) { String prompt = prompts.get(i, 0).toString(); - long t0 = System.nanoTime(); - try { - JSONObject req = new JSONObject(); - req.put("prompt", prompt); - req.put("max_tokens", maxTokens); - req.put("temperature", temperature); - req.put("top_p", topP); - - HttpURLConnection conn = (HttpURLConnection) - new URI(url).toURL().openConnection(); - conn.setRequestMethod("POST"); - conn.setRequestProperty("Content-Type", "application/json"); - conn.setConnectTimeout(10_000); - conn.setReadTimeout(120_000); - conn.setDoOutput(true); - - try(OutputStream os = conn.getOutputStream()) { - os.write(req.toString().getBytes(StandardCharsets.UTF_8)); - } - - if(conn.getResponseCode() != 200) - throw new DMLRuntimeException( - "LLM endpoint returned HTTP " + conn.getResponseCode()); - - String body; - try(InputStream is = conn.getInputStream()) { - body = new String(is.readAllBytes(), StandardCharsets.UTF_8); - } - conn.disconnect(); - - JSONObject resp = new JSONObject(body); - String text = resp.getJSONArray("choices") - .getJSONObject(0).getString("text"); - long elapsed = (System.nanoTime() - t0) / 1_000_000; - int inTok = 0, outTok = 0; - if(resp.has("usage")) { - JSONObject usage = resp.getJSONObject("usage"); - inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0; - outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0; - } - data[i] = new String[]{prompt, text, - String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)}; + tasks.add(() -> callLlmEndpoint(prompt, url, maxTokens, temperature, topP)); + } + + try { + if(concurrency <= 1) { + // sequential + for(int i = 0; i < n; i++) + data[i] = tasks.get(i).call(); } - catch(DMLRuntimeException e) { throw e; } - catch(Exception e) { - throw new DMLRuntimeException("llmPredict HTTP call failed: " + e.getMessage(), e); + else { + // parallel + ExecutorService pool = Executors.newFixedThreadPool( + Math.min(concurrency, n)); + List> futures = pool.invokeAll(tasks); + pool.shutdown(); + for(int i = 0; i < n; i++) + data[i] = futures.get(i).get(); } } + catch(Exception e) { + throw new DMLRuntimeException("llmPredict failed: " + e.getMessage(), e); + } ValueType[] schema = {ValueType.STRING, ValueType.STRING, ValueType.INT64, ValueType.INT64, ValueType.INT64}; @@ -570,6 +553,50 @@ private void warnOnTrunction(TensorBlock data, int rows, int cols) { } } + private static String[] callLlmEndpoint(String prompt, String url, + int maxTokens, double temperature, double topP) throws Exception { + long t0 = System.nanoTime(); + JSONObject req = new JSONObject(); + req.put("prompt", prompt); + req.put("max_tokens", maxTokens); + req.put("temperature", temperature); + req.put("top_p", topP); + + HttpURLConnection conn = (HttpURLConnection) + new URI(url).toURL().openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(120_000); + conn.setDoOutput(true); + + try(OutputStream os = conn.getOutputStream()) { + os.write(req.toString().getBytes(StandardCharsets.UTF_8)); + } + if(conn.getResponseCode() != 200) + throw new DMLRuntimeException( + "LLM endpoint returned HTTP " + conn.getResponseCode()); + + String body; + try(InputStream is = conn.getInputStream()) { + body = new String(is.readAllBytes(), StandardCharsets.UTF_8); + } + conn.disconnect(); + + JSONObject resp = new JSONObject(body); + String text = resp.getJSONArray("choices") + .getJSONObject(0).getString("text"); + long elapsed = (System.nanoTime() - t0) / 1_000_000; + int inTok = 0, outTok = 0; + if(resp.has("usage")) { + JSONObject usage = resp.getJSONObject("usage"); + inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0; + outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0; + } + return new String[]{prompt, text, + String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)}; + } + @Override public Pair getLineageItem(ExecutionContext ec) { String opcode = getOpcode(); From 53e3febd02438476b88e015324f566ff5007b7ef Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:38:35 +0100 Subject: [PATCH 23/24] Remove license header from test, clarify llm_server.py docstring --- src/main/python/llm_server.py | 3 ++- .../functions/jmlc/JMLCLLMInferenceTest.java | 19 ------------------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/main/python/llm_server.py b/src/main/python/llm_server.py index 7ff439949bf..4ebf9b87afb 100644 --- a/src/main/python/llm_server.py +++ b/src/main/python/llm_server.py @@ -1,4 +1,5 @@ -"""Simple /v1/completions server for local HuggingFace models. +"""Local inference server for llmPredict. Loads a HuggingFace model +and serves it at http://localhost:PORT/v1/completions. Usage: python llm_server.py distilgpt2 --port 8080 """ diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index fb8b53770b7..cbf85ded50e 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -1,22 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - package org.apache.sysds.test.functions.jmlc; import java.util.HashMap; From e872f2277be30d2f1dd074692ab89a655c582811 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Mon, 16 Feb 2026 22:45:03 +0100 Subject: [PATCH 24/24] Fix JMLC frame binding: match DML variable names to registered inputs JMLC requires the LHS variable name in read() assignments to match the input name registered in prepareScript(). Changed X/R to prompts/results so RewriteRemovePersistentReadWrite correctly converts persistent reads to transient reads. --- .../sysds/test/functions/jmlc/JMLCLLMInferenceTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index cbf85ded50e..1c259129356 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -21,9 +21,9 @@ public class JMLCLLMInferenceTest extends AutomatedTestBase { private final static String LLM_URL = "http://localhost:8080/v1/completions"; private final static String DML_SCRIPT = - "X = read(\"prompts\", data_type=\"frame\")\n" + - "R = llmPredict(target=X, url=$url, max_tokens=$mt, temperature=$temp, top_p=$tp)\n" + - "write(R, \"results\")"; + "prompts = read(\"prompts\", data_type=\"frame\")\n" + + "results = llmPredict(target=prompts, url=$url, max_tokens=$mt, temperature=$temp, top_p=$tp)\n" + + "write(results, \"results\")"; @Override public void setUp() {