diff --git a/aieng-eval-agents/aieng/agent_evals/aml_investigation/data/cases.py b/aieng-eval-agents/aieng/agent_evals/aml_investigation/data/cases.py index 9b883fb..0ba833c 100644 --- a/aieng-eval-agents/aieng/agent_evals/aml_investigation/data/cases.py +++ b/aieng-eval-agents/aieng/agent_evals/aml_investigation/data/cases.py @@ -106,9 +106,9 @@ class AnalystOutput(BaseModel): class CaseRecord(BaseModel): """Combined case file and ground truth record.""" - case: CaseFile = Field(..., description="Metadata for the laundering case.") - groundtruth: GroundTruth = Field(..., description="Ground truth information for the laundering case.") - analysis: AnalystOutput | None = Field( + input: CaseFile = Field(..., description="Metadata for the laundering case.") + expected_output: GroundTruth = Field(..., description="Ground truth information for the laundering case.") + output: AnalystOutput | None = Field( default=None, description="Optional analyst output for the laundering case. Typically populated after investigation.", ) @@ -246,7 +246,7 @@ def build_cases( laundering_attempt_txn_ids: set[str] = set() for case in laundering_cases: attempt_ids_list = [] - attempt_ids_str = case.groundtruth.attempt_transaction_ids + attempt_ids_str = case.expected_output.attempt_transaction_ids if attempt_ids_str: attempt_ids_list = [item.strip() for item in attempt_ids_str.split(",") if item.strip()] laundering_attempt_txn_ids.update(attempt_ids_list) @@ -257,7 +257,7 @@ def build_cases( false_positive_cases = _build_false_positive_cases(transactions, num_false_positive_cases) - fp_seed_ids = {case.case.seed_transaction_id for case in false_positive_cases} + fp_seed_ids = {case.input.seed_transaction_id for case in false_positive_cases} normal_cases = _build_normal_cases(transactions, num_normal_cases, fp_seed_ids, lookback_days, min_timestamp) return laundering_cases + false_negative_cases + false_positive_cases + normal_cases @@ -367,7 +367,7 @@ def _finalize_attempt_block( pattern_description=current["pattern_description"], attempt_transaction_ids=attempt_ids, ) - return CaseRecord(case=case_file, groundtruth=groundtruth) + return CaseRecord(input=case_file, expected_output=groundtruth) def _build_false_negative_cases( @@ -377,24 +377,24 @@ def _build_false_negative_cases( false_negative_cases: list[CaseRecord] = [] fn_attempts = remaining_attempts[:num_false_negative_cases] for case in fn_attempts: - if case.case.seed_transaction_id in laundering_attempt_txn_ids: + if case.input.seed_transaction_id in laundering_attempt_txn_ids: continue case_file = CaseFile( - case_id=_create_id(f"fn:{case.case.case_id}"), - seed_transaction_id=case.case.seed_transaction_id, - seed_timestamp=case.case.seed_timestamp, - window_start=_date_window_start(case.case.seed_timestamp) - if case.case.window_start == case.case.seed_timestamp - else case.case.window_start, + case_id=_create_id(f"fn:{case.input.case_id}"), + seed_transaction_id=case.input.seed_transaction_id, + seed_timestamp=case.input.seed_timestamp, + window_start=_date_window_start(case.input.seed_timestamp) + if case.input.window_start == case.input.seed_timestamp + else case.input.window_start, trigger_label=random.choice(_LOW_SIGNAL_REVIEW_LABELS), ) groundtruth = GroundTruth( - is_laundering=case.groundtruth.is_laundering, - pattern_type=case.groundtruth.pattern_type, - pattern_description=case.groundtruth.pattern_description, - attempt_transaction_ids=case.groundtruth.attempt_transaction_ids, + is_laundering=case.expected_output.is_laundering, + pattern_type=case.expected_output.pattern_type, + pattern_description=case.expected_output.pattern_description, + attempt_transaction_ids=case.expected_output.attempt_transaction_ids, ) - false_negative_cases.append(CaseRecord(case=case_file, groundtruth=groundtruth)) + false_negative_cases.append(CaseRecord(input=case_file, expected_output=groundtruth)) if len(false_negative_cases) < num_false_negative_cases: logger.warning( @@ -444,7 +444,7 @@ def _build_false_positive_cases(transc_df: pd.DataFrame, num_false_positive_case pattern_description="Normal transaction", attempt_transaction_ids="", ) - false_positive_cases.append(CaseRecord(case=case_file, groundtruth=groundtruth)) + false_positive_cases.append(CaseRecord(input=case_file, expected_output=groundtruth)) return false_positive_cases @@ -484,7 +484,7 @@ def _build_normal_cases( pattern_description="Normal transaction", attempt_transaction_ids="", ) - normal_cases.append(CaseRecord(case=case_file, groundtruth=groundtruth)) + normal_cases.append(CaseRecord(input=case_file, expected_output=groundtruth)) return normal_cases diff --git a/aieng-eval-agents/tests/aieng/agent_evals/aml_investigation/data/test_cases.py b/aieng-eval-agents/tests/aieng/agent_evals/aml_investigation/data/test_cases.py index 2b0e16c..d39a724 100644 --- a/aieng-eval-agents/tests/aieng/agent_evals/aml_investigation/data/test_cases.py +++ b/aieng-eval-agents/tests/aieng/agent_evals/aml_investigation/data/test_cases.py @@ -99,17 +99,19 @@ def test_parse_patterns_file_parses_attempts_and_sets_seed_and_window(patterns_f } for record in cases: - assert record.groundtruth.is_laundering is True - assert record.case.seed_transaction_id - assert record.case.case_id - assert record.case.window_start == min_timestamp + assert record.expected_output.is_laundering is True + assert record.input.seed_transaction_id + assert record.input.case_id + assert record.input.window_start == min_timestamp - seed_timestamp, expected_count = expected_by_pattern[record.groundtruth.pattern_type] - assert record.case.seed_timestamp == seed_timestamp + seed_timestamp, expected_count = expected_by_pattern[record.expected_output.pattern_type] + assert record.input.seed_timestamp == seed_timestamp - attempt_ids = [item.strip() for item in record.groundtruth.attempt_transaction_ids.split(",") if item.strip()] + attempt_ids = [ + item.strip() for item in record.expected_output.attempt_transaction_ids.split(",") if item.strip() + ] assert len(attempt_ids) == expected_count - assert attempt_ids[-1] == record.case.seed_transaction_id + assert attempt_ids[-1] == record.input.seed_transaction_id def test_parse_patterns_file_rejects_negative_lookback(patterns_file: Path) -> None: @@ -175,36 +177,38 @@ def test_build_cases_builds_each_case_type(patterns_file: Path, transactions_df: laundering = [ case for case in cases - if case.groundtruth.is_laundering and case.case.trigger_label == case.groundtruth.pattern_type.value + if case.expected_output.is_laundering and case.input.trigger_label == case.expected_output.pattern_type.value ] false_negatives = [ - case for case in cases if case.groundtruth.is_laundering and case.case.trigger_label in low_signal_labels + case for case in cases if case.expected_output.is_laundering and case.input.trigger_label in low_signal_labels ] false_positives = [ case for case in cases - if (not case.groundtruth.is_laundering) and case.case.trigger_label not in low_signal_labels + if (not case.expected_output.is_laundering) and case.input.trigger_label not in low_signal_labels ] normals = [ - case for case in cases if (not case.groundtruth.is_laundering) and case.case.trigger_label in low_signal_labels + case + for case in cases + if (not case.expected_output.is_laundering) and case.input.trigger_label in low_signal_labels ] assert len(laundering) == 1 - assert laundering[0].groundtruth.pattern_type in { + assert laundering[0].expected_output.pattern_type in { LaunderingPattern.CYCLE, LaunderingPattern.STACK, LaunderingPattern.GATHER_SCATTER, } assert len(false_negatives) == 1 - assert false_negatives[0].groundtruth.is_laundering is True - assert false_negatives[0].groundtruth.pattern_type != LaunderingPattern.NONE + assert false_negatives[0].expected_output.is_laundering is True + assert false_negatives[0].expected_output.pattern_type != LaunderingPattern.NONE assert len(false_positives) == 1 - assert false_positives[0].groundtruth.pattern_type == LaunderingPattern.NONE + assert false_positives[0].expected_output.pattern_type == LaunderingPattern.NONE assert len(normals) == 1 - assert normals[0].groundtruth.pattern_type == LaunderingPattern.NONE + assert normals[0].expected_output.pattern_type == LaunderingPattern.NONE @pytest.mark.parametrize( diff --git a/implementations/aml_investigation/agent.py b/implementations/aml_investigation/agent.py index 8fa34b4..17f9e7e 100644 --- a/implementations/aml_investigation/agent.py +++ b/implementations/aml_investigation/agent.py @@ -163,12 +163,12 @@ def _write_results(output_path: Path, input_records: list[CaseRecord], results_b with tmp_path.open("w", encoding="utf-8") as outfile: for record in input_records: - case_id = record.case.case_id + case_id = record.input.case_id if case_id in written: continue written.add(case_id) out_record = results_by_id.get(case_id, record) - analyzed += int(out_record.analysis is not None) + analyzed += int(out_record.output is not None) outfile.write(out_record.model_dump_json() + "\n") tmp_path.replace(output_path) @@ -178,7 +178,7 @@ def _write_results(output_path: Path, input_records: list[CaseRecord], results_b async def _analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord: """Run the agent on one case and attach the validated AnalystOutput.""" message = google.genai.types.Content( - role="user", parts=[google.genai.types.Part(text=record.case.model_dump_json())] + role="user", parts=[google.genai.types.Part(text=record.input.model_dump_json())] ) events_async = runner.run_async(session_id=str(uuid.uuid4()), user_id=getpass.getuser(), new_message=message) @@ -188,10 +188,10 @@ async def _analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord: final_text = "".join(part.text or "" for part in event.content.parts if part.text) if not final_text: - logger.warning("No analyst output produced for case_id=%s", record.case.case_id) + logger.warning("No analyst output produced for case_id=%s", record.input.case_id) return record - record.analysis = AnalystOutput.model_validate(_extract_json(final_text.strip())) + record.output = AnalystOutput.model_validate(_extract_json(final_text.strip())) return record @@ -200,7 +200,7 @@ async def _safe_analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord: try: return await _analyze_case(runner, record) except Exception as exc: - logger.exception("Case failed (case_id=%s): %s", record.case.case_id, exc) + logger.exception("Case failed (case_id=%s): %s", record.input.case_id, exc) return record @@ -234,7 +234,7 @@ async def _analyze_cases_to_jsonl( for finished in asyncio.as_completed(tasks): record = await finished - analyzed_by_id[record.case.case_id] = record + analyzed_by_id[record.input.case_id] = record outfile.write(record.model_dump_json() + "\n") outfile.flush() os.fsync(outfile.fileno()) @@ -253,8 +253,8 @@ async def _main() -> None: output_path.parent.mkdir(parents=True, exist_ok=True) input_records = _load_records(input_path) - existing_results = {record.case.case_id: record for record in _load_records(output_path)} - to_run = [r for r in input_records if existing_results.get(r.case.case_id, r).analysis is None] + existing_results = {record.input.case_id: record for record in _load_records(output_path)} + to_run = [r for r in input_records if existing_results.get(r.input.case_id, r).output is None] logger.info("Resume: %d/%d done; %d remaining.", len(input_records) - len(to_run), len(input_records), len(to_run)) @@ -272,16 +272,16 @@ async def _main() -> None: analyzed_count = _write_results(output_path, input_records, existing_results) logger.info("Wrote %d analyzed cases to %s", analyzed_count, output_path) - final_records = [existing_results.get(r.case.case_id, r) for r in input_records] - scored = [r for r in final_records if r.analysis is not None] + final_records = [existing_results.get(r.input.case_id, r) for r in input_records] + scored = [r for r in final_records if r.output is not None] if not scored: logger.info("Metrics: N/A (no analyzed cases)") else: tp = fp = fn = tn = 0 for r in scored: - gt = r.groundtruth.is_laundering - assert r.analysis is not None # Guaranteed by filter above - pred = r.analysis.is_laundering + gt = r.expected_output.is_laundering + assert r.output is not None # Guaranteed by filter above + pred = r.output.is_laundering if gt and pred: tp += 1 elif (not gt) and pred: