From 5b4977deddc3466d3bfff662f95b9e1782fe9bea Mon Sep 17 00:00:00 2001 From: ulleo Date: Wed, 28 Jan 2026 11:00:25 +0800 Subject: [PATCH 1/2] feat: add more chat log operation steps --- .../versions/062_update_chat_log_dll.py | 36 +++++ backend/apps/chat/curd/chat.py | 24 ++- backend/apps/chat/models/chat_model.py | 8 +- backend/apps/chat/task/llm.py | 149 ++++++++++++------ .../apps/data_training/curd/data_training.py | 8 +- backend/apps/terminology/curd/terminology.py | 6 +- 6 files changed, 168 insertions(+), 63 deletions(-) create mode 100644 backend/alembic/versions/062_update_chat_log_dll.py diff --git a/backend/alembic/versions/062_update_chat_log_dll.py b/backend/alembic/versions/062_update_chat_log_dll.py new file mode 100644 index 000000000..9d729b875 --- /dev/null +++ b/backend/alembic/versions/062_update_chat_log_dll.py @@ -0,0 +1,36 @@ +"""062_update_chat_log_dll + +Revision ID: c9ab05247503 +Revises: 547df942eb90 +Create Date: 2026-01-27 14:20:35.069255 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c9ab05247503' +down_revision = '547df942eb90' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('chat_log', sa.Column('local_operation', sa.Boolean(), nullable=True)) + sql = ''' + UPDATE chat_log SET local_operation = false + ''' + op.execute(sql) + op.alter_column('chat_log', 'local_operation', + existing_type=sa.BOOLEAN(), + nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('chat_log', 'local_operation') + # ### end Alembic commands ### diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index a231bef40..1673bd2cb 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -1,5 +1,5 @@ import datetime -from typing import List, Optional +from typing import List, Optional, Union import orjson import sqlparse @@ -62,6 +62,7 @@ def list_recent_questions(session: SessionDep, current_user: CurrentUser, dataso ) return [record[0] for record in chat_records] if chat_records else [] + def rename_chat_with_user(session: SessionDep, current_user: CurrentUser, rename_object: RenameChat) -> str: chat = session.get(Chat, rename_object.id) if not chat: @@ -78,6 +79,7 @@ def rename_chat_with_user(session: SessionDep, current_user: CurrentUser, rename session.commit() return brief + def rename_chat(session: SessionDep, rename_object: RenameChat) -> str: chat = session.get(Chat, rename_object.id) if not chat: @@ -104,6 +106,7 @@ def delete_chat(session, chart_id) -> str: return f'Chat with id {chart_id} has been deleted' + def delete_chat_with_user(session, current_user: CurrentUser, chart_id) -> str: chat = session.query(Chat).filter(Chat.id == chart_id).first() if not chat: @@ -220,6 +223,7 @@ def get_chat_chart_config(session: SessionDep, chat_record_id: int): pass return {} + def get_chart_data_with_user(session: SessionDep, current_user: CurrentUser, chat_record_id: int): stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id, ChatRecord.create_by == current_user.id)) res = session.execute(stmt) @@ -230,6 +234,7 @@ def get_chart_data_with_user(session: SessionDep, current_user: CurrentUser, cha pass return {} + def get_chat_chart_data(session: SessionDep, chat_record_id: int): stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id)) res = session.execute(stmt) @@ -240,8 +245,10 @@ def get_chat_chart_data(session: SessionDep, chat_record_id: int): pass return {} + def get_chat_predict_data_with_user(session: SessionDep, current_user: CurrentUser, chat_record_id: int): - stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chat_record_id, ChatRecord.create_by == current_user.id)) + stmt = select(ChatRecord.predict_data).where( + and_(ChatRecord.id == chat_record_id, ChatRecord.create_by == current_user.id)) res = session.execute(stmt) for row in res: try: @@ -250,6 +257,7 @@ def get_chat_predict_data_with_user(session: SessionDep, current_user: CurrentUs pass return {} + def get_chat_predict_data(session: SessionDep, chat_record_id: int): stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chat_record_id)) res = session.execute(stmt) @@ -607,10 +615,11 @@ def save_analysis_predict_record(session: SessionDep, base_record: ChatRecord, a return result -def start_log(session: SessionDep, ai_modal_id: int, ai_modal_name: str, operate: OperationEnum, record_id: int, - full_message: list[dict]) -> ChatLog: +def start_log(session: SessionDep, ai_modal_id: int = None, ai_modal_name: str = None, operate: OperationEnum = None, + record_id: int = None, full_message: Union[list[dict], dict] = None, + local_operation: bool = False) -> ChatLog: log = ChatLog(type=TypeEnum.CHAT, operate=operate, pid=record_id, ai_modal_id=ai_modal_id, base_modal=ai_modal_name, - messages=full_message, start_time=datetime.datetime.now()) + messages=full_message, start_time=datetime.datetime.now(), local_operation=local_operation) result = ChatLog(**log.model_dump()) @@ -623,7 +632,8 @@ def start_log(session: SessionDep, ai_modal_id: int, ai_modal_name: str, operate return result -def end_log(session: SessionDep, log: ChatLog, full_message: list[dict], reasoning_content: str = None, +def end_log(session: SessionDep, log: ChatLog, full_message: Union[list[dict], dict, str], + reasoning_content: str = None, token_usage=None) -> ChatLog: if token_usage is None: token_usage = {} @@ -867,6 +877,8 @@ def save_error_message(session: SessionDep, record_id: int, message: str) -> Cha session.commit() + # todo log error finish + return result diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 6cccdb360..81081feb7 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -40,7 +40,12 @@ class OperationEnum(Enum): GENERATE_SQL_WITH_PERMISSIONS = '5' CHOOSE_DATASOURCE = '6' GENERATE_DYNAMIC_SQL = '7' - + CHOOSE_TABLE = '8' + FILTER_TERMS = '9' + FILTER_SQL_EXAMPLE = '10' + FILTER_CUSTOM_PROMPT = '11' + EXECUTE_SQL = '12' + GENERATE_PICTURE = '13' class ChatFinishStep(Enum): GENERATE_SQL = 1 @@ -71,6 +76,7 @@ class ChatLog(SQLModel, table=True): start_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) token_usage: Optional[dict | None | int] = Field(sa_column=Column(JSONB)) + local_operation: bool = Field(default=False) class Chat(SQLModel, table=True): diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index d5b6cb5f8..8c328a77a 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -128,14 +128,11 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C if not ds: raise SingleMessageError("No available datasource configuration found") chat_question.engine = ds.type + get_version(ds) - chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id, chat_question.question) else: ds = session.get(CoreDatasource, chat.datasource) if not ds: raise SingleMessageError("No available datasource configuration found") chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds) - chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds, - question=chat_question.question, embedding=embedding) self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id) self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id) @@ -204,7 +201,10 @@ def is_running(self, timeout=0.5): except Exception as e: return True - def init_messages(self): + def init_messages(self, session: Session): + + self.choose_table_schema(session) + last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len( self.generate_sql_logs) > 0 else [] if self.chat_question.regenerate_record_id: @@ -267,6 +267,64 @@ def get_fields_from_chart(self, _session: Session): chart_info = get_chart_config(_session, self.record.id) return format_chart_fields(chart_info) + def filter_terminology_template(self, _session: Session, oid: int = None, ds_id: int = None): + self.current_logs[OperationEnum.FILTER_TERMS] = start_log(session=_session, + operate=OperationEnum.FILTER_TERMS, + record_id=self.record.id, local_operation=True) + self.chat_question.terminologies, term_list = get_terminology_template(_session, self.chat_question.question, + oid, ds_id) + self.current_logs[OperationEnum.FILTER_TERMS] = end_log(session=_session, + log=self.current_logs[OperationEnum.FILTER_TERMS], + full_message=term_list) + + def filter_custom_prompts(self, _session: Session, custom_prompt_type: CustomPromptTypeEnum, oid: int = None, + ds_id: int = None): + if SQLBotLicenseUtil.valid(): + self.current_logs[OperationEnum.FILTER_CUSTOM_PROMPT] = start_log(session=_session, + operate=OperationEnum.FILTER_CUSTOM_PROMPT, + record_id=self.record.id, + local_operation=True) + self.chat_question.custom_prompt, prompt_list = find_custom_prompts(_session, custom_prompt_type, oid, + ds_id) + self.current_logs[OperationEnum.FILTER_CUSTOM_PROMPT] = end_log(session=_session, + log=self.current_logs[ + OperationEnum.FILTER_CUSTOM_PROMPT], + full_message=prompt_list) + + def filter_training_template(self, _session: Session, oid: int = None, ds_id: int = None): + self.current_logs[OperationEnum.FILTER_SQL_EXAMPLE] = start_log(session=_session, + operate=OperationEnum.FILTER_SQL_EXAMPLE, + record_id=self.record.id, + local_operation=True) + if self.current_assistant and self.current_assistant.type == 1: + self.chat_question.data_training, example_list = get_training_template(_session, + self.chat_question.question, oid, + None, self.current_assistant.id) + else: + self.chat_question.data_training, example_list = get_training_template(_session, + self.chat_question.question, oid, + ds_id) + self.current_logs[OperationEnum.FILTER_SQL_EXAMPLE] = end_log(session=_session, + log=self.current_logs[ + OperationEnum.FILTER_SQL_EXAMPLE], + full_message=example_list) + + def choose_table_schema(self, _session: Session): + self.current_logs[OperationEnum.CHOOSE_TABLE] = start_log(session=_session, + operate=OperationEnum.CHOOSE_TABLE, + record_id=self.record.id, + local_operation=True) + self.chat_question.db_schema = self.out_ds_instance.get_db_schema( + self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( + session=_session, + current_user=self.current_user, + ds=self.ds, + question=self.chat_question.question) + + self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session, + log=self.current_logs[OperationEnum.CHOOSE_TABLE], + full_message=self.chat_question.db_schema) + def generate_analysis(self, _session: Session): fields = self.get_fields_from_chart(_session) self.chat_question.fields = orjson.dumps(fields).decode() @@ -275,11 +333,10 @@ def generate_analysis(self, _session: Session): analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = [] ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, - self.current_user.oid, ds_id) - if SQLBotLicenseUtil.valid(): - self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, - self.current_user.oid, ds_id) + + self.filter_terminology_template(_session, self.current_user.oid, ds_id) + + self.filter_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, self.current_user.oid, ds_id) analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) @@ -325,10 +382,8 @@ def generate_predict(self, _session: Session): data = get_chat_chart_data(_session, self.record.id) self.chat_question.data = orjson.dumps(data.get('data')).decode() - if SQLBotLicenseUtil.valid(): - ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, - self.current_user.oid, ds_id) + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None + self.filter_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, self.current_user.oid, ds_id) predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) @@ -509,8 +564,7 @@ def select_datasource(self, _session: Session): _ds = self.out_ds_instance.get_ds(data['id']) self.ds = _ds self.chat_question.engine = _ds.type + get_version(self.ds) - self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id, - self.chat_question.question) + _engine_type = self.chat_question.engine _chat.engine_type = _ds.type else: @@ -521,9 +575,7 @@ def select_datasource(self, _session: Session): self.ds = CoreDatasource(**_ds.model_dump()) self.chat_question.engine = (_ds.type_name if _ds.type != 'excel' else 'PostgreSQL') + get_version( self.ds) - self.chat_question.db_schema = get_table_schema(session=_session, - current_user=self.current_user, ds=self.ds, - question=self.chat_question.question) + _engine_type = self.chat_question.engine _chat.engine_type = _ds.type_name # save chat @@ -555,19 +607,13 @@ def select_datasource(self, _session: Session): oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, oid, - ds_id) - if self.current_assistant and self.current_assistant.type == 1: - self.chat_question.data_training = get_training_template(_session, self.chat_question.question, - oid, None, self.current_assistant.id) - else: - self.chat_question.data_training = get_training_template(_session, self.chat_question.question, - oid, ds_id) - if SQLBotLicenseUtil.valid(): - self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.GENERATE_SQL, - oid, ds_id) + self.filter_terminology_template(_session, oid, ds_id) - self.init_messages() + self.filter_training_template(_session, oid, ds_id) + + self.filter_custom_prompts(_session, CustomPromptTypeEnum.GENERATE_SQL, oid, ds_id) + + self.init_messages(_session) if _error: raise _error @@ -994,19 +1040,14 @@ def run_task(self, in_chat: bool = True, stream: bool = True, if self.ds: oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, - oid, ds_id) - if self.current_assistant and self.current_assistant.type == 1: - self.chat_question.data_training = get_training_template(_session, self.chat_question.question, - oid, None, self.current_assistant.id) - else: - self.chat_question.data_training = get_training_template(_session, self.chat_question.question, - oid, ds_id) - if SQLBotLicenseUtil.valid(): - self.chat_question.custom_prompt = find_custom_prompts(_session, - CustomPromptTypeEnum.GENERATE_SQL, - oid, ds_id) - self.init_messages() + + self.filter_terminology_template(_session, oid, ds_id) + + self.filter_training_template(_session, oid, ds_id) + + self.filter_custom_prompts(_session, CustomPromptTypeEnum.GENERATE_SQL, oid, ds_id) + + self.init_messages(_session) # return id if in_chat: @@ -1038,12 +1079,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True, 'engine_type': self.ds.type_name or self.ds.type, 'type': 'datasource'}).decode() + '\n\n' - self.chat_question.db_schema = self.out_ds_instance.get_db_schema( - self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( - session=_session, - current_user=self.current_user, - ds=self.ds, - question=self.chat_question.question) else: self.validate_history_ds(_session) @@ -1140,7 +1175,14 @@ def run_task(self, in_chat: bool = True, stream: bool = True, yield json_result return + self.current_logs[OperationEnum.EXECUTE_SQL] = start_log(session=_session, + operate=OperationEnum.EXECUTE_SQL, + record_id=self.record.id, local_operation=True) result = self.execute_sql(sql=real_execute_sql) + self.current_logs[OperationEnum.EXECUTE_SQL] = end_log(session=_session, + log=self.current_logs[OperationEnum.EXECUTE_SQL], + full_message={'sql': real_execute_sql, + 'count': len(result.get('data'))}) _data = DataFormat.convert_large_numbers_in_object_array(result.get('data')) result["data"] = _data @@ -1229,6 +1271,10 @@ def run_task(self, in_chat: bool = True, stream: bool = True, try: if chart.get('type') != 'table': # yield '### generated chart picture\n\n' + self.current_logs[OperationEnum.GENERATE_PICTURE] = start_log(session=_session, + operate=OperationEnum.GENERATE_PICTURE, + record_id=self.record.id, + local_operation=True) image_url, error = request_picture(self.record.chat_id, self.record.id, chart, format_json_data(result)) SQLBotLogUtil.info(image_url) @@ -1238,6 +1284,11 @@ def run_task(self, in_chat: bool = True, stream: bool = True, json_result['image_url'] = image_url if error is not None: raise error + + self.current_logs[OperationEnum.GENERATE_PICTURE] = end_log(session=_session, + log=self.current_logs[ + OperationEnum.GENERATE_PICTURE], + full_message=image_url) except Exception as e: if stream: if chart.get('type') != 'table': diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py index 20ea3fb65..175f6c812 100644 --- a/backend/apps/data_training/curd/data_training.py +++ b/backend/apps/data_training/curd/data_training.py @@ -610,15 +610,15 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'sql-examples') -> str: def get_training_template(session: SessionDep, question: str, oid: Optional[int] = 1, datasource: Optional[int] = None, - advanced_application_id: Optional[int] = None) -> str: + advanced_application_id: Optional[int] = None) -> tuple[str, list[dict]]: if not oid: oid = 1 if not datasource and not advanced_application_id: - return '' + return '', [] _results = select_training_by_question(session, question, oid, datasource, advanced_application_id) if _results and len(_results) > 0: data_training = to_xml_string(_results) template = get_base_data_training_template().format(data_training=data_training) - return template + return template, _results else: - return '' + return '', [] diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index 4a866cbeb..296fd1eba 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -846,13 +846,13 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str: def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1, - datasource: Optional[int] = None) -> str: + datasource: Optional[int] = None) -> tuple[str, list[dict]]: if not oid: oid = 1 _results = select_terminology_by_word(session, question, oid, datasource) if _results and len(_results) > 0: terminology = to_xml_string(_results) template = get_base_terminology_template().format(terminologies=terminology) - return template + return template, _results else: - return '' + return '', [] From b69048c8e033ffd9c04acfa5020c8837b115a380 Mon Sep 17 00:00:00 2001 From: ulleo Date: Wed, 28 Jan 2026 11:09:59 +0800 Subject: [PATCH 2/2] chore: update pyproject.toml --- backend/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 0166dbfd8..2f4e9d317 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "pyyaml (>=6.0.2,<7.0.0)", "fastapi-mcp (>=0.3.4,<0.4.0)", "tabulate>=0.9.0", - "sqlbot-xpack>=0.0.5.9,<0.0.6.0", + "sqlbot-xpack>=0.0.5.10,<0.0.6.0", "fastapi-cache2>=0.2.2", "sqlparse>=0.5.3", "redis>=6.2.0",