Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 162 additions & 44 deletions astrbot/core/star/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Context:

registered_web_apis: list = []

# back compatibility
# 向后兼容的变量
_register_tasks: list[Awaitable] = []
_star_manager = None

Expand All @@ -73,12 +73,19 @@ def __init__(
self._db = db
"""AstrBot 数据库"""
self.provider_manager = provider_manager
"""模型提供商管理器"""
self.platform_manager = platform_manager
"""平台适配器管理器"""
self.conversation_manager = conversation_manager
"""会话管理器"""
self.message_history_manager = message_history_manager
"""平台消息历史管理器"""
self.persona_manager = persona_manager
"""人格角色设定管理器"""
self.astrbot_config_mgr = astrbot_config_mgr
"""配置文件管理器(非webui)"""
self.kb_manager = knowledge_base_manager
"""知识库管理器"""

async def llm_generate(
self,
Expand Down Expand Up @@ -226,14 +233,16 @@ async def tool_loop_agent(
return llm_resp

async def get_current_chat_provider_id(self, umo: str) -> str:
"""Get the ID of the currently used chat provider.
"""获取当前使用的聊天模型 Provider ID

Args:
umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
umo: unified_message_origin。消息会话来源 ID。

Raises:
ProviderNotFoundError: If the specified chat provider is not found
Returns:
指定消息会话来源当前使用的聊天模型 Provider ID。

Raises:
ProviderNotFoundError: 未找到。
"""
prov = self.get_using_provider(umo)
if not prov:
Expand All @@ -255,20 +264,27 @@ def get_llm_tool_manager(self) -> FunctionToolManager:
return self.provider_manager.llm_tools

def activate_llm_tool(self, name: str) -> bool:
"""激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
"""激活一个已经注册的函数调用工具。

Args:
name: 工具名称。

Returns:
如果没找到,会返回 False
如果成功激活返回 True,如果没找到工具返回 False

Note:
注册的工具默认是激活状态。
"""
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)

def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。

Returns:
如果没找到,会返回 False
Args:
name: 工具名称。

Returns:
如果成功停用返回 True,如果没找到工具返回 False。
"""
return self.provider_manager.llm_tools.deactivate_llm_tool(name)

Expand All @@ -278,7 +294,17 @@ def get_provider_by_id(
) -> (
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
):
"""通过 ID 获取对应的 LLM Provider。"""
"""通过 ID 获取对应的 LLM Provider。

Args:
provider_id: 提供者 ID。

Returns:
提供者实例,如果未找到则返回 None。

Note:
如果提供者 ID 存在但未找到提供者,会记录警告日志。
"""
prov = self.provider_manager.inst_map.get(provider_id)
if provider_id and not prov:
logger.warning(
Expand All @@ -303,11 +329,20 @@ def get_all_embedding_providers(self) -> list[EmbeddingProvider]:
return self.provider_manager.embedding_provider_insts

def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。

Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
umo: unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,
则使用该会话偏好的提供商。

Returns:
当前使用的文本生成提供者。

Raises:
ValueError: 返回的提供者不是 Provider 类型。

Note:
通过 /provider 指令可以切换提供者。
"""
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION,
Expand All @@ -321,8 +356,13 @@ def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
"""获取当前使用的用于 TTS 任务的 Provider。

Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
umo: unified_message_origin 值,如果传入,则使用该会话偏好的提供商。

Returns:
当前使用的 TTS 提供者,如果未设置则返回 None。

Raises:
ValueError: 返回的提供者不是 TTSProvider 类型。
"""
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH,
Expand All @@ -336,8 +376,13 @@ def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
"""获取当前使用的用于 STT 任务的 Provider。

Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
umo: unified_message_origin 值,如果传入,则使用该会话偏好的提供商。

Returns:
当前使用的 STT 提供者,如果未设置则返回 None。

Raises:
ValueError: 返回的提供者不是 STTProvider 类型。
"""
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT,
Expand All @@ -348,9 +393,19 @@ def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
return prov

def get_config(self, umo: str | None = None) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
"""获取 AstrBot 的配置。

Args:
umo: unified_message_origin 值,用于获取特定会话的配置。

Returns:
AstrBot 配置对象。

Note:
如果不提供 umo 参数,将返回默认配置。
"""
if not umo:
# using default config
# 使用默认配置
return self._config
return self.astrbot_config_mgr.get_conf(umo)

Expand All @@ -361,14 +416,19 @@ async def send_message(
) -> bool:
"""根据 session(unified_msg_origin) 主动发送消息。

@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
Args:
session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
message_chain: 消息链。

@return: 是否找到匹配的平台。
Returns:
是否找到匹配的平台。

当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
Raises:
ValueError: session 字符串不合法时抛出。

NOTE: qq_official(QQ 官方 API 平台) 不支持此方法
Note:
当 session 为字符串时,会尝试解析为 MessageSession 对象。(类名为MessageSesion是因为历史遗留拼写错误)
qq_official(QQ 官方 API 平台) 不支持此方法。
"""
if isinstance(session, str):
try:
Expand All @@ -383,7 +443,14 @@ async def send_message(
return False

def add_llm_tools(self, *tools: FunctionTool) -> None:
"""添加 LLM 工具。"""
"""添加 LLM 工具。

Args:
*tools: 要添加的函数工具对象。

Note:
如果工具已存在,会替换已存在的工具。
"""
tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list}
module_path = ""
for tool in tools:
Expand Down Expand Up @@ -416,6 +483,17 @@ def register_web_api(
methods: list,
desc: str,
):
"""注册 Web API。

Args:
route: API 路由路径。
view_handler: 异步视图处理函数。
methods: HTTP 方法列表。
desc: API 描述。

Note:
如果相同路由和方法已注册,会替换现有的 API。
"""
for idx, api in enumerate(self.registered_web_apis):
if api[0] == route and methods == api[2]:
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
Expand All @@ -434,7 +512,14 @@ def get_event_queue(self) -> Queue:
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。

该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
Args:
platform_type: 平台类型或平台名称。

Returns:
平台适配器实例,如果未找到则返回 None。

Note:
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
Expand All @@ -451,22 +536,32 @@ def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。

Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取
platform_id: 平台适配器的唯一标识符。

Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
平台适配器实例,如果未找到则返回 None。

Note:
可以通过 event.get_platform_id() 获取平台 ID。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform

def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
"""获取 AstrBot 数据库。

Returns:
数据库实例。
"""
return self._db

def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
"""注册一个 LLM Provider(Chat_Completion 类型)。

Args:
provider: 提供者实例。
"""
self.provider_manager.provider_insts.append(provider)

def register_llm_tool(
Expand All @@ -478,12 +573,16 @@ def register_llm_tool(
) -> None:
"""[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。

@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。

异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
Args:
name: 函数名。
func_args: 函数参数列表,格式为
[{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]。
desc: 函数描述。
func_obj: 异步处理函数。

Note:
异步处理函数会接收到额外的关键词参数:event: AstrMessageEvent, context: Context。
该方法已弃用,请使用新的注册方式。
"""
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
Expand All @@ -498,7 +597,15 @@ def register_llm_tool(
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)

def unregister_llm_tool(self, name: str) -> None:
"""[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。"""
"""[DEPRECATED]删除一个函数调用工具。

Args:
name: 工具名称。

Note:
如果再要启用,需要重新注册。
该方法已弃用。
"""
self.provider_manager.llm_tools.remove_func(name)

def register_commands(
Expand All @@ -511,16 +618,19 @@ def register_commands(
use_regex=False,
ignore_prefix=False,
):
"""注册一个命令。

[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。

@param star_name: 插件(Star)名称。
@param command_name: 命令名称。
@param desc: 命令描述。
@param priority: 优先级。1-10。
@param awaitable: 异步处理函数。
"""[DEPRECATED]注册一个命令。

Args:
star_name: 插件(Star)名称。
command_name: 命令名称。
desc: 命令描述。
priority: 优先级。1-10。
awaitable: 异步处理函数。
use_regex: 是否使用正则表达式匹配命令。
ignore_prefix: 是否忽略命令前缀。

Note:
推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
"""
md = StarHandlerMetadata(
event_type=EventType.AdapterMessageEvent,
Expand All @@ -540,5 +650,13 @@ def register_commands(
star_handlers_registry.append(md)

def register_task(self, task: Awaitable, desc: str):
"""[DEPRECATED]注册一个异步任务。"""
"""[DEPRECATED]注册一个异步任务。

Args:
task: 异步任务。
desc: 任务描述。

Note:
该方法已弃用。
"""
self._register_tasks.append(task)