diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py index ce2ecec1..2e5bf77b 100644 --- a/backend/apps/data_training/api/data_training.py +++ b/backend/apps/data_training/api/data_training.py @@ -27,6 +27,7 @@ @router.get("/page/{current_page}/{page_size}", summary=f"{PLACEHOLDER_PREFIX}get_dt_page") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int, question: Optional[str] = Query(None, description="搜索问题(可选)")): current_page, page_size, total_count, total_pages, _list = page_data_training(session, current_page, page_size, @@ -43,6 +44,7 @@ async def pager(session: SessionDep, current_user: CurrentUser, current_page: in @router.put("", response_model=int, summary=f"{PLACEHOLDER_PREFIX}create_or_update_dt") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'], type='ds', keyExpression="info.datasource")) @system_log(LogConfig(operation_type=OperationType.CREATE_OR_UPDATE, module=OperationModules.DATA_TRAINING,resource_id_expr='info.id', result_id_expr="result_self")) async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: DataTrainingInfo): oid = current_user.oid diff --git a/backend/apps/datasource/api/datasource.py b/backend/apps/datasource/api/datasource.py index 5b4f7608..a731f003 100644 --- a/backend/apps/datasource/api/datasource.py +++ b/backend/apps/datasource/api/datasource.py @@ -75,10 +75,21 @@ def inner(): @system_log(LogConfig(operation_type=OperationType.CREATE, module=OperationModules.DATASOURCE, result_id_expr="id")) @require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def add(session: SessionDep, trans: Trans, user: CurrentUser, ds: CreateDatasource): - def inner(): + """ def inner(): return create_ds(session, trans, user, ds) - return await asyncio.to_thread(inner) + return await asyncio.to_thread(inner) """ + loop = asyncio.get_event_loop() + + def sync_wrapper(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(create_ds(session, trans, user, ds)) + finally: + loop.close() + + return await loop.run_in_executor(None, sync_wrapper) @router.post("/chooseTables/{id}", response_model=None, summary=f"{PLACEHOLDER_PREFIX}ds_choose_tables") @@ -107,7 +118,7 @@ def inner(): @system_log(LogConfig(operation_type=OperationType.DELETE, module=OperationModules.DATASOURCE, resource_id_expr="id", )) async def delete(session: SessionDep, id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id"), name: str = None): - return delete_ds(session, id) + return await delete_ds(session, id) @router.post("/getTables/{id}", response_model=List[TableSchemaResponse], summary=f"{PLACEHOLDER_PREFIX}ds_get_tables") diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 9fc54842..8c919166 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -13,10 +13,12 @@ from apps.db.constant import DB from apps.db.db import get_tables, get_fields, exec_sql, check_connection from apps.db.engine import get_engine_config, get_engine_conn +from apps.system.schemas.auth import CacheName, CacheNamespace from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings -from common.utils.utils import deepcopy_ignore_extra +from common.utils.utils import SQLBotLogUtil, deepcopy_ignore_extra +from common.core.sqlbot_cache import cache, clear_cache from .table import get_tables_by_ds_id from ..crud.field import delete_field_by_ds_id, update_field from ..crud.table import delete_table_by_ds_id, update_table @@ -63,8 +65,8 @@ def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDat if ds_list is not None and len(ds_list) > 0: raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist')) - -def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource): +@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="user.oid") +async def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource): ds = CoreDatasource() deepcopy_ignore_extra(create_ds, ds) check_name(session, trans, user, ds) @@ -117,7 +119,7 @@ def update_ds_recommended_config(session: SessionDep, datasource_id: int, recomm session.commit() -def delete_ds(session: SessionDep, id: int): +async def delete_ds(session: SessionDep, id: int): term = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first() if term.type == "excel": # drop all tables for current datasource @@ -132,6 +134,8 @@ def delete_ds(session: SessionDep, id: int): session.commit() delete_table_by_ds_id(session, id) delete_field_by_ds_id(session, id) + if term: + await clear_ws_resource_cache(term.oid) return { "message": f"Datasource with ID {id} deleted successfully." } @@ -526,3 +530,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n" return schema_str + +@cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid") +async def get_ws_ds(session, oid) -> list: + stmt = select(CoreDatasource.id).distinct().where(CoreDatasource.oid == oid) + db_list = session.exec(stmt).all() + return db_list +@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid") +async def clear_ws_ds_cache(oid): + SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned") \ No newline at end of file diff --git a/backend/apps/system/api/assistant.py b/backend/apps/system/api/assistant.py index c4404835..5db3f403 100644 --- a/backend/apps/system/api/assistant.py +++ b/backend/apps/system/api/assistant.py @@ -15,6 +15,7 @@ from apps.system.crud.assistant_manage import dynamic_upgrade_cors, save from apps.system.models.system_model import AssistantModel from apps.system.schemas.auth import CacheName, CacheNamespace +from apps.system.schemas.permission import SqlbotPermission, require_permissions from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantUiSchema, AssistantValidator from common.core.config import settings from common.core.deps import CurrentAssistant, SessionDep, Trans, CurrentUser @@ -217,6 +218,7 @@ def get_db_type(type): @router.get("", response_model=list[AssistantModel], summary=f"{PLACEHOLDER_PREFIX}assistant_grid_api", description=f"{PLACEHOLDER_PREFIX}assistant_grid_api") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def query(session: SessionDep, current_user: CurrentUser): list_result = session.exec(select(AssistantModel).where(AssistantModel.oid == current_user.oid, AssistantModel.type != 4).order_by(AssistantModel.name, AssistantModel.create_time)).all() @@ -224,13 +226,15 @@ async def query(session: SessionDep, current_user: CurrentUser): @router.get("/advanced_application", response_model=list[AssistantModel], include_in_schema=False) -async def query_advanced_application(session: SessionDep): - list_result = session.exec(select(AssistantModel).where(AssistantModel.type == 1).order_by(AssistantModel.name, +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) +async def query_advanced_application(session: SessionDep, current_user: CurrentUser): + list_result = session.exec(select(AssistantModel).where(AssistantModel.type == 1, AssistantModel.oid == current_user.oid).order_by(AssistantModel.name, AssistantModel.create_time)).all() return list_result @router.post("", summary=f"{PLACEHOLDER_PREFIX}assistant_create_api", description=f"{PLACEHOLDER_PREFIX}assistant_create_api") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) @system_log(LogConfig(operation_type=OperationType.CREATE, module=OperationModules.APPLICATION, result_id_expr="id")) async def add(request: Request, session: SessionDep, current_user: CurrentUser, creator: AssistantBase): oid = current_user.oid if creator.type != 4 else 1 @@ -238,6 +242,7 @@ async def add(request: Request, session: SessionDep, current_user: CurrentUser, @router.put("", summary=f"{PLACEHOLDER_PREFIX}assistant_update_api", description=f"{PLACEHOLDER_PREFIX}assistant_update_api") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) @clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="editor.id") @system_log(LogConfig(operation_type=OperationType.UPDATE, module=OperationModules.APPLICATION, resource_id_expr="editor.id")) async def update(request: Request, session: SessionDep, editor: AssistantDTO): @@ -262,6 +267,7 @@ async def get_one(session: SessionDep, id: int = Path(description="ID")): @router.delete("/{id}", summary=f"{PLACEHOLDER_PREFIX}assistant_del_api", description=f"{PLACEHOLDER_PREFIX}assistant_del_api") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) @clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id") @system_log(LogConfig(operation_type=OperationType.DELETE, module=OperationModules.APPLICATION, resource_id_expr="id")) async def delete(request: Request, session: SessionDep, id: int = Path(description="ID")): diff --git a/backend/apps/system/schemas/auth.py b/backend/apps/system/schemas/auth.py index 14db17d9..de4a4dfe 100644 --- a/backend/apps/system/schemas/auth.py +++ b/backend/apps/system/schemas/auth.py @@ -17,6 +17,7 @@ class CacheName(Enum): ASSISTANT_INFO = "assistant:info" ASSISTANT_DS = "assistant:ds" ASK_INFO = "ask:info" + DS_ID_LIST = "ds:id:list" def __str__(self): return self.value diff --git a/backend/apps/system/schemas/permission.py b/backend/apps/system/schemas/permission.py index d0dccd24..c590ec2e 100644 --- a/backend/apps/system/schemas/permission.py +++ b/backend/apps/system/schemas/permission.py @@ -8,6 +8,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from sqlmodel import Session, select from apps.chat.models.chat_model import Chat +from apps.datasource.crud.datasource import get_ws_ds from apps.datasource.models.datasource import CoreDatasource from common.core.db import engine from apps.system.schemas.system_schema import UserInfoDTO @@ -22,7 +23,7 @@ async def get_ws_resource(oid, type) -> list: with Session(engine) as session: stmt = None if type == 'ds' or type == 'datasource': - stmt = select(CoreDatasource.id).where(CoreDatasource.oid == oid) + return await get_ws_ds(session, oid) if type == 'chat': stmt = select(Chat.id).where(Chat.oid == oid) if stmt is not None: @@ -32,6 +33,9 @@ async def get_ws_resource(oid, type) -> list: async def check_ws_permission(oid, type, resource) -> bool: + if not resource or (isinstance(resource, list) and len(resource) == 0): + return True + resource_id_list = await get_ws_resource(oid, type) if not resource_id_list: return False @@ -53,7 +57,7 @@ async def wrapper(*args, **kwargs): ) current_oid = current_user.oid - if current_user.isAdmin: + if current_user.isAdmin and not permission.type: return await func(*args, **kwargs) role_list = permission.role keyExpression = permission.keyExpression @@ -62,7 +66,7 @@ async def wrapper(*args, **kwargs): if role_list: if 'admin' in role_list and not current_user.isAdmin: raise Exception('no permission to execute, only for admin') - if 'ws_admin' in role_list and current_user.weight == 0: + if 'ws_admin' in role_list and current_user.weight == 0 and not current_user.isAdmin: raise Exception('no permission to execute, only for workspace admin') if not resource_type: return await func(*args, **kwargs) diff --git a/backend/apps/terminology/api/terminology.py b/backend/apps/terminology/api/terminology.py index 2a8cb0b4..7240b278 100644 --- a/backend/apps/terminology/api/terminology.py +++ b/backend/apps/terminology/api/terminology.py @@ -26,6 +26,7 @@ @router.get("/page/{current_page}/{page_size}", summary=f"{PLACEHOLDER_PREFIX}get_term_page") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int, word: Optional[str] = Query(None, description="搜索术语(可选)"), dslist: Optional[list[int]] = Query(None, description="数据集ID集合(可选)")): @@ -42,6 +43,7 @@ async def pager(session: SessionDep, current_user: CurrentUser, current_page: in @router.put("", summary=f"{PLACEHOLDER_PREFIX}create_or_update_term") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'], type='ds', keyExpression="info.datasource_ids")) @system_log(LogConfig(operation_type=OperationType.CREATE_OR_UPDATE, module=OperationModules.TERMINOLOGY,resource_id_expr='info.id', result_id_expr="result_self")) async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: TerminologyInfo): oid = current_user.oid