Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,24 +268,37 @@ def get_fields_from_chart(self, _session: Session):
return format_chart_fields(chart_info)

def filter_terminology_template(self, _session: Session, oid: int = None, ds_id: int = None):
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
self.current_logs[OperationEnum.FILTER_TERMS] = start_log(session=_session,
operate=OperationEnum.FILTER_TERMS,
record_id=self.record.id, local_operation=True)

self.chat_question.terminologies, term_list = get_terminology_template(_session, self.chat_question.question,
oid, ds_id)
calculate_oid, calculate_ds_id)
self.current_logs[OperationEnum.FILTER_TERMS] = end_log(session=_session,
log=self.current_logs[OperationEnum.FILTER_TERMS],
full_message=term_list)

def filter_custom_prompts(self, _session: Session, custom_prompt_type: CustomPromptTypeEnum, oid: int = None,
ds_id: int = None):
if SQLBotLicenseUtil.valid():
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
self.current_logs[OperationEnum.FILTER_CUSTOM_PROMPT] = start_log(session=_session,
operate=OperationEnum.FILTER_CUSTOM_PROMPT,
record_id=self.record.id,
local_operation=True)
self.chat_question.custom_prompt, prompt_list = find_custom_prompts(_session, custom_prompt_type, oid,
ds_id)
self.chat_question.custom_prompt, prompt_list = find_custom_prompts(_session, custom_prompt_type, calculate_oid,
calculate_ds_id)
self.current_logs[OperationEnum.FILTER_CUSTOM_PROMPT] = end_log(session=_session,
log=self.current_logs[
OperationEnum.FILTER_CUSTOM_PROMPT],
Expand All @@ -296,14 +309,20 @@ def filter_training_template(self, _session: Session, oid: int = None, ds_id: in
operate=OperationEnum.FILTER_SQL_EXAMPLE,
record_id=self.record.id,
local_operation=True)
calculate_oid = oid
calculate_ds_id = ds_id
if self.current_assistant:
calculate_oid = self.current_assistant.oid if self.current_assistant.type != 4 else self.current_user.oid
if self.current_assistant.type == 1:
calculate_ds_id = None
if self.current_assistant and self.current_assistant.type == 1:
self.chat_question.data_training, example_list = get_training_template(_session,
self.chat_question.question, oid,
self.chat_question.question, calculate_oid,
None, self.current_assistant.id)
else:
self.chat_question.data_training, example_list = get_training_template(_session,
self.chat_question.question, oid,
ds_id)
self.chat_question.question, calculate_oid,
calculate_ds_id)
self.current_logs[OperationEnum.FILTER_SQL_EXAMPLE] = end_log(session=_session,
log=self.current_logs[
OperationEnum.FILTER_SQL_EXAMPLE],
Expand Down
13 changes: 3 additions & 10 deletions backend/apps/system/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,7 @@ async def validateAssistant(self, assistantToken: Optional[str], trans: I18n) ->
assistant_info = await get_assistant_info(session=session, assistant_id=payload['assistant_id'])
assistant_info = AssistantModel.model_validate(assistant_info)
assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True))
if assistant_info and assistant_info.type == 0:
if payload['oid']:
session_user.oid = int(payload['oid'])
else:
assistant_oid = 1
configuration = assistant_info.configuration
config_obj = json.loads(configuration) if configuration else {}
assistant_oid = config_obj.get('oid', 1)
session_user.oid = int(assistant_oid)
session_user.oid = int(assistant_info.oid)

return True, session_user, assistant_info
except Exception as e:
Expand Down Expand Up @@ -226,7 +218,8 @@ async def validateEmbedded(self, param: str, trans: I18n) -> tuple[any]:
if not session_user.oid or session_user.oid == 0:
message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin'))
raise Exception(message)

if session_user.oid:
assistant_info.oid = int(session_user.oid)
return True, session_user, assistant_info
except Exception as e:
SQLBotLogUtil.exception(f"Embedded validation error: {str(e)}")
Expand Down
1 change: 1 addition & 0 deletions backend/apps/system/schemas/system_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class AssistantBase(BaseModel):
type: int = Field(default=0, description=f"{PLACEHOLDER_PREFIX}assistant_type") # 0普通小助手 1高级 4页面嵌入
configuration: Optional[str] = Field(default=None, description=f"{PLACEHOLDER_PREFIX}assistant_configuration")
description: Optional[str] = Field(default=None, description=f"{PLACEHOLDER_PREFIX}assistant_description")
oid: Optional[int] = Field(default=1, description=f"{PLACEHOLDER_PREFIX}oid")


class AssistantDTO(AssistantBase, BaseCreatorDTO):
Expand Down