diff --git a/.gitignore b/.gitignore index 5de697a37e3..db93fe7c49c 100644 --- a/.gitignore +++ b/.gitignore @@ -156,4 +156,3 @@ docker/mountFolder/*.bin docker/mountFolder/*.bin.mtd SEAL-*/ - diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index dc1f23b83fc..82eccbec021 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -226,6 +226,7 @@ public enum Builtins { LMDS("lmDS", true), LMPREDICT("lmPredict", true), LMPREDICT_STATS("lmPredictStats", true), + LLMPREDICT("llmPredict", false, true), LOCAL("local", false), LOG("log", false), LOGSUMEXP("logSumExp", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..94055d055c5 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -204,6 +204,7 @@ public enum Opcodes { GROUPEDAGG("groupedagg", InstructionType.ParameterizedBuiltin), RMEMPTY("rmempty", InstructionType.ParameterizedBuiltin), REPLACE("replace", InstructionType.ParameterizedBuiltin), + LLMPREDICT("llmpredict", InstructionType.ParameterizedBuiltin), LOWERTRI("lowertri", InstructionType.ParameterizedBuiltin), UPPERTRI("uppertri", InstructionType.ParameterizedBuiltin), REXPAND("rexpand", InstructionType.ParameterizedBuiltin), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..3414614991c 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -805,7 +805,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) { /** Parameterized operations that require named variable arguments */ public enum ParamBuiltinOp { - AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND, + AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, LLMPREDICT, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, TOKENIZE, TOSTRING, LIST, PARAMSERV diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java index 61a4b8b8f91..b791478214b 100644 --- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java @@ -187,6 +187,7 @@ public Lop constructLops() case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: @@ -758,7 +759,7 @@ && getTargetHop().areDimsBelowThreshold() ) { if (_op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA || _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST || _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF - || _op == ParamBuiltinOp.PARAMSERV) { + || _op == ParamBuiltinOp.PARAMSERV || _op == ParamBuiltinOp.LLMPREDICT) { _etype = ExecType.CP; } @@ -768,7 +769,7 @@ && getTargetHop().areDimsBelowThreshold() ) { switch(_op) { case CONTAINS: if(getTargetHop().optFindExecType() == ExecType.SPARK) - _etype = ExecType.SPARK; + _etype = ExecType.SPARK; break; default: // Do not change execution type. diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java index 3604121aac8..dcec28f76ca 100644 --- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java @@ -176,6 +176,7 @@ public String getInstructions(String output) case CONTAINS: case REPLACE: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..b1536371711 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 314440628e0..08dc91af405 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -61,6 +61,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG); pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY); pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE); + pbHopMap.put(Builtins.LLMPREDICT, ParamBuiltinOp.LLMPREDICT); pbHopMap.put(Builtins.LOWER_TRI, ParamBuiltinOp.LOWER_TRI); pbHopMap.put(Builtins.UPPER_TRI, ParamBuiltinOp.UPPER_TRI); @@ -211,6 +212,10 @@ public void validateExpression(HashMap ids, HashMap valid = new HashSet<>(Arrays.asList( + "target", "url", "max_tokens", "temperature", "top_p", "concurrency")); + checkInvalidParameters(getOpCode(), getVarParams(), valid); + checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional); + checkStringParam(false, "llmPredict", "url", conditional); + output.setDataType(DataType.FRAME); + output.setValueType(ValueType.STRING); + output.setDimensions(-1, -1); + } + // example: A = transformapply(target=X, meta=M, spec=s) private void validateTransformApply(DataIdentifier output, boolean conditional) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..90401b8cd02 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -19,14 +19,25 @@ package org.apache.sysds.runtime.instructions.cp; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.wink.json4j.JSONObject; + import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -154,7 +165,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.RMEMPTY.toString()) || opcode.equalsIgno } else if(opcode.equals(Opcodes.TRANSFORMAPPLY.toString()) || opcode.equals(Opcodes.TRANSFORMDECODE.toString()) || opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString()) || opcode.equals(Opcodes.TRANSFORMCOLMAP.toString()) - || opcode.equals(Opcodes.TRANSFORMMETA.toString()) || opcode.equals(Opcodes.TOKENIZE.toString()) + || opcode.equals(Opcodes.TRANSFORMMETA.toString()) || opcode.equals(Opcodes.TOKENIZE.toString()) || opcode.equals(Opcodes.LLMPREDICT.toString()) || opcode.equals(Opcodes.TOSTRING.toString()) || opcode.equals(Opcodes.NVLIST.toString()) || opcode.equals(Opcodes.AUTODIFF.toString())) { return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str); } @@ -324,6 +335,60 @@ else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) { ec.setFrameOutput(output.getName(), fbout); ec.releaseFrameInput(params.get("target")); } + + else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) { + FrameBlock prompts = ec.getFrameInput(params.get("target")); + String url = params.get("url"); + int maxTokens = params.containsKey("max_tokens") ? + Integer.parseInt(params.get("max_tokens")) : 512; + double temperature = params.containsKey("temperature") ? + Double.parseDouble(params.get("temperature")) : 0.0; + double topP = params.containsKey("top_p") ? + Double.parseDouble(params.get("top_p")) : 0.9; + int concurrency = params.containsKey("concurrency") ? + Integer.parseInt(params.get("concurrency")) : 1; + + int n = prompts.getNumRows(); + String[][] data = new String[n][]; + + // build one callable per prompt + List> tasks = new ArrayList<>(n); + for(int i = 0; i < n; i++) { + String prompt = prompts.get(i, 0).toString(); + tasks.add(() -> callLlmEndpoint(prompt, url, maxTokens, temperature, topP)); + } + + try { + if(concurrency <= 1) { + // sequential + for(int i = 0; i < n; i++) + data[i] = tasks.get(i).call(); + } + else { + // parallel + ExecutorService pool = Executors.newFixedThreadPool( + Math.min(concurrency, n)); + List> futures = pool.invokeAll(tasks); + pool.shutdown(); + for(int i = 0; i < n; i++) + data[i] = futures.get(i).get(); + } + } + catch(Exception e) { + throw new DMLRuntimeException("llmPredict failed: " + e.getMessage(), e); + } + + ValueType[] schema = {ValueType.STRING, ValueType.STRING, + ValueType.INT64, ValueType.INT64, ValueType.INT64}; + String[] colNames = {"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; + FrameBlock fbout = new FrameBlock(schema, colNames); + for(String[] row : data) + fbout.appendRow(row); + + ec.setFrameOutput(output.getName(), fbout); + ec.releaseFrameInput(params.get("target")); + } + else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) { // acquire locks FrameBlock data = ec.getFrameInput(params.get("target")); @@ -488,6 +553,50 @@ private void warnOnTrunction(TensorBlock data, int rows, int cols) { } } + private static String[] callLlmEndpoint(String prompt, String url, + int maxTokens, double temperature, double topP) throws Exception { + long t0 = System.nanoTime(); + JSONObject req = new JSONObject(); + req.put("prompt", prompt); + req.put("max_tokens", maxTokens); + req.put("temperature", temperature); + req.put("top_p", topP); + + HttpURLConnection conn = (HttpURLConnection) + new URI(url).toURL().openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(120_000); + conn.setDoOutput(true); + + try(OutputStream os = conn.getOutputStream()) { + os.write(req.toString().getBytes(StandardCharsets.UTF_8)); + } + if(conn.getResponseCode() != 200) + throw new DMLRuntimeException( + "LLM endpoint returned HTTP " + conn.getResponseCode()); + + String body; + try(InputStream is = conn.getInputStream()) { + body = new String(is.readAllBytes(), StandardCharsets.UTF_8); + } + conn.disconnect(); + + JSONObject resp = new JSONObject(body); + String text = resp.getJSONArray("choices") + .getJSONObject(0).getString("text"); + long elapsed = (System.nanoTime() - t0) / 1_000_000; + int inTok = 0, outTok = 0; + if(resp.has("usage")) { + JSONObject usage = resp.getJSONObject("usage"); + inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0; + outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0; + } + return new String[]{prompt, text, + String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)}; + } + @Override public Pair getLineageItem(ExecutionContext ec) { String opcode = getOpcode(); @@ -549,6 +658,12 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString()) || opcode.eq return Pair.of(output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, meta, spec))); } + else if(opcode.equalsIgnoreCase(Opcodes.LLMPREDICT.toString())) { + CPOperand target = new CPOperand(params.get("target"), ValueType.STRING, DataType.FRAME); + CPOperand urlOp = getStringLiteral("url"); + return Pair.of(output.getName(), + new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, urlOp))); + } else if (opcode.equalsIgnoreCase(Opcodes.NVLIST.toString()) || opcode.equalsIgnoreCase(Opcodes.AUTODIFF.toString())) { List names = new ArrayList<>(params.keySet()); CPOperand[] listOperands = names.stream().map(n -> ec.containsVariable(params.get(n)) diff --git a/src/main/python/llm_server.py b/src/main/python/llm_server.py new file mode 100644 index 00000000000..4ebf9b87afb --- /dev/null +++ b/src/main/python/llm_server.py @@ -0,0 +1,96 @@ +"""Local inference server for llmPredict. Loads a HuggingFace model +and serves it at http://localhost:PORT/v1/completions. + +Usage: python llm_server.py distilgpt2 --port 8080 +""" + +import argparse +import json +import sys +import time +from http.server import HTTPServer, BaseHTTPRequestHandler + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +class InferenceHandler(BaseHTTPRequestHandler): + + def do_POST(self): + if self.path != "/v1/completions": + self.send_error(404) + return + length = int(self.headers.get("Content-Length", 0)) + body = json.loads(self.rfile.read(length)) + + prompt = body.get("prompt", "") + max_tokens = int(body.get("max_tokens", 512)) + temperature = float(body.get("temperature", 0.0)) + top_p = float(body.get("top_p", 0.9)) + + model = self.server.model + tokenizer = self.server.tokenizer + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_len = inputs["input_ids"].shape[1] + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_tokens, + temperature=temperature if temperature > 0 else 1.0, + top_p=top_p, + do_sample=temperature > 0, + ) + new_tokens = outputs[0][input_len:] + text = tokenizer.decode(new_tokens, skip_special_tokens=True) + + resp = { + "choices": [{"text": text}], + "usage": { + "prompt_tokens": input_len, + "completion_tokens": len(new_tokens), + }, + } + payload = json.dumps(resp).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + self.wfile.write(payload) + + def log_message(self, fmt, *args): + sys.stderr.write("[llm_server] %s\n" % (fmt % args)) + + +def main(): + parser = argparse.ArgumentParser(description="OpenAI-compatible LLM server") + parser.add_argument("model", help="HuggingFace model name") + parser.add_argument("--port", type=int, default=8080) + args = parser.parse_args() + + print(f"Loading model: {args.model}", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if torch.cuda.is_available(): + print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, device_map="auto", torch_dtype=torch.float16) + else: + model = AutoModelForCausalLM.from_pretrained(args.model) + model.eval() + print(f"Model loaded on {next(model.parameters()).device}", flush=True) + + server = HTTPServer(("0.0.0.0", args.port), InferenceHandler) + server.model = model + server.tokenizer = tokenizer + print(f"Serving on http://0.0.0.0:{args.port}/v1/completions", flush=True) + try: + server.serve_forever() + except KeyboardInterrupt: + print("Shutting down", flush=True) + server.server_close() + + +if __name__ == "__main__": + main() diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java new file mode 100644 index 00000000000..1c259129356 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -0,0 +1,118 @@ +package org.apache.sysds.test.functions.jmlc; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.sysds.api.jmlc.Connection; +import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.api.jmlc.ResultVariables; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test LLM inference via the llmPredict built-in function. + * Requires an OpenAI-compatible server (e.g., llm_server.py) on localhost:8080. + */ +public class JMLCLLMInferenceTest extends AutomatedTestBase { + private final static String TEST_NAME = "JMLCLLMInferenceTest"; + private final static String TEST_DIR = "functions/jmlc/"; + private final static String LLM_URL = "http://localhost:8080/v1/completions"; + + private final static String DML_SCRIPT = + "prompts = read(\"prompts\", data_type=\"frame\")\n" + + "results = llmPredict(target=prompts, url=$url, max_tokens=$mt, temperature=$temp, top_p=$tp)\n" + + "write(results, \"results\")"; + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testSinglePrompt() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"The meaning of life is"}}; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 1 row", 1, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + String generated = result.get(0, 1).toString(); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + + System.out.println("Prompt: " + promptData[0][0]); + System.out.println("Generated: " + generated); + } catch (Exception e) { + System.out.println("Skipping LLM test (server not running):"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testBatchInference() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[] prompts = { + "The meaning of life is", + "Machine learning is", + "Apache SystemDS enables" + }; + String[][] promptData = new String[prompts.length][1]; + for (int i = 0; i < prompts.length; i++) + promptData[i][0] = prompts[i]; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + + for (int i = 0; i < prompts.length; i++) { + String prompt = result.get(i, 0).toString(); + String generated = result.get(i, 1).toString(); + long timeMs = Long.parseLong(result.get(i, 2).toString()); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + Assert.assertTrue("Time should be positive", timeMs > 0); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated + " (" + timeMs + "ms)"); + } + } catch (Exception e) { + System.out.println("Skipping batch LLM test (server not running):"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } +}