Skip to content
16 changes: 16 additions & 0 deletions src/eligibility_signposting_api/common/api_error_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,19 @@ def log_and_generate_response(
fhir_error_code=FHIRSpineErrorCode.ACCESS_DENIED,
fhir_display_message="Access has been denied to process this request.",
)

CONSUMER_ID_NOT_PROVIDED_ERROR = APIErrorResponse(
status_code=HTTPStatus.FORBIDDEN,
fhir_issue_code=FHIRIssueCode.FORBIDDEN,
fhir_issue_severity=FHIRIssueSeverity.ERROR,
fhir_error_code=FHIRSpineErrorCode.ACCESS_DENIED,
fhir_display_message="Access has been denied to process this request.",
)

CONSUMER_HAS_NO_CAMPAIGN_MAPPING = APIErrorResponse(
status_code=HTTPStatus.FORBIDDEN,
fhir_issue_code=FHIRIssueCode.FORBIDDEN,
fhir_issue_severity=FHIRIssueSeverity.ERROR,
fhir_error_code=FHIRSpineErrorCode.ACCESS_DENIED,
fhir_display_message="Access has been denied to process this request.",
)
11 changes: 10 additions & 1 deletion src/eligibility_signposting_api/common/request_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from flask.typing import ResponseReturnValue

from eligibility_signposting_api.common.api_error_response import (
CONSUMER_ID_NOT_PROVIDED_ERROR,
INVALID_CATEGORY_ERROR,
INVALID_CONDITION_FORMAT_ERROR,
INVALID_INCLUDE_ACTIONS_ERROR,
NHS_NUMBER_MISMATCH_ERROR,
)
from eligibility_signposting_api.config.constants import NHS_NUMBER_HEADER
from eligibility_signposting_api.config.constants import CONSUMER_ID, NHS_NUMBER_HEADER

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,6 +57,14 @@ def validate_request_params() -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> ResponseReturnValue: # noqa:ANN002,ANN003
consumer_id = str(request.headers.get(CONSUMER_ID))

if not consumer_id:
message = "You are not authorised to request"
return CONSUMER_ID_NOT_PROVIDED_ERROR.log_and_generate_response(
log_message=message, diagnostics=message
)

path_nhs_number = str(kwargs.get("nhs_number"))
header_nhs_no = str(request.headers.get(NHS_NUMBER_HEADER))

Expand Down
3 changes: 3 additions & 0 deletions src/eligibility_signposting_api/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
def config() -> dict[str, Any]:
person_table_name = TableName(os.getenv("PERSON_TABLE_NAME", "test_eligibility_datastore"))
rules_bucket_name = BucketName(os.getenv("RULES_BUCKET_NAME", "test-rules-bucket"))
consumer_mapping_bucket_name = BucketName(os.getenv("CONSUMER_MAPPING_BUCKET_NAME", "test-consumer-mapping-bucket"))
audit_bucket_name = BucketName(os.getenv("AUDIT_BUCKET_NAME", "test-audit-bucket"))
hashing_secret_name = HashSecretName(os.getenv("HASHING_SECRET_NAME", "test_secret"))
aws_default_region = AwsRegion(os.getenv("AWS_DEFAULT_REGION", "eu-west-1"))
Expand All @@ -41,6 +42,7 @@ def config() -> dict[str, Any]:
"s3_endpoint": None,
"rules_bucket_name": rules_bucket_name,
"audit_bucket_name": audit_bucket_name,
"consumer_mapping_bucket_name": consumer_mapping_bucket_name,
"firehose_endpoint": None,
"kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3,
"enable_xray_patching": enable_xray_patching,
Expand All @@ -59,6 +61,7 @@ def config() -> dict[str, Any]:
"s3_endpoint": URL(os.getenv("S3_ENDPOINT", local_stack_endpoint)),
"rules_bucket_name": rules_bucket_name,
"audit_bucket_name": audit_bucket_name,
"consumer_mapping_bucket_name": consumer_mapping_bucket_name,
"firehose_endpoint": URL(os.getenv("FIREHOSE_ENDPOINT", local_stack_endpoint)),
"kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3,
"enable_xray_patching": enable_xray_patching,
Expand Down
1 change: 1 addition & 0 deletions src/eligibility_signposting_api/config/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
URL_PREFIX = "patient-check"
RULE_STOP_DEFAULT = False
NHS_NUMBER_HEADER = "nhs-login-nhs-number"
CONSUMER_ID = "consumer-id"
ALLOWED_CONDITIONS = Literal["COVID", "FLU", "MMR", "RSV"]
12 changes: 12 additions & 0 deletions src/eligibility_signposting_api/model/consumer_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import NewType

