diff --git a/astrbot/builtin_stars/python_interpreter/main.py b/astrbot/builtin_stars/python_interpreter/main.py deleted file mode 100644 index ec9d261b7..000000000 --- a/astrbot/builtin_stars/python_interpreter/main.py +++ /dev/null @@ -1,536 +0,0 @@ -import asyncio -import json -import os -import re -import shutil -import time -import uuid -from collections import defaultdict - -import aiodocker -import aiohttp - -from astrbot.api import llm_tool, logger, star -from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter -from astrbot.api.message_components import File, Image -from astrbot.api.provider import ProviderRequest -from astrbot.core.message.components import BaseMessageComponent -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.io import download_file, download_image_by_url - -PROMPT = """ -## Task -You need to generate python codes to solve user's problem: {prompt} - -{extra_input} - -## Limit -1. Available libraries: - - standard libs - - `Pillow` - - `requests` - - `numpy` - - `matplotlib` - - `scipy` - - `scikit-learn` - - `beautifulsoup4` - - `pandas` - - `opencv-python` - - `python-docx` - - `python-pptx` - - `pymupdf` (Do not use fpdf, reportlab, etc.) - - `mplfonts` - You can only use these libraries and the libraries that they depend on. -2. Do not generate malicious code. -3. Use given `shared.api` package to output the result. - It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`. - For Image and file, you must save it to `output` folder. -4. You must only output the code, do not output the result of the code and any other information. -5. The output language is same as user's input language. -6. Please first provide relevant knowledge about user's problem appropriately. - -## Example -1. User's problem: `please solve the fabonacci sequence problem.` -Output: -```python -from shared.api import send_text, send_image, send_file - -def fabonacci(n): - if n <= 1: - return n - else: - return fabonacci(n-1) + fabonacci(n-2) - -result = fabonacci(10) -send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.") -send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user -``` - -2. User's problem: `please draw a sin(x) function.` -Output: -```python -from shared.api import send_text, send_image, send_file -import numpy as np -import matplotlib.pyplot as plt - -x = np.linspace(0, 2*np.pi, 100) -y = np.sin(x) -plt.plot(x, y) -plt.savefig("output/sin_x.png") -send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).") -send_image("output/sin_x.png") # send_image is a function to send image to user -send_text("If you need more information, please let me know :)") -``` - -{extra_prompt} -""" - -DEFAULT_CONFIG = { - "sandbox": { - "image": "soulter/astrbot-code-interpreter-sandbox", - "docker_mirror": "", # cjie.eu.org - }, - "docker_host_astrbot_abs_path": "", -} -PATH = os.path.join(get_astrbot_data_path(), "config", "python_interpreter.json") - - -class Main(star.Star): - """基于 Docker 沙箱的 Python 代码执行器""" - - def __init__(self, context: star.Context) -> None: - self.context = context - self.curr_dir = os.path.dirname(os.path.abspath(__file__)) - - self.shared_path = os.path.join("data", "py_interpreter_shared") - if not os.path.exists(self.shared_path): - # 复制 api.py 到 shared 目录 - os.makedirs(self.shared_path, exist_ok=True) - shared_api_file = os.path.join(self.curr_dir, "shared", "api.py") - shutil.copy(shared_api_file, self.shared_path) - self.workplace_path = os.path.join("data", "py_interpreter_workplace") - os.makedirs(self.workplace_path, exist_ok=True) - - self.user_file_msg_buffer = defaultdict(list) - """存放用户上传的文件和图片""" - self.user_waiting = {} - """正在等待用户的文件或图片""" - - # 加载配置 - if not os.path.exists(PATH): - self.config = DEFAULT_CONFIG - self._save_config() - else: - with open(PATH) as f: - self.config = json.load(f) - - async def initialize(self): - ok = await self.is_docker_available() - if not ok: - logger.info( - "Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。", - ) - # await self.context._star_manager.turn_off_plugin( - # "astrbot-python-interpreter" - # ) - - async def file_upload(self, file_path: str): - """上传图像文件到 S3""" - ext = os.path.splitext(file_path)[1] - S3_URL = "https://s3.neko.soulter.top/astrbot-s3" - with open(file_path, "rb") as f: - file = f.read() - - s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}" - - async with ( - aiohttp.ClientSession( - headers={"Accept": "application/json"}, - trust_env=True, - ) as session, - session.put(s3_file_url, data=file) as resp, - ): - if resp.status != 200: - raise Exception(f"Failed to upload image: {resp.status}") - return s3_file_url - - async def is_docker_available(self) -> bool: - """Check if docker is available""" - try: - async with aiodocker.Docker() as docker: - await docker.version() - return True - except BaseException as e: - logger.info(f"检查 Docker 可用性: {e}") - return False - - async def get_image_name(self) -> str: - """Get the image name""" - if self.config["sandbox"]["docker_mirror"]: - return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" - return self.config["sandbox"]["image"] - - def _save_config(self): - with open(PATH, "w") as f: - json.dump(self.config, f) - - async def gen_magic_code(self) -> str: - return uuid.uuid4().hex[:8] - - async def download_image( - self, - image_url: str, - workplace_path: str, - filename: str, - ) -> str: - """Download image from url to workplace_path""" - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get(image_url) as resp: - if resp.status != 200: - return "" - image_path = os.path.join(workplace_path, f"{filename}.jpg") - with open(image_path, "wb") as f: - f.write(await resp.read()) - return f"{filename}.jpg" - - async def tidy_code(self, code: str) -> str: - """Tidy the code""" - pattern = r"```(?:py|python)?\n(.*?)\n```" - match = re.search(pattern, code, re.DOTALL) - if match is None: - raise ValueError("The code is not in the code block.") - return match.group(1) - - @filter.event_message_type(filter.EventMessageType.ALL) - async def on_message(self, event: AstrMessageEvent): - """处理消息""" - uid = event.get_sender_id() - if uid not in self.user_waiting: - return - for comp in event.message_obj.message: - if isinstance(comp, File): - file_path = await comp.get_file() - if file_path.startswith("http"): - name = comp.name if comp.name else uuid.uuid4().hex[:8] - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, name) - await download_file(file_path, path) - else: - path = file_path - self.user_file_msg_buffer[event.get_session_id()].append(path) - logger.debug(f"User {uid} uploaded file: {path}") - yield event.plain_result(f"代码执行器: 文件已经上传: {path}") - if uid in self.user_waiting: - del self.user_waiting[uid] - elif isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - if image_url is None: - raise ValueError("Image URL is None") - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - elif image_url.startswith("file:///"): - image_path = image_url.replace("file:///", "") - else: - image_path = image_url - self.user_file_msg_buffer[event.get_session_id()].append(image_path) - logger.debug(f"User {uid} uploaded image: {image_path}") - yield event.plain_result(f"代码执行器: 图片已经上传: {image_path}") - if uid in self.user_waiting: - del self.user_waiting[uid] - - @filter.on_llm_request() - async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): - if event.get_session_id() in self.user_file_msg_buffer: - files = self.user_file_msg_buffer[event.get_session_id()] - if not request.prompt: - request.prompt = "" - request.prompt += f"\nUser provided files: {files}" - - @filter.command_group("pi") - def pi(self): - """代码执行器配置""" - - @pi.command("absdir") - async def pi_absdir(self, event: AstrMessageEvent, path: str = ""): - """设置 Docker 宿主机绝对路径""" - if not path: - yield event.plain_result( - f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}", - ) - else: - self.config["docker_host_astrbot_abs_path"] = path - self._save_config() - yield event.plain_result(f"设置 Docker 宿主机绝对路径成功: {path}") - - @pi.command("mirror") - async def pi_mirror(self, event: AstrMessageEvent, url: str = ""): - """Docker 镜像地址""" - if not url: - yield event.plain_result(f"""当前 Docker 镜像地址: {self.config["sandbox"]["docker_mirror"]}。 -使用 `pi mirror ` 来设置 Docker 镜像地址。 -您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。 -""") - else: - self.config["sandbox"]["docker_mirror"] = url - self._save_config() - yield event.plain_result("设置 Docker 镜像地址成功。") - - @pi.command("repull") - async def pi_repull(self, event: AstrMessageEvent): - """重新拉取沙箱镜像""" - async with aiodocker.Docker() as docker: - image_name = await self.get_image_name() - try: - await docker.images.get(image_name) - await docker.images.delete(image_name, force=True) - except aiodocker.exceptions.DockerError: - pass - await docker.images.pull(image_name) - yield event.plain_result("重新拉取沙箱镜像成功。") - - @pi.command("file") - async def pi_file(self, event: AstrMessageEvent): - """在规定秒数(60s)内上传一个文件""" - uid = event.get_sender_id() - self.user_waiting[uid] = time.time() - tip = "文件" - yield event.plain_result(f"代码执行器: 请在 60s 内上传一个{tip}。") - await asyncio.sleep(60) - if uid in self.user_waiting: - yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。", - ) - self.user_waiting.pop(uid) - - @pi.command("clear", alias=["clean"]) - async def pi_file_clean(self, event: AstrMessageEvent): - """清理用户上传的文件""" - uid = event.get_sender_id() - if uid in self.user_waiting: - self.user_waiting.pop(uid) - yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 已清理。", - ) - else: - yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有等待上传文件。", - ) - - @pi.command("list") - async def pi_file_list(self, event: AstrMessageEvent): - """列出用户上传的文件""" - uid = event.get_sender_id() - if uid in self.user_file_msg_buffer: - files = self.user_file_msg_buffer[uid] - yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 上传的文件: {files}", - ) - else: - yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有上传文件。", - ) - - @llm_tool("python_interpreter") - async def python_interpreter(self, event: AstrMessageEvent): - """Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. - For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc. - """ - if not await self.is_docker_available(): - yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。") - - plain_text = event.message_str - - # 创建必要的工作目录和幻术码 - magic_code = await self.gen_magic_code() - workplace_path = os.path.join(self.workplace_path, magic_code) - output_path = os.path.join(workplace_path, "output") - os.makedirs(workplace_path, exist_ok=True) - os.makedirs(output_path, exist_ok=True) - - files = [] - # 文件 - for file_path in self.user_file_msg_buffer[event.get_session_id()]: - if not file_path: - continue - elif not os.path.exists(file_path): - logger.warning(f"文件 {file_path} 不存在,已忽略。") - continue - # cp - file_name = os.path.basename(file_path) - shutil.copy(file_path, os.path.join(workplace_path, file_name)) - files.append(file_name) - - logger.debug(f"user query: {plain_text}, files: {files}") - - # 整理额外输入 - extra_inputs = "" - if files: - extra_inputs += f"User provided files: {files}\n" - - obs = "" - n = 5 - - async with aiodocker.Docker() as docker: - for i in range(n): - if i > 0: - logger.info(f"Try {i + 1}/{n}") - - PROMPT_ = PROMPT.format( - prompt=plain_text, - extra_input=extra_inputs, - extra_prompt=obs, - ) - provider = self.context.get_using_provider() - llm_response = await provider.text_chat( - prompt=PROMPT_, - session_id=f"{event.session_id}_{magic_code}_{i!s}", - ) - - logger.debug( - "code interpreter llm gened code:" + llm_response.completion_text, - ) - - # 整理代码并保存 - code_clean = await self.tidy_code(llm_response.completion_text) - with open(os.path.join(workplace_path, "exec.py"), "w") as f: - f.write(code_clean) - - # 检查有没有image - image_name = await self.get_image_name() - try: - await docker.images.get(image_name) - except aiodocker.exceptions.DockerError: - # 拉取镜像 - logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...") - await docker.images.pull(image_name) - - yield event.plain_result( - f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})", - ) - - self.docker_host_astrbot_abs_path = self.config.get( - "docker_host_astrbot_abs_path", - "", - ) - if self.docker_host_astrbot_abs_path: - host_shared = os.path.join( - self.docker_host_astrbot_abs_path, - self.shared_path, - ) - host_output = os.path.join( - self.docker_host_astrbot_abs_path, - output_path, - ) - host_workplace = os.path.join( - self.docker_host_astrbot_abs_path, - workplace_path, - ) - - else: - host_shared = os.path.abspath(self.shared_path) - host_output = os.path.abspath(output_path) - host_workplace = os.path.abspath(workplace_path) - - logger.debug( - f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}", - ) - - container = await docker.containers.run( - { - "Image": image_name, - "Cmd": ["python", "exec.py"], - "Memory": 512 * 1024 * 1024, - "NanoCPUs": 1000000000, - "HostConfig": { - "Binds": [ - f"{host_shared}:/astrbot_sandbox/shared:ro", - f"{host_output}:/astrbot_sandbox/output:rw", - f"{host_workplace}:/astrbot_sandbox:rw", - ], - }, - "Env": [f"MAGIC_CODE={magic_code}"], - "AutoRemove": True, - }, - ) - - logger.debug(f"Container {container.id} created.") - logs = await self.run_container(container) - - logger.debug(f"Container {container.id} finished.") - logger.debug(f"Container {container.id} logs: {logs}") - - # 发送结果 - pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)" - ok = False - traceback = "" - for idx, log in enumerate(logs): - match = re.match(pattern, log) - if match: - ok = True - if match.group(1) == "TEXT": - yield event.plain_result(match.group(2)) - elif match.group(1) == "IMAGE": - image_path = os.path.join(workplace_path, match.group(2)) - logger.debug(f"Sending image: {image_path}") - yield event.image_result(image_path) - elif match.group(1) == "FILE": - file_path = os.path.join(workplace_path, match.group(2)) - # logger.debug(f"Sending file: {file_path}") - # file_s3_url = await self.file_upload(file_path) - # logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}") - file_name = os.path.basename(file_path) - chain: list[BaseMessageComponent] = [ - File(name=file_name, file=file_path) - ] - yield event.set_result(MessageEventResult(chain=chain)) - - elif ( - "Traceback (most recent call last)" in log or "[Error]: " in log - ): - traceback = "\n".join(logs[idx:]) - - if not ok: - if traceback: - obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code." - else: - logger.warning( - f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}", - ) - break - else: - # 成功了 - self.user_file_msg_buffer.pop(event.get_session_id()) - return - - yield event.plain_result( - "经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。", - ) - - @pi.command("cleanfile") - async def pi_cleanfile(self, event: AstrMessageEvent): - """清理用户上传的文件""" - for file in self.user_file_msg_buffer[event.get_session_id()]: - try: - os.remove(file) - except BaseException as e: - logger.error(f"删除文件 {file} 失败: {e}") - - self.user_file_msg_buffer.pop(event.get_session_id()) - yield event.plain_result(f"用户 {event.get_session_id()} 上传的文件已清理。") - - async def run_container( - self, - container: aiodocker.docker.DockerContainer, - timeout: int = 20, - ) -> list[str]: - """Run the container and get the output""" - try: - await container.wait(timeout=timeout) - logs = await container.log(stdout=True, stderr=True) - return logs - except asyncio.TimeoutError: - logger.warning(f"Container {container.id} timeout.") - await container.kill() - return [f"[Error]: Container has been killed due to timeout ({timeout}s)."] - finally: - await container.delete() diff --git a/astrbot/builtin_stars/python_interpreter/metadata.yaml b/astrbot/builtin_stars/python_interpreter/metadata.yaml deleted file mode 100644 index 4378f0ada..000000000 --- a/astrbot/builtin_stars/python_interpreter/metadata.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: astrbot-python-interpreter -desc: Python 代码执行器 -author: Soulter -version: 0.0.1 \ No newline at end of file diff --git a/astrbot/builtin_stars/python_interpreter/requirements.txt b/astrbot/builtin_stars/python_interpreter/requirements.txt deleted file mode 100644 index 44a8f5435..000000000 --- a/astrbot/builtin_stars/python_interpreter/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -aiodocker \ No newline at end of file diff --git a/astrbot/builtin_stars/python_interpreter/shared/api.py b/astrbot/builtin_stars/python_interpreter/shared/api.py deleted file mode 100644 index 287773fb0..000000000 --- a/astrbot/builtin_stars/python_interpreter/shared/api.py +++ /dev/null @@ -1,22 +0,0 @@ -import os - - -def _get_magic_code(): - """防止注入攻击""" - return os.getenv("MAGIC_CODE") - - -def send_text(text: str): - print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}") - - -def send_image(image_path: str): - if not os.path.exists(image_path): - raise Exception(f"Image file not found: {image_path}") - print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") - - -def send_file(file_path: str): - if not os.path.exists(file_path): - raise Exception(f"File not found: {file_path}") - print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 4a00dad41..1f2d067f8 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -113,6 +113,14 @@ "provider": "moonshotai", "moonshotai_api_key": "", }, + "sandbox": { + "enable": False, + "booter": "shipyard", + "shipyard_endpoint": "", + "shipyard_access_token": "", + "shipyard_ttl": 3600, + "shipyard_max_sessions": 10, + }, }, "provider_stt_settings": { "enable": False, @@ -2539,6 +2547,62 @@ class ChatProviderTemplate(TypedDict): # "provider_settings.enable": True, # }, # }, + "sandbox": { + "description": "Agent 沙箱环境", + "type": "object", + "items": { + "provider_settings.sandbox.enable": { + "description": "启用沙箱环境", + "type": "bool", + "hint": "启用后,Agent 可以使用沙箱环境中的工具和资源,如 Python 代码执行、Shell 等。", + }, + "provider_settings.sandbox.booter": { + "description": "沙箱环境驱动器", + "type": "string", + "options": ["shipyard"], + "condition": { + "provider_settings.sandbox.enable": True, + }, + }, + "provider_settings.sandbox.shipyard_endpoint": { + "description": "Shipyard API Endpoint", + "type": "string", + "hint": "Shipyard 服务的 API 访问地址。", + "condition": { + "provider_settings.sandbox.enable": True, + "provider_settings.sandbox.booter": "shipyard", + }, + "_special": "check_shipyard_connection", + }, + "provider_settings.sandbox.shipyard_access_token": { + "description": "Shipyard Access Token", + "type": "string", + "hint": "用于访问 Shipyard 服务的访问令牌。", + "condition": { + "provider_settings.sandbox.enable": True, + "provider_settings.sandbox.booter": "shipyard", + }, + }, + "provider_settings.sandbox.shipyard_ttl": { + "description": "Shipyard Session TTL", + "type": "int", + "hint": "Shipyard 会话的生存时间(秒)。", + "condition": { + "provider_settings.sandbox.enable": True, + "provider_settings.sandbox.booter": "shipyard", + }, + }, + "provider_settings.sandbox.shipyard_max_sessions": { + "description": "Shipyard Max Sessions", + "type": "int", + "hint": "Shipyard 最大会话数量。", + "condition": { + "provider_settings.sandbox.enable": True, + "provider_settings.sandbox.booter": "shipyard", + }, + }, + }, + }, "truncate_and_compress": { "description": "上下文管理策略", "type": "object", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 09fd1409d..3e39dc8a1 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -2,10 +2,11 @@ import asyncio import json +import os from collections.abc import AsyncGenerator from astrbot.core import logger -from astrbot.core.agent.message import Message +from astrbot.core.agent.message import Message, TextPart from astrbot.core.agent.response import AgentStats from astrbot.core.agent.tool import ToolSet from astrbot.core.astr_agent_context import AstrAgentContext @@ -35,8 +36,13 @@ from ....context import PipelineContext, call_event_hook from ...stage import Stage from ...utils import ( + EXECUTE_SHELL_TOOL, + FILE_DOWNLOAD_TOOL, + FILE_UPLOAD_TOOL, KNOWLEDGE_BASE_QUERY_TOOL, LLM_SAFETY_MODE_SYSTEM_PROMPT, + PYTHON_TOOL, + SANDBOX_MODE_PROMPT, decoded_blocked, retrieve_knowledge_base, ) @@ -94,6 +100,8 @@ async def initialize(self, ctx: PipelineContext) -> None: "safety_mode_strategy", "system_prompt" ) + self.sandbox_cfg = settings.get("sandbox", {}) + self.conv_manager = ctx.plugin_manager.context.conversation_manager def _select_provider(self, event: AstrMessageEvent): @@ -458,6 +466,24 @@ def _apply_llm_safety_mode(self, req: ProviderRequest) -> None: f"Unsupported llm_safety_mode strategy: {self.safety_mode_strategy}.", ) + def _apply_sandbox_tools(self, req: ProviderRequest, session_id: str) -> None: + """Add sandbox tools to the provider request.""" + if req.func_tool is None: + req.func_tool = ToolSet() + if self.sandbox_cfg.get("booter") == "shipyard": + ep = self.sandbox_cfg.get("shipyard_endpoint", "") + at = self.sandbox_cfg.get("shipyard_access_token", "") + if not ep or not at: + logger.error("Shipyard sandbox configuration is incomplete.") + return + os.environ["SHIPYARD_ENDPOINT"] = ep + os.environ["SHIPYARD_ACCESS_TOKEN"] = at + req.func_tool.add_tool(EXECUTE_SHELL_TOOL) + req.func_tool.add_tool(PYTHON_TOOL) + req.func_tool.add_tool(FILE_UPLOAD_TOOL) + req.func_tool.add_tool(FILE_DOWNLOAD_TOOL) + req.system_prompt += f"\n{SANDBOX_MODE_PROMPT}\n" + async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: @@ -536,6 +562,20 @@ async def process( image_path = await comp.convert_to_file_path() req.image_urls.append(image_path) + req.extra_user_content_parts.append( + TextPart(text=f"[Image Attachment: path {image_path}]") + ) + elif isinstance(comp, File) and self.sandbox_cfg.get( + "enable", False + ): + file_path = await comp.get_file() + file_name = comp.name or os.path.basename(file_path) + req.extra_user_content_parts.append( + TextPart( + text=f"[File Attachment: name {file_name}, path {file_path}]" + ) + ) + conversation = await self._get_session_conv(event) req.conversation = conversation req.contexts = json.loads(conversation.history) @@ -586,6 +626,10 @@ async def process( if self.llm_safety_mode: self._apply_llm_safety_mode(req) + # apply sandbox tools + if self.sandbox_cfg.get("enable", False): + self._apply_sandbox_tools(req, req.session_id) + stream_to_general = ( self.unsupported_streaming_strategy == "turn_off" and not event.platform_meta.support_streaming_message diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index 107d9d640..d1dd22139 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -7,6 +7,12 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.sandbox.tools import ( + ExecuteShellTool, + FileDownloadTool, + FileUploadTool, + PythonTool, +) from astrbot.core.star.context import Context LLM_SAFETY_MODE_SYSTEM_PROMPT = """You are running in Safe Mode. @@ -21,6 +27,20 @@ - Output same language as the user's input. """ +SANDBOX_MODE_PROMPT = ( + "You have access to a sandboxed environment and can execute shell commands and Python code securely." + # "Your have extended skills library, such as PDF processing, image generation, data analysis, etc. " + # "Before handling complex tasks, please retrieve and review the documentation in the in /app/skills/ directory. " + # "If the current task matches the description of a specific skill, prioritize following the workflow defined by that skill." + # "Use `ls /app/skills/` to list all available skills. " + # "Use `cat /app/skills/{skill_name}/SKILL.md` to read the documentation of a specific skill." + # "SKILL.md might be large, you can read the description first, which is located in the YAML frontmatter of the file." + # "Use shell commands such as grep, sed, awk to extract relevant information from the documentation as needed.\n" + "Note:\n" + "1. If you use shell, your command will always runs in the /home//workspace directory.\n" + "2. If you use IPython, you would better use absolute paths when dealing with files to avoid confusion.\n" +) + @dataclass class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): @@ -138,6 +158,11 @@ async def retrieve_knowledge_base( KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool() +EXECUTE_SHELL_TOOL = ExecuteShellTool() +PYTHON_TOOL = PythonTool() +FILE_UPLOAD_TOOL = FileUploadTool() +FILE_DOWNLOAD_TOOL = FileDownloadTool() + # we prevent astrbot from connecting to known malicious hosts # these hosts are base64 encoded BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"} diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 1ad68136e..e799e396e 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -93,7 +93,8 @@ async def send_by_session( session: MessageSesion, message_chain: MessageChain, ): - await WebChatMessageEvent._send(message_chain, session.session_id) + message_id = f"active_{str(uuid.uuid4())}" + await WebChatMessageEvent._send(message_id, message_chain, session.session_id) await super().send_by_session(session, message_chain) async def _get_message_history( @@ -196,7 +197,7 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: abm.session_id = f"webchat!{username}!{cid}" - abm.message_id = str(uuid.uuid4()) + abm.message_id = payload.get("message_id") # 处理消息段列表 message_parts = payload.get("message", []) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 2e529bb1d..7d1c966e4 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -21,7 +21,10 @@ def __init__(self, message_str, message_obj, platform_meta, session_id): @staticmethod async def _send( - message: MessageChain | None, session_id: str, streaming: bool = False + message_id: str, + message: MessageChain | None, + session_id: str, + streaming: bool = False, ) -> str | None: cid = session_id.split("!")[-1] web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) @@ -31,6 +34,7 @@ async def _send( "type": "end", "data": "", "streaming": False, + "message_id": message_id, }, # end means this request is finished ) return @@ -45,6 +49,7 @@ async def _send( "data": data, "streaming": streaming, "chain_type": message.type, + "message_id": message_id, }, ) elif isinstance(comp, Json): @@ -54,6 +59,7 @@ async def _send( "data": json.dumps(comp.data, ensure_ascii=False), "streaming": streaming, "chain_type": message.type, + "message_id": message_id, }, ) elif isinstance(comp, Image): @@ -69,6 +75,7 @@ async def _send( "type": "image", "data": data, "streaming": streaming, + "message_id": message_id, }, ) elif isinstance(comp, Record): @@ -84,6 +91,7 @@ async def _send( "type": "record", "data": data, "streaming": streaming, + "message_id": message_id, }, ) elif isinstance(comp, File): @@ -94,12 +102,13 @@ async def _send( filename = f"{uuid.uuid4()!s}{ext}" dest_path = os.path.join(imgs_dir, filename) shutil.copy2(file_path, dest_path) - data = f"[FILE]{filename}|{original_name}" + data = f"[FILE]{filename}" await web_chat_back_queue.put( { "type": "file", "data": data, "streaming": streaming, + "message_id": message_id, }, ) else: @@ -108,7 +117,8 @@ async def _send( return data async def send(self, message: MessageChain | None): - await WebChatMessageEvent._send(message, session_id=self.session_id) + message_id = self.message_obj.message_id + await WebChatMessageEvent._send(message_id, message, session_id=self.session_id) await super().send(MessageChain([])) async def send_streaming(self, generator, use_fallback: bool = False): @@ -116,6 +126,7 @@ async def send_streaming(self, generator, use_fallback: bool = False): reasoning_content = "" cid = self.session_id.split("!")[-1] web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) + message_id = self.message_obj.message_id async for chain in generator: # if chain.type == "break" and final_data: # # 分割符 @@ -130,7 +141,8 @@ async def send_streaming(self, generator, use_fallback: bool = False): # continue r = await WebChatMessageEvent._send( - chain, + message_id=message_id, + message=chain, session_id=self.session_id, streaming=True, ) @@ -147,6 +159,7 @@ async def send_streaming(self, generator, use_fallback: bool = False): "data": final_data, "reasoning": reasoning_content, "streaming": True, + "message_id": message_id, }, ) await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/sandbox/booters/base.py b/astrbot/core/sandbox/booters/base.py new file mode 100644 index 000000000..e906bdf65 --- /dev/null +++ b/astrbot/core/sandbox/booters/base.py @@ -0,0 +1,31 @@ +from ..olayer import FileSystemComponent, PythonComponent, ShellComponent + + +class SandboxBooter: + @property + def fs(self) -> FileSystemComponent: ... + + @property + def python(self) -> PythonComponent: ... + + @property + def shell(self) -> ShellComponent: ... + + async def boot(self, session_id: str) -> None: ... + + async def shutdown(self) -> None: ... + + async def upload_file(self, path: str, file_name: str) -> dict: + """Upload file to sandbox. + + Should return a dict with `success` (bool) and `file_path` (str) keys. + """ + ... + + async def download_file(self, remote_path: str, local_path: str): + """Download file from sandbox.""" + ... + + async def available(self) -> bool: + """Check if the sandbox is available.""" + ... diff --git a/astrbot/core/sandbox/booters/boxlite.py b/astrbot/core/sandbox/booters/boxlite.py new file mode 100644 index 000000000..4d481c49e --- /dev/null +++ b/astrbot/core/sandbox/booters/boxlite.py @@ -0,0 +1,186 @@ +import asyncio +import random +from typing import Any + +import aiohttp +import boxlite +from shipyard.filesystem import FileSystemComponent as ShipyardFileSystemComponent +from shipyard.python import PythonComponent as ShipyardPythonComponent +from shipyard.shell import ShellComponent as ShipyardShellComponent + +from astrbot.api import logger + +from ..olayer import FileSystemComponent, PythonComponent, ShellComponent +from .base import SandboxBooter + + +class MockShipyardSandboxClient: + def __init__(self, sb_url: str) -> None: + self.sb_url = sb_url.rstrip("/") + + async def _exec_operation( + self, + ship_id: str, + operation_type: str, + payload: dict[str, Any], + session_id: str, + ) -> dict[str, Any]: + async with aiohttp.ClientSession() as session: + headers = {"X-SESSION-ID": session_id} + async with session.post( + f"{self.sb_url}/{operation_type}", + json=payload, + headers=headers, + ) as response: + if response.status == 200: + return await response.json() + else: + error_text = await response.text() + raise Exception( + f"Failed to exec operation: {response.status} {error_text}" + ) + + async def upload_file(self, path: str, remote_path: str) -> dict: + """Upload a file to the sandbox""" + url = f"http://{self.sb_url}/upload" + + try: + # Read file content + with open(path, "rb") as f: + file_content = f.read() + + # Create multipart form data + data = aiohttp.FormData() + data.add_field( + "file", + file_content, + filename=remote_path.split("/")[-1], + content_type="application/octet-stream", + ) + data.add_field("file_path", remote_path) + + timeout = aiohttp.ClientTimeout(total=120) # 2 minutes for file upload + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, data=data) as response: + if response.status == 200: + return { + "success": True, + "message": "File uploaded successfully", + "file_path": remote_path, + } + else: + error_text = await response.text() + return { + "success": False, + "error": f"Server returned {response.status}: {error_text}", + "message": "File upload failed", + } + + except aiohttp.ClientError as e: + logger.error(f"Failed to upload file: {e}") + return { + "success": False, + "error": f"Connection error: {str(e)}", + "message": "File upload failed", + } + except asyncio.TimeoutError: + return { + "success": False, + "error": "File upload timeout", + "message": "File upload failed", + } + except FileNotFoundError: + logger.error(f"File not found: {path}") + return { + "success": False, + "error": f"File not found: {path}", + "message": "File upload failed", + } + except Exception as e: + logger.error(f"Unexpected error uploading file: {e}") + return { + "success": False, + "error": f"Internal error: {str(e)}", + "message": "File upload failed", + } + + async def wait_healthy(self, ship_id: str, session_id: str) -> None: + """Mock wait healthy""" + loop = 60 + while loop > 0: + try: + logger.info( + f"Checking health for sandbox {ship_id} on {self.sb_url}..." + ) + url = f"{self.sb_url}/health" + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + logger.info(f"Sandbox {ship_id} is healthy") + return + except Exception: + await asyncio.sleep(1) + loop -= 1 + + +class BoxliteBooter(SandboxBooter): + async def boot(self, session_id: str) -> None: + logger.info( + f"Booting(Boxlite) for session: {session_id}, this may take a while..." + ) + random_port = random.randint(20000, 30000) + self.box = boxlite.SimpleBox( + image="soulter/shipyard-ship", + memory_mib=512, + cpus=1, + ports=[ + { + "host_port": random_port, + "guest_port": 8123, + } + ], + ) + await self.box.start() + logger.info(f"Boxlite booter started for session: {session_id}") + self.mocked = MockShipyardSandboxClient( + sb_url=f"http://127.0.0.1:{random_port}" + ) + self._fs = ShipyardFileSystemComponent( + client=self.mocked, # type: ignore + ship_id=self.box.id, + session_id=session_id, + ) + self._python = ShipyardPythonComponent( + client=self.mocked, # type: ignore + ship_id=self.box.id, + session_id=session_id, + ) + self._shell = ShipyardShellComponent( + client=self.mocked, # type: ignore + ship_id=self.box.id, + session_id=session_id, + ) + + await self.mocked.wait_healthy(self.box.id, session_id) + + async def shutdown(self) -> None: + logger.info(f"Shutting down Boxlite booter for ship: {self.box.id}") + self.box.shutdown() + logger.info(f"Boxlite booter for ship: {self.box.id} stopped") + + @property + def fs(self) -> FileSystemComponent: + return self._fs + + @property + def python(self) -> PythonComponent: + return self._python + + @property + def shell(self) -> ShellComponent: + return self._shell + + async def upload_file(self, path: str, file_name: str) -> dict: + """Upload file to sandbox""" + return await self.mocked.upload_file(path, file_name) diff --git a/astrbot/core/sandbox/booters/shipyard.py b/astrbot/core/sandbox/booters/shipyard.py new file mode 100644 index 000000000..5ca81af23 --- /dev/null +++ b/astrbot/core/sandbox/booters/shipyard.py @@ -0,0 +1,67 @@ +from shipyard import ShipyardClient, Spec + +from astrbot.api import logger + +from ..olayer import FileSystemComponent, PythonComponent, ShellComponent +from .base import SandboxBooter + + +class ShipyardBooter(SandboxBooter): + def __init__( + self, + endpoint_url: str, + access_token: str, + ttl: int = 3600, + session_num: int = 10, + ) -> None: + self._sandbox_client = ShipyardClient( + endpoint_url=endpoint_url, access_token=access_token + ) + self._ttl = ttl + self._session_num = session_num + + async def boot(self, session_id: str) -> None: + ship = await self._sandbox_client.create_ship( + ttl=self._ttl, + spec=Spec(cpus=1.0, memory="512m"), + max_session_num=self._session_num, + session_id=session_id, + ) + logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}") + self._ship = ship + + async def shutdown(self) -> None: + pass + + @property + def fs(self) -> FileSystemComponent: + return self._ship.fs + + @property + def python(self) -> PythonComponent: + return self._ship.python + + @property + def shell(self) -> ShellComponent: + return self._ship.shell + + async def upload_file(self, path: str, file_name: str) -> dict: + """Upload file to sandbox""" + return await self._ship.upload_file(path, file_name) + + async def download_file(self, remote_path: str, local_path: str): + """Download file from sandbox.""" + return await self._ship.download_file(remote_path, local_path) + + async def available(self) -> bool: + """Check if the sandbox is available.""" + try: + ship_id = self._ship.id + data = await self._sandbox_client.get_ship(ship_id) + if not data: + return False + health = bool(data.get("status", 0) == 1) + return health + except Exception as e: + logger.error(f"Error checking Shipyard sandbox availability: {e}") + return False diff --git a/astrbot/core/sandbox/olayer/__init__.py b/astrbot/core/sandbox/olayer/__init__.py new file mode 100644 index 000000000..f099c079a --- /dev/null +++ b/astrbot/core/sandbox/olayer/__init__.py @@ -0,0 +1,5 @@ +from .filesystem import FileSystemComponent +from .python import PythonComponent +from .shell import ShellComponent + +__all__ = ["PythonComponent", "ShellComponent", "FileSystemComponent"] diff --git a/astrbot/core/sandbox/olayer/filesystem.py b/astrbot/core/sandbox/olayer/filesystem.py new file mode 100644 index 000000000..21f36d111 --- /dev/null +++ b/astrbot/core/sandbox/olayer/filesystem.py @@ -0,0 +1,33 @@ +""" +File system component +""" + +from typing import Any, Protocol + + +class FileSystemComponent(Protocol): + async def create_file( + self, path: str, content: str = "", mode: int = 0o644 + ) -> dict[str, Any]: + """Create a file with the specified content""" + ... + + async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]: + """Read file content""" + ... + + async def write_file( + self, path: str, content: str, mode: str = "w", encoding: str = "utf-8" + ) -> dict[str, Any]: + """Write content to file""" + ... + + async def delete_file(self, path: str) -> dict[str, Any]: + """Delete file or directory""" + ... + + async def list_dir( + self, path: str = ".", show_hidden: bool = False + ) -> dict[str, Any]: + """List directory contents""" + ... diff --git a/astrbot/core/sandbox/olayer/python.py b/astrbot/core/sandbox/olayer/python.py new file mode 100644 index 000000000..625504146 --- /dev/null +++ b/astrbot/core/sandbox/olayer/python.py @@ -0,0 +1,19 @@ +""" +Python/IPython component +""" + +from typing import Any, Protocol + + +class PythonComponent(Protocol): + """Python/IPython operations component""" + + async def exec( + self, + code: str, + kernel_id: str | None = None, + timeout: int = 30, + silent: bool = False, + ) -> dict[str, Any]: + """Execute Python code""" + ... diff --git a/astrbot/core/sandbox/olayer/shell.py b/astrbot/core/sandbox/olayer/shell.py new file mode 100644 index 000000000..df2263b65 --- /dev/null +++ b/astrbot/core/sandbox/olayer/shell.py @@ -0,0 +1,21 @@ +""" +Shell component +""" + +from typing import Any, Protocol + + +class ShellComponent(Protocol): + """Shell operations component""" + + async def exec( + self, + command: str, + cwd: str | None = None, + env: dict[str, str] | None = None, + timeout: int | None = 30, + shell: bool = True, + background: bool = False, + ) -> dict[str, Any]: + """Execute shell command""" + ... diff --git a/astrbot/core/sandbox/sandbox_client.py b/astrbot/core/sandbox/sandbox_client.py new file mode 100644 index 000000000..f9937bbc1 --- /dev/null +++ b/astrbot/core/sandbox/sandbox_client.py @@ -0,0 +1,52 @@ +import uuid + +from astrbot.api import logger +from astrbot.core.star.context import Context + +from .booters.base import SandboxBooter + +session_booter: dict[str, SandboxBooter] = {} + + +async def get_booter( + context: Context, + session_id: str, +) -> SandboxBooter: + config = context.get_config(umo=session_id) + + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + booter_type = sandbox_cfg.get("booter", "shipyard") + + if session_id in session_booter: + booter = session_booter[session_id] + if not await booter.available(): + # rebuild + session_booter.pop(session_id, None) + if session_id not in session_booter: + uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex + if booter_type == "shipyard": + from .booters.shipyard import ShipyardBooter + + ep = sandbox_cfg.get("shipyard_endpoint", "") + token = sandbox_cfg.get("shipyard_access_token", "") + ttl = sandbox_cfg.get("shipyard_ttl", 3600) + max_sessions = sandbox_cfg.get("shipyard_max_sessions", 10) + + client = ShipyardBooter( + endpoint_url=ep, access_token=token, ttl=ttl, session_num=max_sessions + ) + elif booter_type == "boxlite": + from .booters.boxlite import BoxliteBooter + + client = BoxliteBooter() + else: + raise ValueError(f"Unknown booter type: {booter_type}") + + try: + await client.boot(uuid_str) + except Exception as e: + logger.error(f"Error booting sandbox for session {session_id}: {e}") + raise e + + session_booter[session_id] = client + return session_booter[session_id] diff --git a/astrbot/core/sandbox/tools/__init__.py b/astrbot/core/sandbox/tools/__init__.py new file mode 100644 index 000000000..8dacab79f --- /dev/null +++ b/astrbot/core/sandbox/tools/__init__.py @@ -0,0 +1,10 @@ +from .fs import FileDownloadTool, FileUploadTool +from .python import PythonTool +from .shell import ExecuteShellTool + +__all__ = [ + "FileUploadTool", + "PythonTool", + "ExecuteShellTool", + "FileDownloadTool", +] diff --git a/astrbot/core/sandbox/tools/fs.py b/astrbot/core/sandbox/tools/fs.py new file mode 100644 index 000000000..8f12a2f78 --- /dev/null +++ b/astrbot/core/sandbox/tools/fs.py @@ -0,0 +1,188 @@ +import os +from dataclasses import dataclass, field + +from astrbot.api import FunctionTool, logger +from astrbot.api.event import MessageChain +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.message.components import File +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..sandbox_client import get_booter + +# @dataclass +# class CreateFileTool(FunctionTool): +# name: str = "astrbot_create_file" +# description: str = "Create a new file in the sandbox." +# parameters: dict = field( +# default_factory=lambda: { +# "type": "object", +# "properties": { +# "path": { +# "path": "string", +# "description": "The path where the file should be created, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.", +# }, +# "content": { +# "type": "string", +# "description": "The content to write into the file.", +# }, +# }, +# "required": ["path", "content"], +# } +# ) + +# async def call( +# self, context: ContextWrapper[AstrAgentContext], path: str, content: str +# ) -> ToolExecResult: +# sb = await get_booter( +# context.context.context, +# context.context.event.unified_msg_origin, +# ) +# try: +# result = await sb.fs.create_file(path, content) +# return json.dumps(result) +# except Exception as e: +# return f"Error creating file: {str(e)}" + + +# @dataclass +# class ReadFileTool(FunctionTool): +# name: str = "astrbot_read_file" +# description: str = "Read the content of a file in the sandbox." +# parameters: dict = field( +# default_factory=lambda: { +# "type": "object", +# "properties": { +# "path": { +# "type": "string", +# "description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.", +# }, +# }, +# "required": ["path"], +# } +# ) + +# async def call(self, context: ContextWrapper[AstrAgentContext], path: str): +# sb = await get_booter( +# context.context.context, +# context.context.event.unified_msg_origin, +# ) +# try: +# result = await sb.fs.read_file(path) +# return result +# except Exception as e: +# return f"Error reading file: {str(e)}" + + +@dataclass +class FileUploadTool(FunctionTool): + name: str = "astrbot_upload_file" + description: str = "Upload a local file to the sandbox. The file must exist on the local filesystem." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "local_path": { + "type": "string", + "description": "The local file path to upload. This must be an absolute path to an existing file on the local filesystem.", + }, + # "remote_path": { + # "type": "string", + # "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.", + # }, + }, + "required": ["local_path"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + local_path: str, + ): + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + try: + # Check if file exists + if not os.path.exists(local_path): + return f"Error: File does not exist: {local_path}" + + if not os.path.isfile(local_path): + return f"Error: Path is not a file: {local_path}" + + # Use basename if sandbox_filename is not provided + remote_path = os.path.basename(local_path) + + # Upload file to sandbox + result = await sb.upload_file(local_path, remote_path) + logger.debug(f"Upload result: {result}") + success = result.get("success", False) + + if not success: + return f"Error uploading file: {result.get('message', 'Unknown error')}" + + file_path = result.get("file_path", "") + logger.info(f"File {local_path} uploaded to sandbox at {file_path}") + + return f"File uploaded successfully to {file_path}" + except Exception as e: + logger.error(f"Error uploading file {local_path}: {e}") + return f"Error uploading file: {str(e)}" + + +@dataclass +class FileDownloadTool(FunctionTool): + name: str = "astrbot_download_file" + description: str = "Download a file from the sandbox. Only call this when user explicitly need you to download a file." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "remote_path": { + "type": "string", + "description": "The path of the file in the sandbox to download.", + } + }, + "required": ["remote_path"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + remote_path: str, + ) -> ToolExecResult: + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + try: + name = os.path.basename(remote_path) + + local_path = os.path.join(get_astrbot_temp_path(), name) + + # Download file from sandbox + await sb.download_file(remote_path, local_path) + logger.info(f"File {remote_path} downloaded from sandbox to {local_path}") + + try: + name = os.path.basename(local_path) + await context.context.event.send( + MessageChain(chain=[File(name=name, file=local_path)]) + ) + except Exception as e: + logger.error(f"Error sending file message: {e}") + + # remove + try: + os.remove(local_path) + except Exception as e: + logger.error(f"Error removing temp file {local_path}: {e}") + + return f"File downloaded successfully to {local_path}" + except Exception as e: + logger.error(f"Error downloading file {remote_path}: {e}") + return f"Error downloading file: {str(e)}" diff --git a/astrbot/core/sandbox/tools/python.py b/astrbot/core/sandbox/tools/python.py new file mode 100644 index 000000000..84825e2a1 --- /dev/null +++ b/astrbot/core/sandbox/tools/python.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass, field + +import mcp + +from astrbot.api import FunctionTool +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.sandbox.sandbox_client import get_booter + + +@dataclass +class PythonTool(FunctionTool): + name: str = "astrbot_execute_ipython" + description: str = "Execute a command in an IPython shell." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The Python code to execute.", + }, + "silent": { + "type": "boolean", + "description": "Whether to suppress the output of the code execution.", + "default": False, + }, + }, + "required": ["code"], + } + ) + + async def call( + self, context: ContextWrapper[AstrAgentContext], code: str, silent: bool = False + ) -> ToolExecResult: + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + try: + result = await sb.python.exec(code, silent=silent) + data = result.get("data", {}) + output = data.get("output", {}) + error = data.get("error", "") + images: list[dict] = output.get("images", []) + text: str = output.get("text", "") + + resp = mcp.types.CallToolResult(content=[]) + + if error: + resp.content.append( + mcp.types.TextContent(type="text", text=f"error: {error}") + ) + + if images: + for img in images: + resp.content.append( + mcp.types.ImageContent( + type="image", data=img["image/png"], mimeType="image/png" + ) + ) + if text: + resp.content.append(mcp.types.TextContent(type="text", text=text)) + + if not resp.content: + resp.content.append( + mcp.types.TextContent(type="text", text="No output.") + ) + + return resp + + except Exception as e: + return f"Error executing code: {str(e)}" diff --git a/astrbot/core/sandbox/tools/shell.py b/astrbot/core/sandbox/tools/shell.py new file mode 100644 index 000000000..88c1a25a5 --- /dev/null +++ b/astrbot/core/sandbox/tools/shell.py @@ -0,0 +1,55 @@ +import json +from dataclasses import dataclass, field + +from astrbot.api import FunctionTool +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext + +from ..sandbox_client import get_booter + + +@dataclass +class ExecuteShellTool(FunctionTool): + name: str = "astrbot_execute_shell" + description: str = "Execute a command in the shell." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute. Equal to 'cd {working_dir} && {your_command}'.", + }, + "background": { + "type": "boolean", + "description": "Whether to run the command in the background.", + "default": False, + }, + "env": { + "type": "object", + "description": "Optional environment variables to set for the file creation process.", + "additionalProperties": {"type": "string"}, + "default": {}, + }, + }, + "required": ["command"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + command: str, + background: bool = False, + env: dict = {}, + ) -> ToolExecResult: + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + try: + result = await sb.shell.exec(command, background=background, env=env) + return json.dumps(result) + except Exception as e: + return f"Error executing command: {str(e)}" diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index c42bc4f64..de12daab9 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -296,6 +296,8 @@ async def chat(self): # 构建用户消息段(包含 path 用于传递给 adapter) message_parts = await self._build_user_message_parts(message) + message_id = str(uuid.uuid4()) + async def stream(): client_disconnected = False accumulated_parts = [] @@ -319,6 +321,13 @@ async def stream(): if not result: continue + if ( + "message_id" in result + and result["message_id"] != message_id + ): + logger.warning("webchat stream message_id mismatch") + continue + result_text = result["data"] msg_type = result.get("type") streaming = result.get("streaming", False) @@ -456,6 +465,7 @@ async def stream(): "selected_provider": selected_provider, "selected_model": selected_model, "enable_streaming": enable_streaming, + "message_id": message_id, }, ), ) diff --git a/dashboard/package.json b/dashboard/package.json index d4c0ef485..87bacdab7 100644 --- a/dashboard/package.json +++ b/dashboard/package.json @@ -28,7 +28,7 @@ "katex": "^0.16.27", "lodash": "4.17.21", "markdown-it": "^14.1.0", - "markstream-vue": "0.0.3-beta.7", + "markstream-vue": "^0.0.6-beta.1", "mermaid": "^11.12.2", "pinia": "2.1.6", "pinyin-pro": "^3.26.0", diff --git a/dashboard/src/components/chat/MessageList.vue b/dashboard/src/components/chat/MessageList.vue index 243758597..987827bf0 100644 --- a/dashboard/src/components/chat/MessageList.vue +++ b/dashboard/src/components/chat/MessageList.vue @@ -28,7 +28,7 @@
+ @click="openImagePreview(part.embedded_url)" />
@@ -147,24 +147,44 @@ borderTopColor: 'rgba(100, 140, 200, 0.3)', backgroundColor: 'rgba(30, 45, 70, 0.5)' } : {}"> -
- ID: - {{ toolCall.id - }} -
-
- Args: -
{{
-                                                            JSON.stringify(toolCall.args, null, 2) }}
-
-
- Result: -
{{ formatToolResult(toolCall.result) }}
+                                                
+                                                
+                                                
+                                                
+                                                
                                             
@@ -178,7 +198,7 @@
+ @click="openImagePreview(part.embedded_url)" />
@@ -289,6 +309,13 @@ + + + +
+ +
+