diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index a64d2a9ee..9c47ba3a7 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -49,7 +49,7 @@ class Context: registered_web_apis: list = [] - # back compatibility + # 向后兼容的变量 _register_tasks: list[Awaitable] = [] _star_manager = None @@ -73,12 +73,19 @@ def __init__( self._db = db """AstrBot 数据库""" self.provider_manager = provider_manager + """模型提供商管理器""" self.platform_manager = platform_manager + """平台适配器管理器""" self.conversation_manager = conversation_manager + """会话管理器""" self.message_history_manager = message_history_manager + """平台消息历史管理器""" self.persona_manager = persona_manager + """人格角色设定管理器""" self.astrbot_config_mgr = astrbot_config_mgr + """配置文件管理器(非webui)""" self.kb_manager = knowledge_base_manager + """知识库管理器""" async def llm_generate( self, @@ -226,14 +233,16 @@ async def tool_loop_agent( return llm_resp async def get_current_chat_provider_id(self, umo: str) -> str: - """Get the ID of the currently used chat provider. + """获取当前使用的聊天模型 Provider ID。 Args: - umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used. + umo: unified_message_origin。消息会话来源 ID。 - Raises: - ProviderNotFoundError: If the specified chat provider is not found + Returns: + 指定消息会话来源当前使用的聊天模型 Provider ID。 + Raises: + ProviderNotFoundError: 未找到。 """ prov = self.get_using_provider(umo) if not prov: @@ -255,20 +264,27 @@ def get_llm_tool_manager(self) -> FunctionToolManager: return self.provider_manager.llm_tools def activate_llm_tool(self, name: str) -> bool: - """激活一个已经注册的函数调用工具。注册的工具默认是激活状态。 + """激活一个已经注册的函数调用工具。 + + Args: + name: 工具名称。 Returns: - 如果没找到,会返回 False + 如果成功激活返回 True,如果没找到工具返回 False。 + Note: + 注册的工具默认是激活状态。 """ return self.provider_manager.llm_tools.activate_llm_tool(name, star_map) def deactivate_llm_tool(self, name: str) -> bool: """停用一个已经注册的函数调用工具。 - Returns: - 如果没找到,会返回 False + Args: + name: 工具名称。 + Returns: + 如果成功停用返回 True,如果没找到工具返回 False。 """ return self.provider_manager.llm_tools.deactivate_llm_tool(name) @@ -278,7 +294,17 @@ def get_provider_by_id( ) -> ( Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None ): - """通过 ID 获取对应的 LLM Provider。""" + """通过 ID 获取对应的 LLM Provider。 + + Args: + provider_id: 提供者 ID。 + + Returns: + 提供者实例,如果未找到则返回 None。 + + Note: + 如果提供者 ID 存在但未找到提供者,会记录警告日志。 + """ prov = self.provider_manager.inst_map.get(provider_id) if provider_id and not prov: logger.warning( @@ -303,11 +329,20 @@ def get_all_embedding_providers(self) -> list[EmbeddingProvider]: return self.provider_manager.embedding_provider_insts def get_using_provider(self, umo: str | None = None) -> Provider: - """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 + """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 Args: - umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 + umo: unified_message_origin 值,如果传入并且用户启用了提供商会话隔离, + 则使用该会话偏好的提供商。 + + Returns: + 当前使用的文本生成提供者。 + Raises: + ValueError: 返回的提供者不是 Provider 类型。 + + Note: + 通过 /provider 指令可以切换提供者。 """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.CHAT_COMPLETION, @@ -321,8 +356,13 @@ def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None: """获取当前使用的用于 TTS 任务的 Provider。 Args: - umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 + umo: unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 + + Returns: + 当前使用的 TTS 提供者,如果未设置则返回 None。 + Raises: + ValueError: 返回的提供者不是 TTSProvider 类型。 """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.TEXT_TO_SPEECH, @@ -336,8 +376,13 @@ def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None: """获取当前使用的用于 STT 任务的 Provider。 Args: - umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 + umo: unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 + + Returns: + 当前使用的 STT 提供者,如果未设置则返回 None。 + Raises: + ValueError: 返回的提供者不是 STTProvider 类型。 """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.SPEECH_TO_TEXT, @@ -348,9 +393,19 @@ def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None: return prov def get_config(self, umo: str | None = None) -> AstrBotConfig: - """获取 AstrBot 的配置。""" + """获取 AstrBot 的配置。 + + Args: + umo: unified_message_origin 值,用于获取特定会话的配置。 + + Returns: + AstrBot 配置对象。 + + Note: + 如果不提供 umo 参数,将返回默认配置。 + """ if not umo: - # using default config + # 使用默认配置 return self._config return self.astrbot_config_mgr.get_conf(umo) @@ -361,14 +416,19 @@ async def send_message( ) -> bool: """根据 session(unified_msg_origin) 主动发送消息。 - @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 - @param message_chain: 消息链。 + Args: + session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 + message_chain: 消息链。 - @return: 是否找到匹配的平台。 + Returns: + 是否找到匹配的平台。 - 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 + Raises: + ValueError: session 字符串不合法时抛出。 - NOTE: qq_official(QQ 官方 API 平台) 不支持此方法 + Note: + 当 session 为字符串时,会尝试解析为 MessageSession 对象。(类名为MessageSesion是因为历史遗留拼写错误) + qq_official(QQ 官方 API 平台) 不支持此方法。 """ if isinstance(session, str): try: @@ -383,7 +443,14 @@ async def send_message( return False def add_llm_tools(self, *tools: FunctionTool) -> None: - """添加 LLM 工具。""" + """添加 LLM 工具。 + + Args: + *tools: 要添加的函数工具对象。 + + Note: + 如果工具已存在,会替换已存在的工具。 + """ tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list} module_path = "" for tool in tools: @@ -416,6 +483,17 @@ def register_web_api( methods: list, desc: str, ): + """注册 Web API。 + + Args: + route: API 路由路径。 + view_handler: 异步视图处理函数。 + methods: HTTP 方法列表。 + desc: API 描述。 + + Note: + 如果相同路由和方法已注册,会替换现有的 API。 + """ for idx, api in enumerate(self.registered_web_apis): if api[0] == route and methods == api[2]: self.registered_web_apis[idx] = (route, view_handler, methods, desc) @@ -434,7 +512,14 @@ def get_event_queue(self) -> Queue: def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: """获取指定类型的平台适配器。 - 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) + Args: + platform_type: 平台类型或平台名称。 + + Returns: + 平台适配器实例,如果未找到则返回 None。 + + Note: + 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) """ for platform in self.platform_manager.platform_insts: name = platform.meta().name @@ -451,22 +536,32 @@ def get_platform_inst(self, platform_id: str) -> Platform | None: """获取指定 ID 的平台适配器实例。 Args: - platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。 + platform_id: 平台适配器的唯一标识符。 Returns: - Platform: 平台适配器实例,如果未找到则返回 None。 + 平台适配器实例,如果未找到则返回 None。 + Note: + 可以通过 event.get_platform_id() 获取平台 ID。 """ for platform in self.platform_manager.platform_insts: if platform.meta().id == platform_id: return platform def get_db(self) -> BaseDatabase: - """获取 AstrBot 数据库。""" + """获取 AstrBot 数据库。 + + Returns: + 数据库实例。 + """ return self._db def register_provider(self, provider: Provider): - """注册一个 LLM Provider(Chat_Completion 类型)。""" + """注册一个 LLM Provider(Chat_Completion 类型)。 + + Args: + provider: 提供者实例。 + """ self.provider_manager.provider_insts.append(provider) def register_llm_tool( @@ -478,12 +573,16 @@ def register_llm_tool( ) -> None: """[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。 - @param name: 函数名 - @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] - @param desc: 函数描述 - @param func_obj: 异步处理函数。 - - 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 + Args: + name: 函数名。 + func_args: 函数参数列表,格式为 + [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]。 + desc: 函数描述。 + func_obj: 异步处理函数。 + + Note: + 异步处理函数会接收到额外的关键词参数:event: AstrMessageEvent, context: Context。 + 该方法已弃用,请使用新的注册方式。 """ md = StarHandlerMetadata( event_type=EventType.OnLLMRequestEvent, @@ -498,7 +597,15 @@ def register_llm_tool( self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj) def unregister_llm_tool(self, name: str) -> None: - """[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。""" + """[DEPRECATED]删除一个函数调用工具。 + + Args: + name: 工具名称。 + + Note: + 如果再要启用,需要重新注册。 + 该方法已弃用。 + """ self.provider_manager.llm_tools.remove_func(name) def register_commands( @@ -511,16 +618,19 @@ def register_commands( use_regex=False, ignore_prefix=False, ): - """注册一个命令。 - - [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 - - @param star_name: 插件(Star)名称。 - @param command_name: 命令名称。 - @param desc: 命令描述。 - @param priority: 优先级。1-10。 - @param awaitable: 异步处理函数。 + """[DEPRECATED]注册一个命令。 + Args: + star_name: 插件(Star)名称。 + command_name: 命令名称。 + desc: 命令描述。 + priority: 优先级。1-10。 + awaitable: 异步处理函数。 + use_regex: 是否使用正则表达式匹配命令。 + ignore_prefix: 是否忽略命令前缀。 + + Note: + 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 """ md = StarHandlerMetadata( event_type=EventType.AdapterMessageEvent, @@ -540,5 +650,13 @@ def register_commands( star_handlers_registry.append(md) def register_task(self, task: Awaitable, desc: str): - """[DEPRECATED]注册一个异步任务。""" + """[DEPRECATED]注册一个异步任务。 + + Args: + task: 异步任务。 + desc: 任务描述。 + + Note: + 该方法已弃用。 + """ self._register_tasks.append(task)