from pydantic import RootModel

from eligibility_signposting_api.model.campaign_config import CampaignID

ConsumerId = NewType("ConsumerId", str)


class ConsumerMapping(RootModel[dict[ConsumerId, list[CampaignID]]]):
def get(self, key: ConsumerId, default: list[CampaignID] | None = None) -> list[CampaignID] | None:
return self.root.get(key, default)
32 changes: 32 additions & 0 deletions src/eligibility_signposting_api/repos/consumer_mapping_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
from typing import Annotated, NewType

from botocore.client import BaseClient
from wireup import Inject, service

from eligibility_signposting_api.model.campaign_config import CampaignID
from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping

BucketName = NewType("BucketName", str)


@service
class ConsumerMappingRepo:
"""Repository class for Campaign Rules, which we can use to calculate a person's eligibility for vaccination.

These rules are stored as JSON files in AWS S3."""

def __init__(
self,
s3_client: Annotated[BaseClient, Inject(qualifier="s3")],
bucket_name: Annotated[BucketName, Inject(param="consumer_mapping_bucket_name")],
) -> None:
super().__init__()
self.s3_client = s3_client
self.bucket_name = bucket_name

def get_permitted_campaign_ids(self, consumer_id: ConsumerId) -> list[CampaignID] | None:
consumer_mappings = self.s3_client.list_objects(Bucket=self.bucket_name)["Contents"][0]
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=f"{consumer_mappings['Key']}")
body = response["Body"].read()
return ConsumerMapping.model_validate(json.loads(body)).get(consumer_id)
30 changes: 28 additions & 2 deletions src/eligibility_signposting_api/services/eligibility_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from wireup import service

from eligibility_signposting_api.model import eligibility_status
from eligibility_signposting_api.model.campaign_config import CampaignConfig
from eligibility_signposting_api.model.consumer_mapping import ConsumerId
from eligibility_signposting_api.repos import CampaignRepo, NotFoundError, PersonRepo
from eligibility_signposting_api.repos.consumer_mapping_repo import ConsumerMappingRepo
from eligibility_signposting_api.services.calculators import eligibility_calculator as calculator

logger = logging.getLogger(__name__)
Expand All @@ -17,35 +20,58 @@ class InvalidQueryParamError(Exception):
pass


class NoPermittedCampaignsError(Exception):
pass


@service
class EligibilityService:
def __init__(
self,
person_repo: PersonRepo,
campaign_repo: CampaignRepo,
consumer_mapping_repo: ConsumerMappingRepo,
calculator_factory: calculator.EligibilityCalculatorFactory,
) -> None:
super().__init__()
self.person_repo = person_repo
self.campaign_repo = campaign_repo
self.calculator_factory = calculator_factory
self.consumer_mapping = consumer_mapping_repo

def get_eligibility_status(
self,
nhs_number: eligibility_status.NHSNumber,
include_actions: str,
conditions: list[str],
category: str,
consumer_id: str,
) -> eligibility_status.EligibilityStatus:
"""Calculate a person's eligibility for vaccination given an NHS number."""
if nhs_number:
try:
person_data = self.person_repo.get_eligibility_data(nhs_number)
campaign_configs = list(self.campaign_repo.get_campaign_configs())
except NotFoundError as e:
raise UnknownPersonError from e
else:
calc: calculator.EligibilityCalculator = self.calculator_factory.get(person_data, campaign_configs)
campaign_configs: list[CampaignConfig] = list(self.campaign_repo.get_campaign_configs())
permitted_campaign_configs = self.__collect_permitted_campaign_configs(
campaign_configs, ConsumerId(consumer_id)
)
calc: calculator.EligibilityCalculator = self.calculator_factory.get(
person_data, permitted_campaign_configs
)
return calc.get_eligibility_status(include_actions, conditions, category)

