Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ class VariablePostBody(StrictBaseModel):

value: str | None = Field(alias="val")
description: str | None = Field(default=None)


class VariableKeysResponse(StrictBaseModel):
"""Variable keys list response."""

keys: list[str]
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status

from airflow.api_fastapi.execution_api.datamodels.variable import (
VariableKeysResponse,
VariablePostBody,
VariableResponse,
)
Expand All @@ -47,20 +48,33 @@ async def has_variable_access(
return True


router = APIRouter(
responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}},
dependencies=[Depends(has_variable_access)],
)
router = APIRouter(responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}})

log = logging.getLogger(__name__)


@router.get(
"/keys/list",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def list_keys(
prefix: str | None = None,
team_name: Annotated[str | None, Depends(get_team_name_dep)] = None,
token=JWTBearerDep,
) -> VariableKeysResponse:
return VariableKeysResponse(keys=Variable.list_keys(prefix=prefix, team_name=team_name))


@router.get(
"/{variable_key:path}",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
dependencies=[Depends(has_variable_access)],
)
def get_variable(
variable_key: str, team_name: Annotated[str | None, Depends(get_team_name_dep)]
Expand Down Expand Up @@ -90,6 +104,7 @@ def get_variable(
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
dependencies=[Depends(has_variable_access)],
)
def put_variable(
variable_key: str, body: VariablePostBody, team_name: Annotated[str | None, Depends(get_team_name_dep)]
Expand Down
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
GetTaskStates,
GetTICount,
GetVariable,
GetVariableKeys,
GetXCom,
MaskSecret,
OKResponse,
Expand All @@ -73,6 +74,7 @@
TaskStatesResult,
TICount,
UpdateHITLDetail,
VariableKeysResult,
VariableResult,
XComResult,
_new_encoder,
Expand Down Expand Up @@ -253,6 +255,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
| messages.TriggerStateSync
| ConnectionResult
| VariableResult
| VariableKeysResult
| XComResult
| DagRunStateResult
| DRCount
Expand All @@ -275,6 +278,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
| DeleteVariable
| GetVariable
| PutVariable
| GetVariableKeys
| DeleteXCom
| GetXCom
| SetXCom
Expand Down Expand Up @@ -463,6 +467,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
resp = var
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, GetVariableKeys):
resp = self.client.variables.list_keys(msg.prefix)
elif isinstance(msg, DeleteXCom):
self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
elif isinstance(msg, GetXCom):
Expand Down
38 changes: 38 additions & 0 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,44 @@ def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | N
SecretCache.save_variable(key, var_val, team_name=team_name) # we save None as well
return var_val

@staticmethod
def list_keys(
prefix: str | None = None, team_name: str | None = None, session: Session | None = None
) -> list[str]:
"""
List variable keys, optionally filtered by prefix.

:param prefix: Prefix to filter keys
:param team_name: Team name filter
"""
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
warnings.warn(
"Using Variable.list_keys from `airflow.models` is deprecated. "
"Please use `list_keys` on Variable from sdk(`airflow.sdk.Variable`) instead",
DeprecationWarning,
stacklevel=1,
)
from airflow.sdk import Variable as TaskSDKVariable

return TaskSDKVariable.list_keys(prefix=prefix)

if team_name and not conf.getboolean("core", "multi_team"):
raise ValueError("Multi-team mode is not configured in the Airflow environment")

ctx: contextlib.AbstractContextManager
if session is not None:
ctx = contextlib.nullcontext(session)
else:
ctx = create_session()
with ctx as session:
stmt = select(Variable.key)
if team_name:
stmt = stmt.where(or_(Variable.team_name == team_name, Variable.team_name.is_(None)))
if prefix:
stmt = stmt.where(Variable.key.like(f"{prefix}%"))

return list(session.scalars(stmt).all())

@staticmethod
@provide_session
def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,29 @@ def test_should_not_delete_variable(self, client, session):

vars = session.scalars(select(Variable)).all()
assert len(vars) == 1


class TestGetVariableKeys:
def test_get_all_keys(self, client, session):
Variable.set(key="key1", value="value")
Variable.set(key="key2", value="value")
response = client.get("/execution/variables/keys/list")
keys = session.scalars(select(Variable.key)).all()
response_json = response.json()
assert len(response_json.get("keys")) == 2
assert sorted(response_json.get("keys")) == sorted(keys)

def test_get_keys_by_prefix(self, client):
Variable.set(key="key1", value="value")
Variable.set(key="test_key", value="value")
Variable.set(key="test_key2", value="value")
response = client.get("/execution/variables/keys/list", params={"prefix": "test_"})
response_json = response.json()
assert len(response_json.get("keys")) == 2
assert sorted(response_json.get("keys")) == ["test_key", "test_key2"]
assert "key1" not in response_json

def test_no_keys_with_prefix(self, client):
Variable.set(key="key1", value="value")
response = client.get("/execution/variables/keys/list", params={"prefix": "api_"})
assert not response.json().get("keys")
9 changes: 9 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
TITerminalStatePayload,
TriggerDAGRunPayload,
ValidationError as RemoteValidationError,
VariableKeysResponse,
VariablePostBody,
VariableResponse,
XComResponse,
Expand Down Expand Up @@ -453,6 +454,14 @@ def delete(
# decouple from the server response string
return OKResponse(ok=True)

def list_keys(self, prefix: str | None = None) -> VariableKeysResponse:
"""List variable keys from the API server."""
params = {}
if prefix is not None:
params["prefix"] = prefix
resp = self.client.get("variables/keys/list", params=params)
return VariableKeysResponse.model_validate_json(resp.read())


class XComOperations:
__slots__ = ("client",)
Expand Down
11 changes: 11 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ class ValidationError(BaseModel):
ctx: Annotated[dict[str, Any] | None, Field(title="Context")] = None


class VariableKeysResponse(BaseModel):
"""
Variable keys list response.
"""

model_config = ConfigDict(
extra="forbid",
)
keys: Annotated[list[str], Field(title="Keys")]


class VariablePostBody(BaseModel):
"""
Request body schema for creating variables.
Expand Down
10 changes: 10 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,13 @@ def delete(cls, key: str) -> None:
_delete_variable(key=key)
except AirflowRuntimeError as e:
log.exception(e)

@classmethod
def list_keys(cls, prefix: str | None = None):
from airflow.sdk.exceptions import AirflowRuntimeError
from airflow.sdk.execution_time.context import _list_variable_keys

try:
return _list_variable_keys(prefix=prefix)
except AirflowRuntimeError as e:
log.exception(e)
16 changes: 16 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
TISuccessStatePayload,
TriggerDAGRunPayload,
UpdateHITLDetailPayload,
VariableKeysResponse,
VariableResponse,
XComResponse,
XComSequenceIndexResponse,
Expand Down Expand Up @@ -544,6 +545,14 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable
return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")


class VariableKeysResult(VariableKeysResponse):
type: Literal["VariableKeysResult"] = "VariableKeysResult"

@classmethod
def from_api_response(cls, response: VariableKeysResponse) -> VariableKeysResult:
return cls(**response.model_dump(exclude_defaults=True), type="VariableKeysResult")


class DagRunResult(DagRun):
type: Literal["DagRunResult"] = "DagRunResult"

Expand Down Expand Up @@ -711,6 +720,7 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest
| TaskBreadcrumbsResult
| TaskStatesResult
| VariableResult
| VariableKeysResult
| XComCountResponse
| XComResult
| XComSequenceIndexResult
Expand Down Expand Up @@ -856,6 +866,11 @@ class DeleteVariable(BaseModel):
type: Literal["DeleteVariable"] = "DeleteVariable"


class GetVariableKeys(BaseModel):
prefix: str | None = None
type: Literal["GetVariableKeys"] = "GetVariableKeys"


class ResendLoggingFD(BaseModel):
type: Literal["ResendLoggingFD"] = "ResendLoggingFD"

Expand Down Expand Up @@ -1021,6 +1036,7 @@ class MaskSecret(BaseModel):
ToSupervisor = Annotated[
DeferTask
| DeleteXCom
| GetVariableKeys
| GetAssetByName
| GetAssetByUri
| GetAssetEventByAsset
Expand Down
20 changes: 20 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,26 @@ def _delete_variable(key: str) -> None:
SecretCache.invalidate_variable(key)


def _list_variable_keys(prefix: str | None = None) -> list[str]:
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded

backends = ensure_secrets_backend_loaded()
all_keys: set[str] = set()
for secrets_backend in backends:
try:
if hasattr(secrets_backend, "list_variable_keys"):
keys = secrets_backend.list_variable_keys(prefix=prefix)
if keys:
all_keys.update(keys)
except Exception:
log.exception(
"Unable to list variable keys from secrets backend (%s).",
type(secrets_backend).__name__,
)

return list(all_keys)


class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None:
# to allow fallback to other backends
return None

def list_variable_keys(self, prefix: str | None = None, team_name: str | None = None) -> list | None:
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariableKeys, VariableKeysResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

try:
msg = SUPERVISOR_COMMS.send(GetVariableKeys(prefix=prefix))
if isinstance(msg, ErrorResponse):
return None
if isinstance(msg, VariableKeysResult):
return msg.keys
except Exception:
# If SUPERVISOR_COMMS fails for any reason, return None
# to allow fallback to other backends
return None

async def aget_connection(self, conn_id: str) -> Connection | None: # type: ignore[override]
"""
Return connection object asynchronously via SUPERVISOR_COMMS.
Expand Down
5 changes: 5 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
GetTaskStates,
GetTICount,
GetVariable,
GetVariableKeys,
GetXCom,
GetXComCount,
GetXComSequenceItem,
Expand All @@ -113,6 +114,7 @@
ToSupervisor,
TriggerDagRun,
ValidateInletsAndOutlets,
VariableKeysResult,
VariableResult,
XComResult,
XComSequenceIndexResult,
Expand Down Expand Up @@ -1276,6 +1278,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
dump_opts = {"exclude_unset": True}
else:
resp = var
elif isinstance(msg, GetVariableKeys):
keys_response = self.client.variables.list_keys(msg.prefix)
resp = VariableKeysResult.from_api_response(keys_response)
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates
Expand Down
12 changes: 12 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
GetTaskStates,
GetTICount,
GetVariable,
GetVariableKeys,
GetXCom,
GetXComCount,
GetXComSequenceItem,
Expand Down Expand Up @@ -125,6 +126,7 @@
TriggerDagRun,
UpdateHITLDetail,
ValidateInletsAndOutlets,
VariableKeysResult,
VariableResult,
XComCountResponse,
XComResult,
Expand Down Expand Up @@ -1474,6 +1476,16 @@ class RequestTestCase:
response=OKResponse(ok=True),
),
),
RequestTestCase(
message=GetVariableKeys(prefix="test_"),
test_id="get_variable",
client_mock=ClientMock(
method_path="variables.list_keys",
args=("test_",),
response=VariableKeysResult(keys=["test_key"]),
),
expected_body={"keys": ["test_key"], "type": "VariableKeysResult"},
),
RequestTestCase(
message=DeleteVariable(key="test_key"),
test_id="delete_variable",
Expand Down
Loading