Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/apps/data_training/api/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


@router.get("/page/{current_page}/{page_size}", summary=f"{PLACEHOLDER_PREFIX}get_dt_page")
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int,
question: Optional[str] = Query(None, description="搜索问题(可选)")):
current_page, page_size, total_count, total_pages, _list = page_data_training(session, current_page, page_size,
Expand All @@ -43,6 +44,7 @@ async def pager(session: SessionDep, current_user: CurrentUser, current_page: in


@router.put("", response_model=int, summary=f"{PLACEHOLDER_PREFIX}create_or_update_dt")
@require_permissions(permission=SqlbotPermission(role=['ws_admin'], type='ds', keyExpression="info.datasource"))
@system_log(LogConfig(operation_type=OperationType.CREATE_OR_UPDATE, module=OperationModules.DATA_TRAINING,resource_id_expr='info.id', result_id_expr="result_self"))
async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: DataTrainingInfo):
oid = current_user.oid
Expand Down
17 changes: 14 additions & 3 deletions backend/apps/datasource/api/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,21 @@ def inner():
@system_log(LogConfig(operation_type=OperationType.CREATE, module=OperationModules.DATASOURCE, result_id_expr="id"))
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
async def add(session: SessionDep, trans: Trans, user: CurrentUser, ds: CreateDatasource):
def inner():
""" def inner():
return create_ds(session, trans, user, ds)

return await asyncio.to_thread(inner)
return await asyncio.to_thread(inner) """
loop = asyncio.get_event_loop()

def sync_wrapper():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(create_ds(session, trans, user, ds))
finally:
loop.close()

return await loop.run_in_executor(None, sync_wrapper)


@router.post("/chooseTables/{id}", response_model=None, summary=f"{PLACEHOLDER_PREFIX}ds_choose_tables")
Expand Down Expand Up @@ -107,7 +118,7 @@ def inner():
@system_log(LogConfig(operation_type=OperationType.DELETE, module=OperationModules.DATASOURCE, resource_id_expr="id",
))
async def delete(session: SessionDep, id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id"), name: str = None):
return delete_ds(session, id)
return await delete_ds(session, id)


@router.post("/getTables/{id}", response_model=List[TableSchemaResponse], summary=f"{PLACEHOLDER_PREFIX}ds_get_tables")
Expand Down
21 changes: 17 additions & 4 deletions backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from apps.db.constant import DB
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
from apps.db.engine import get_engine_config, get_engine_conn
from apps.system.schemas.auth import CacheName, CacheNamespace
from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser, Trans
from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings
from common.utils.utils import deepcopy_ignore_extra
from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra
from common.core.sqlbot_cache import cache, clear_cache
from .table import get_tables_by_ds_id
from ..crud.field import delete_field_by_ds_id, update_field
from ..crud.table import delete_table_by_ds_id, update_table
Expand Down Expand Up @@ -63,8 +65,8 @@ def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDat
if ds_list is not None and len(ds_list) > 0:
raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist'))


def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource):
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="user.oid")
async def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource):
ds = CoreDatasource()
deepcopy_ignore_extra(create_ds, ds)
check_name(session, trans, user, ds)
Expand Down Expand Up @@ -117,7 +119,7 @@ def update_ds_recommended_config(session: SessionDep, datasource_id: int, recomm
session.commit()


def delete_ds(session: SessionDep, id: int):
async def delete_ds(session: SessionDep, id: int):
term = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
if term.type == "excel":
# drop all tables for current datasource
Expand All @@ -132,6 +134,8 @@ def delete_ds(session: SessionDep, id: int):
session.commit()
delete_table_by_ds_id(session, id)
delete_field_by_ds_id(session, id)
if term:
await clear_ws_resource_cache(term.oid)
return {
"message": f"Datasource with ID {id} deleted successfully."
}
Expand Down Expand Up @@ -526,3 +530,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n"

return schema_str

@cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid")
async def get_ws_ds(session, oid) -> list:
stmt = select(CoreDatasource.id).distinct().where(CoreDatasource.oid == oid)
db_list = session.exec(stmt).all()
return db_list
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid")
async def clear_ws_ds_cache(oid):
SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned")
10 changes: 8 additions & 2 deletions backend/apps/system/api/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from apps.system.crud.assistant_manage import dynamic_upgrade_cors, save
from apps.system.models.system_model import AssistantModel
from apps.system.schemas.auth import CacheName, CacheNamespace
from apps.system.schemas.permission import SqlbotPermission, require_permissions
from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantUiSchema, AssistantValidator
from common.core.config import settings
from common.core.deps import CurrentAssistant, SessionDep, Trans, CurrentUser
Expand Down Expand Up @@ -217,27 +218,31 @@ def get_db_type(type):


@router.get("", response_model=list[AssistantModel], summary=f"{PLACEHOLDER_PREFIX}assistant_grid_api", description=f"{PLACEHOLDER_PREFIX}assistant_grid_api")
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
async def query(session: SessionDep, current_user: CurrentUser):
list_result = session.exec(select(AssistantModel).where(AssistantModel.oid == current_user.oid, AssistantModel.type != 4).order_by(AssistantModel.name,
AssistantModel.create_time)).all()
return list_result


@router.get("/advanced_application", response_model=list[AssistantModel], include_in_schema=False)
async def query_advanced_application(session: SessionDep):
list_result = session.exec(select(AssistantModel).where(AssistantModel.type == 1).order_by(AssistantModel.name,
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
async def query_advanced_application(session: SessionDep, current_user: CurrentUser):
list_result = session.exec(select(AssistantModel).where(AssistantModel.type == 1, AssistantModel.oid == current_user.oid).order_by(AssistantModel.name,
AssistantModel.create_time)).all()
return list_result


@router.post("", summary=f"{PLACEHOLDER_PREFIX}assistant_create_api", description=f"{PLACEHOLDER_PREFIX}assistant_create_api")
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
@system_log(LogConfig(operation_type=OperationType.CREATE, module=OperationModules.APPLICATION, result_id_expr="id"))
async def add(request: Request, session: SessionDep, current_user: CurrentUser, creator: AssistantBase):
oid = current_user.oid if creator.type != 4 else 1
return await save(request, session, creator, oid)


@router.put("", summary=f"{PLACEHOLDER_PREFIX}assistant_update_api", description=f"{PLACEHOLDER_PREFIX}assistant_update_api")
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="editor.id")
@system_log(LogConfig(operation_type=OperationType.UPDATE, module=OperationModules.APPLICATION, resource_id_expr="editor.id"))
async def update(request: Request, session: SessionDep, editor: AssistantDTO):
Expand All @@ -262,6 +267,7 @@ async def get_one(session: SessionDep, id: int = Path(description="ID")):


@router.delete("/{id}", summary=f"{PLACEHOLDER_PREFIX}assistant_del_api", description=f"{PLACEHOLDER_PREFIX}assistant_del_api")
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id")
@system_log(LogConfig(operation_type=OperationType.DELETE, module=OperationModules.APPLICATION, resource_id_expr="id"))
async def delete(request: Request, session: SessionDep, id: int = Path(description="ID")):
Expand Down
1 change: 1 addition & 0 deletions backend/apps/system/schemas/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class CacheName(Enum):
ASSISTANT_INFO = "assistant:info"
ASSISTANT_DS = "assistant:ds"
ASK_INFO = "ask:info"
DS_ID_LIST = "ds:id:list"
def __str__(self):
return self.value

10 changes: 7 additions & 3 deletions backend/apps/system/schemas/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starlette.middleware.base import BaseHTTPMiddleware
from sqlmodel import Session, select
from apps.chat.models.chat_model import Chat
from apps.datasource.crud.datasource import get_ws_ds
from apps.datasource.models.datasource import CoreDatasource
from common.core.db import engine
from apps.system.schemas.system_schema import UserInfoDTO
Expand All @@ -22,7 +23,7 @@ async def get_ws_resource(oid, type) -> list:
with Session(engine) as session:
stmt = None
if type == 'ds' or type == 'datasource':
stmt = select(CoreDatasource.id).where(CoreDatasource.oid == oid)
return await get_ws_ds(session, oid)
if type == 'chat':
stmt = select(Chat.id).where(Chat.oid == oid)
if stmt is not None:
Expand All @@ -32,6 +33,9 @@ async def get_ws_resource(oid, type) -> list:


async def check_ws_permission(oid, type, resource) -> bool:
if not resource or (isinstance(resource, list) and len(resource) == 0):
return True

resource_id_list = await get_ws_resource(oid, type)
if not resource_id_list:
return False
Expand All @@ -53,7 +57,7 @@ async def wrapper(*args, **kwargs):
)
current_oid = current_user.oid

if current_user.isAdmin:
if current_user.isAdmin and not permission.type:
return await func(*args, **kwargs)
role_list = permission.role
keyExpression = permission.keyExpression
Expand All @@ -62,7 +66,7 @@ async def wrapper(*args, **kwargs):
if role_list:
if 'admin' in role_list and not current_user.isAdmin:
raise Exception('no permission to execute, only for admin')
if 'ws_admin' in role_list and current_user.weight == 0:
if 'ws_admin' in role_list and current_user.weight == 0 and not current_user.isAdmin:
raise Exception('no permission to execute, only for workspace admin')
if not resource_type:
return await func(*args, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/terminology/api/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


@router.get("/page/{current_page}/{page_size}", summary=f"{PLACEHOLDER_PREFIX}get_term_page")
@require_permissions(permission=SqlbotPermission(role=['ws_admin']))
async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int,
word: Optional[str] = Query(None, description="搜索术语(可选)"),
dslist: Optional[list[int]] = Query(None, description="数据集ID集合(可选)")):
Expand All @@ -42,6 +43,7 @@ async def pager(session: SessionDep, current_user: CurrentUser, current_page: in


@router.put("", summary=f"{PLACEHOLDER_PREFIX}create_or_update_term")
@require_permissions(permission=SqlbotPermission(role=['ws_admin'], type='ds', keyExpression="info.datasource_ids"))
@system_log(LogConfig(operation_type=OperationType.CREATE_OR_UPDATE, module=OperationModules.TERMINOLOGY,resource_id_expr='info.id', result_id_expr="result_self"))
async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: TerminologyInfo):
oid = current_user.oid
Expand Down