raise UnknownPersonError # pragma: no cover

def __collect_permitted_campaign_configs(
self, campaign_configs: list[CampaignConfig], consumer_id: ConsumerId
) -> list[CampaignConfig]:
permitted_campaign_ids = self.consumer_mapping.get_permitted_campaign_ids(ConsumerId(consumer_id))
if permitted_campaign_ids:
permitted_campaign_configs: list[CampaignConfig] = [
campaign for campaign in campaign_configs if campaign.id in permitted_campaign_ids
]
return permitted_campaign_configs
raise NoPermittedCampaignsError
33 changes: 29 additions & 4 deletions src/eligibility_signposting_api/views/eligibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@

from eligibility_signposting_api.audit.audit_context import AuditContext
from eligibility_signposting_api.audit.audit_service import AuditService
from eligibility_signposting_api.common.api_error_response import NHS_NUMBER_NOT_FOUND_ERROR
from eligibility_signposting_api.common.api_error_response import (
CONSUMER_HAS_NO_CAMPAIGN_MAPPING,
NHS_NUMBER_NOT_FOUND_ERROR,
)
from eligibility_signposting_api.common.request_validator import validate_request_params
from eligibility_signposting_api.config.constants import URL_PREFIX
from eligibility_signposting_api.config.constants import CONSUMER_ID, URL_PREFIX
from eligibility_signposting_api.model.consumer_mapping import ConsumerId
from eligibility_signposting_api.model.eligibility_status import Condition, EligibilityStatus, NHSNumber, Status
from eligibility_signposting_api.services import EligibilityService, UnknownPersonError
from eligibility_signposting_api.services.eligibility_services import NoPermittedCampaignsError
from eligibility_signposting_api.views.response_model import eligibility_response
from eligibility_signposting_api.views.response_model.eligibility_response import ProcessedSuggestion

Expand Down Expand Up @@ -47,23 +52,36 @@ def check_eligibility(
nhs_number: NHSNumber, eligibility_service: Injected[EligibilityService], audit_service: Injected[AuditService]
) -> ResponseReturnValue:
logger.info("checking nhs_number %r in %r", nhs_number, eligibility_service, extra={"nhs_number": nhs_number})

query_params = _get_or_default_query_params()
consumer_id = _get_consumer_id_from_headers()

try:
query_params = get_or_default_query_params()
eligibility_status = eligibility_service.get_eligibility_status(
nhs_number,
query_params["includeActions"],
query_params["conditions"],
query_params["category"],
consumer_id,
)
except UnknownPersonError:
return handle_unknown_person_error(nhs_number)
except NoPermittedCampaignsError:
return handle_no_permitted_campaigns_for_the_consumer_error(consumer_id)
else:
response: eligibility_response.EligibilityResponse = build_eligibility_response(eligibility_status)
AuditContext.write_to_firehose(audit_service)
return make_response(response.model_dump(by_alias=True, mode="json", exclude_none=True), HTTPStatus.OK)


def get_or_default_query_params() -> dict[str, Any]:
def _get_consumer_id_from_headers() -> ConsumerId:
"""
@validate_request_params() ensures the consumer ID is never null at this stage.
"""
return ConsumerId(request.headers.get(CONSUMER_ID, ""))


def _get_or_default_query_params() -> dict[str, Any]:
default_query_params = {"category": "ALL", "conditions": ["ALL"], "includeActions": "Y"}

if not request.args:
Expand Down Expand Up @@ -102,6 +120,13 @@ def handle_unknown_person_error(nhs_number: NHSNumber) -> ResponseReturnValue:
)


