diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py index fd49a5eae46d6..96be8b413e457 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/variable.py @@ -34,3 +34,9 @@ class VariablePostBody(StrictBaseModel): value: str | None = Field(alias="val") description: str | None = Field(default=None) + + +class VariableKeysResponse(StrictBaseModel): + """Variable keys list response.""" + + keys: list[str] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py index 5621b6cd081ba..d4f1432da84ab 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py @@ -23,6 +23,7 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Request, status from airflow.api_fastapi.execution_api.datamodels.variable import ( + VariableKeysResponse, VariablePostBody, VariableResponse, ) @@ -47,20 +48,33 @@ async def has_variable_access( return True -router = APIRouter( - responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}}, - dependencies=[Depends(has_variable_access)], -) +router = APIRouter(responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}}) log = logging.getLogger(__name__) +@router.get( + "/keys/list", + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"}, + }, +) +def list_keys( + prefix: str | None = None, + team_name: Annotated[str | None, Depends(get_team_name_dep)] = None, + token=JWTBearerDep, +) -> VariableKeysResponse: + return VariableKeysResponse(keys=Variable.list_keys(prefix=prefix, team_name=team_name)) + + @router.get( "/{variable_key:path}", responses={ status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"}, }, + dependencies=[Depends(has_variable_access)], ) def get_variable( variable_key: str, team_name: Annotated[str | None, Depends(get_team_name_dep)] @@ -90,6 +104,7 @@ def get_variable( status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"}, }, + dependencies=[Depends(has_variable_access)], ) def put_variable( variable_key: str, body: VariablePostBody, team_name: Annotated[str | None, Depends(get_team_name_dep)] diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 57660c6820eee..6c17c4aafb970 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -65,6 +65,7 @@ GetTaskStates, GetTICount, GetVariable, + GetVariableKeys, GetXCom, MaskSecret, OKResponse, @@ -73,6 +74,7 @@ TaskStatesResult, TICount, UpdateHITLDetail, + VariableKeysResult, VariableResult, XComResult, _new_encoder, @@ -253,6 +255,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | messages.TriggerStateSync | ConnectionResult | VariableResult + | VariableKeysResult | XComResult | DagRunStateResult | DRCount @@ -275,6 +278,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | DeleteVariable | GetVariable | PutVariable + | GetVariableKeys | DeleteXCom | GetXCom | SetXCom @@ -463,6 +467,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = var elif isinstance(msg, PutVariable): self.client.variables.set(msg.key, msg.value, msg.description) + elif isinstance(msg, GetVariableKeys): + resp = self.client.variables.list_keys(msg.prefix) elif isinstance(msg, DeleteXCom): self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) elif isinstance(msg, GetXCom): diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 5435326de5417..454140db59203 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -524,6 +524,44 @@ def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | N SecretCache.save_variable(key, var_val, team_name=team_name) # we save None as well return var_val + @staticmethod + def list_keys( + prefix: str | None = None, team_name: str | None = None, session: Session | None = None + ) -> list[str]: + """ + List variable keys, optionally filtered by prefix. + + :param prefix: Prefix to filter keys + :param team_name: Team name filter + """ + if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + warnings.warn( + "Using Variable.list_keys from `airflow.models` is deprecated. " + "Please use `list_keys` on Variable from sdk(`airflow.sdk.Variable`) instead", + DeprecationWarning, + stacklevel=1, + ) + from airflow.sdk import Variable as TaskSDKVariable + + return TaskSDKVariable.list_keys(prefix=prefix) + + if team_name and not conf.getboolean("core", "multi_team"): + raise ValueError("Multi-team mode is not configured in the Airflow environment") + + ctx: contextlib.AbstractContextManager + if session is not None: + ctx = contextlib.nullcontext(session) + else: + ctx = create_session() + with ctx as session: + stmt = select(Variable.key) + if team_name: + stmt = stmt.where(or_(Variable.team_name == team_name, Variable.team_name.is_(None))) + if prefix: + stmt = stmt.where(Variable.key.like(f"{prefix}%")) + + return list(session.scalars(stmt).all()) + @staticmethod @provide_session def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None: diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py index 59b206441dea6..ceb8b0128b2fd 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py @@ -276,3 +276,29 @@ def test_should_not_delete_variable(self, client, session): vars = session.scalars(select(Variable)).all() assert len(vars) == 1 + + +class TestGetVariableKeys: + def test_get_all_keys(self, client, session): + Variable.set(key="key1", value="value") + Variable.set(key="key2", value="value") + response = client.get("/execution/variables/keys/list") + keys = session.scalars(select(Variable.key)).all() + response_json = response.json() + assert len(response_json.get("keys")) == 2 + assert sorted(response_json.get("keys")) == sorted(keys) + + def test_get_keys_by_prefix(self, client): + Variable.set(key="key1", value="value") + Variable.set(key="test_key", value="value") + Variable.set(key="test_key2", value="value") + response = client.get("/execution/variables/keys/list", params={"prefix": "test_"}) + response_json = response.json() + assert len(response_json.get("keys")) == 2 + assert sorted(response_json.get("keys")) == ["test_key", "test_key2"] + assert "key1" not in response_json + + def test_no_keys_with_prefix(self, client): + Variable.set(key="key1", value="value") + response = client.get("/execution/variables/keys/list", params={"prefix": "api_"}) + assert not response.json().get("keys") diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 2a38ef2fad7b3..378be7eb69eb9 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -68,6 +68,7 @@ TITerminalStatePayload, TriggerDAGRunPayload, ValidationError as RemoteValidationError, + VariableKeysResponse, VariablePostBody, VariableResponse, XComResponse, @@ -453,6 +454,14 @@ def delete( # decouple from the server response string return OKResponse(ok=True) + def list_keys(self, prefix: str | None = None) -> VariableKeysResponse: + """List variable keys from the API server.""" + params = {} + if prefix is not None: + params["prefix"] = prefix + resp = self.client.get("variables/keys/list", params=params) + return VariableKeysResponse.model_validate_json(resp.read()) + class XComOperations: __slots__ = ("client",) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 3c2099ffbf5b0..9ceb0a51d5c1c 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -376,6 +376,17 @@ class ValidationError(BaseModel): ctx: Annotated[dict[str, Any] | None, Field(title="Context")] = None +class VariableKeysResponse(BaseModel): + """ + Variable keys list response. + """ + + model_config = ConfigDict( + extra="forbid", + ) + keys: Annotated[list[str], Field(title="Keys")] + + class VariablePostBody(BaseModel): """ Request body schema for creating variables. diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 2e4c9aae3ca0f..13e73162696b0 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -76,3 +76,13 @@ def delete(cls, key: str) -> None: _delete_variable(key=key) except AirflowRuntimeError as e: log.exception(e) + + @classmethod + def list_keys(cls, prefix: str | None = None): + from airflow.sdk.exceptions import AirflowRuntimeError + from airflow.sdk.execution_time.context import _list_variable_keys + + try: + return _list_variable_keys(prefix=prefix) + except AirflowRuntimeError as e: + log.exception(e) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 15755e640d97e..c08fd5c77a0b4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -89,6 +89,7 @@ TISuccessStatePayload, TriggerDAGRunPayload, UpdateHITLDetailPayload, + VariableKeysResponse, VariableResponse, XComResponse, XComSequenceIndexResponse, @@ -544,6 +545,14 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult") +class VariableKeysResult(VariableKeysResponse): + type: Literal["VariableKeysResult"] = "VariableKeysResult" + + @classmethod + def from_api_response(cls, response: VariableKeysResponse) -> VariableKeysResult: + return cls(**response.model_dump(exclude_defaults=True), type="VariableKeysResult") + + class DagRunResult(DagRun): type: Literal["DagRunResult"] = "DagRunResult" @@ -711,6 +720,7 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest | TaskBreadcrumbsResult | TaskStatesResult | VariableResult + | VariableKeysResult | XComCountResponse | XComResult | XComSequenceIndexResult @@ -856,6 +866,11 @@ class DeleteVariable(BaseModel): type: Literal["DeleteVariable"] = "DeleteVariable" +class GetVariableKeys(BaseModel): + prefix: str | None = None + type: Literal["GetVariableKeys"] = "GetVariableKeys" + + class ResendLoggingFD(BaseModel): type: Literal["ResendLoggingFD"] = "ResendLoggingFD" @@ -1021,6 +1036,7 @@ class MaskSecret(BaseModel): ToSupervisor = Annotated[ DeferTask | DeleteXCom + | GetVariableKeys | GetAssetByName | GetAssetByUri | GetAssetEventByAsset diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index db5a75e10c18d..5eae00793c7fe 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -341,6 +341,26 @@ def _delete_variable(key: str) -> None: SecretCache.invalidate_variable(key) +def _list_variable_keys(prefix: str | None = None) -> list[str]: + from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded + + backends = ensure_secrets_backend_loaded() + all_keys: set[str] = set() + for secrets_backend in backends: + try: + if hasattr(secrets_backend, "list_variable_keys"): + keys = secrets_backend.list_variable_keys(prefix=prefix) + if keys: + all_keys.update(keys) + except Exception: + log.exception( + "Unable to list variable keys from secrets backend (%s).", + type(secrets_backend).__name__, + ) + + return list(all_keys) + + class ConnectionAccessor: """Wrapper to access Connection entries in template.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py index a44b23d06dc6d..7d812132039fd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py @@ -118,6 +118,21 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: # to allow fallback to other backends return None + def list_variable_keys(self, prefix: str | None = None, team_name: str | None = None) -> list | None: + from airflow.sdk.execution_time.comms import ErrorResponse, GetVariableKeys, VariableKeysResult + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + try: + msg = SUPERVISOR_COMMS.send(GetVariableKeys(prefix=prefix)) + if isinstance(msg, ErrorResponse): + return None + if isinstance(msg, VariableKeysResult): + return msg.keys + except Exception: + # If SUPERVISOR_COMMS fails for any reason, return None + # to allow fallback to other backends + return None + async def aget_connection(self, conn_id: str) -> Connection | None: # type: ignore[override] """ Return connection object asynchronously via SUPERVISOR_COMMS. diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index b87131aa7336d..02cd2ea6b288d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -88,6 +88,7 @@ GetTaskStates, GetTICount, GetVariable, + GetVariableKeys, GetXCom, GetXComCount, GetXComSequenceItem, @@ -113,6 +114,7 @@ ToSupervisor, TriggerDagRun, ValidateInletsAndOutlets, + VariableKeysResult, VariableResult, XComResult, XComSequenceIndexResult, @@ -1276,6 +1278,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dump_opts = {"exclude_unset": True} else: resp = var + elif isinstance(msg, GetVariableKeys): + keys_response = self.client.variables.list_keys(msg.prefix) + resp = VariableKeysResult.from_api_response(keys_response) elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get( msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 15e412d552adc..8e7a086d61eeb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -95,6 +95,7 @@ GetTaskStates, GetTICount, GetVariable, + GetVariableKeys, GetXCom, GetXComCount, GetXComSequenceItem, @@ -125,6 +126,7 @@ TriggerDagRun, UpdateHITLDetail, ValidateInletsAndOutlets, + VariableKeysResult, VariableResult, XComCountResponse, XComResult, @@ -1474,6 +1476,16 @@ class RequestTestCase: response=OKResponse(ok=True), ), ), + RequestTestCase( + message=GetVariableKeys(prefix="test_"), + test_id="get_variable", + client_mock=ClientMock( + method_path="variables.list_keys", + args=("test_",), + response=VariableKeysResult(keys=["test_key"]), + ), + expected_body={"keys": ["test_key"], "type": "VariableKeysResult"}, + ), RequestTestCase( message=DeleteVariable(key="test_key"), test_id="delete_variable",