def handle_no_permitted_campaigns_for_the_consumer_error(consumer_id: ConsumerId) -> ResponseReturnValue:
diagnostics = f"Consumer ID '{consumer_id}' was not recognised by the Eligibility Signposting API"
return CONSUMER_HAS_NO_CAMPAIGN_MAPPING.log_and_generate_response(
log_message=diagnostics, diagnostics=diagnostics, location_param="id"
)


def build_eligibility_response(eligibility_status: EligibilityStatus) -> eligibility_response.EligibilityResponse:
"""Return an object representing the API response we are going to send, given an evaluation of the person's
eligibility."""
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/builders/model/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class IterationFactory(ModelFactory[Iteration]):

class RawCampaignConfigFactory(ModelFactory[CampaignConfig]):
iterations = Use(IterationFactory.batch, size=2)

id = "42-hi5tch-hi5kers-gu5ide-t2o-t3he-gal6axy"
start_date = Use(past_date)
end_date = Use(future_date)

Expand Down
49 changes: 49 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from eligibility_signposting_api.model.campaign_config import (
AvailableAction,
CampaignConfig,
CampaignID,
EndDate,
RuleCode,
RuleEntry,
Expand All @@ -30,6 +31,7 @@
StartDate,
StatusText,
)
from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping
from eligibility_signposting_api.processors.hashing_service import HashingService, HashSecretName
from eligibility_signposting_api.repos import SecretRepo
from eligibility_signposting_api.repos.campaign_repo import BucketName
Expand Down Expand Up @@ -661,6 +663,14 @@ def rules_bucket(s3_client: BaseClient) -> Generator[BucketName]:
s3_client.delete_bucket(Bucket=bucket_name)


@pytest.fixture(scope="session")
def consumer_mapping_bucket(s3_client: BaseClient) -> Generator[BucketName]:
bucket_name = BucketName(os.getenv("CONSUMER_MAPPING_BUCKET_NAME", "test-consumer-mapping-bucket"))
s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": AWS_REGION})
yield bucket_name
s3_client.delete_bucket(Bucket=bucket_name)


@pytest.fixture(scope="session")
def audit_bucket(s3_client: BaseClient) -> Generator[BucketName]:
bucket_name = BucketName(os.getenv("AUDIT_BUCKET_NAME", "test-audit-bucket"))
Expand Down Expand Up @@ -719,6 +729,45 @@ def campaign_config(s3_client: BaseClient, rules_bucket: BucketName) -> Generato
s3_client.delete_object(Bucket=rules_bucket, Key=f"{campaign.name}.json")


@pytest.fixture(scope="class")
def consumer_mapping(s3_client: BaseClient, consumer_mapping_bucket: BucketName) -> Generator[ConsumerMapping]:
consumer_mapping = ConsumerMapping.model_validate({})
consumer_mapping.root[ConsumerId("23-mic7heal-jor6don")] = [CampaignID("42-hi5tch-hi5kers-gu5ide-t2o-t3he-gal6axy")]

consumer_mapping_data = consumer_mapping.model_dump(by_alias=True)
s3_client.put_object(
Bucket=consumer_mapping_bucket,
Key="consumer_mapping.json",
Body=json.dumps(consumer_mapping_data),
ContentType="application/json",
)
yield consumer_mapping
s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json")


@pytest.fixture(scope="class")
def consumer_mapping_with_various_targets(
s3_client: BaseClient, consumer_mapping_bucket: BucketName
) -> Generator[ConsumerMapping]:
consumer_mapping = ConsumerMapping.model_validate({})
consumer_mapping.root[ConsumerId("23-mic7heal-jor6don")] = [
CampaignID("campaign_start_date"),
CampaignID("campaign_start_date_plus_one_day"),
CampaignID("campaign_today"),
CampaignID("campaign_tomorrow"),
]

consumer_mapping_data = consumer_mapping.model_dump(by_alias=True)
s3_client.put_object(
Bucket=consumer_mapping_bucket,
Key="consumer_mapping.json",
Body=json.dumps(consumer_mapping_data),
ContentType="application/json",
)
yield consumer_mapping
s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json")


@pytest.fixture
def campaign_config_with_rules_having_rule_code(
s3_client: BaseClient, rules_bucket: BucketName
Expand Down
Loading