Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions aieng-eval-agents/aieng/agent_evals/aml_investigation/data/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class AnalystOutput(BaseModel):
class CaseRecord(BaseModel):
"""Combined case file and ground truth record."""

case: CaseFile = Field(..., description="Metadata for the laundering case.")
groundtruth: GroundTruth = Field(..., description="Ground truth information for the laundering case.")
analysis: AnalystOutput | None = Field(
input: CaseFile = Field(..., description="Metadata for the laundering case.")
expected_output: GroundTruth = Field(..., description="Ground truth information for the laundering case.")
output: AnalystOutput | None = Field(
default=None,
description="Optional analyst output for the laundering case. Typically populated after investigation.",
)
Comment on lines +109 to 114
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming these serialized fields will break parsing of any existing JSON/JSONL persisted with the old keys (case, groundtruth, analysis). Since the agent reads prior results via CaseRecord.model_validate_json(...) for resume behavior, legacy lines will fail validation and be skipped. Consider adding Pydantic v2 compatibility via validation_alias (e.g., AliasChoices('input', 'case'), AliasChoices('expected_output', 'groundtruth'), AliasChoices('output', 'analysis')) so existing artifacts remain readable while emitting the new field names.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -246,7 +246,7 @@ def build_cases(
laundering_attempt_txn_ids: set[str] = set()
for case in laundering_cases:
attempt_ids_list = []
attempt_ids_str = case.groundtruth.attempt_transaction_ids
attempt_ids_str = case.expected_output.attempt_transaction_ids
if attempt_ids_str:
attempt_ids_list = [item.strip() for item in attempt_ids_str.split(",") if item.strip()]
laundering_attempt_txn_ids.update(attempt_ids_list)
Expand All @@ -257,7 +257,7 @@ def build_cases(

false_positive_cases = _build_false_positive_cases(transactions, num_false_positive_cases)

fp_seed_ids = {case.case.seed_transaction_id for case in false_positive_cases}
fp_seed_ids = {case.input.seed_transaction_id for case in false_positive_cases}
normal_cases = _build_normal_cases(transactions, num_normal_cases, fp_seed_ids, lookback_days, min_timestamp)

return laundering_cases + false_negative_cases + false_positive_cases + normal_cases
Expand Down Expand Up @@ -367,7 +367,7 @@ def _finalize_attempt_block(
pattern_description=current["pattern_description"],
attempt_transaction_ids=attempt_ids,
)
return CaseRecord(case=case_file, groundtruth=groundtruth)
return CaseRecord(input=case_file, expected_output=groundtruth)


def _build_false_negative_cases(
Expand All @@ -377,24 +377,24 @@ def _build_false_negative_cases(
false_negative_cases: list[CaseRecord] = []
fn_attempts = remaining_attempts[:num_false_negative_cases]
for case in fn_attempts:
if case.case.seed_transaction_id in laundering_attempt_txn_ids:
if case.input.seed_transaction_id in laundering_attempt_txn_ids:
continue
case_file = CaseFile(
case_id=_create_id(f"fn:{case.case.case_id}"),
seed_transaction_id=case.case.seed_transaction_id,
seed_timestamp=case.case.seed_timestamp,
window_start=_date_window_start(case.case.seed_timestamp)
if case.case.window_start == case.case.seed_timestamp
else case.case.window_start,
case_id=_create_id(f"fn:{case.input.case_id}"),
seed_transaction_id=case.input.seed_transaction_id,
seed_timestamp=case.input.seed_timestamp,
window_start=_date_window_start(case.input.seed_timestamp)
if case.input.window_start == case.input.seed_timestamp
else case.input.window_start,
trigger_label=random.choice(_LOW_SIGNAL_REVIEW_LABELS),
)
groundtruth = GroundTruth(
is_laundering=case.groundtruth.is_laundering,
pattern_type=case.groundtruth.pattern_type,
pattern_description=case.groundtruth.pattern_description,
attempt_transaction_ids=case.groundtruth.attempt_transaction_ids,
is_laundering=case.expected_output.is_laundering,
pattern_type=case.expected_output.pattern_type,
pattern_description=case.expected_output.pattern_description,
attempt_transaction_ids=case.expected_output.attempt_transaction_ids,
)
false_negative_cases.append(CaseRecord(case=case_file, groundtruth=groundtruth))
false_negative_cases.append(CaseRecord(input=case_file, expected_output=groundtruth))

if len(false_negative_cases) < num_false_negative_cases:
logger.warning(
Expand Down Expand Up @@ -444,7 +444,7 @@ def _build_false_positive_cases(transc_df: pd.DataFrame, num_false_positive_case
pattern_description="Normal transaction",
attempt_transaction_ids="",
)
false_positive_cases.append(CaseRecord(case=case_file, groundtruth=groundtruth))
false_positive_cases.append(CaseRecord(input=case_file, expected_output=groundtruth))

return false_positive_cases

Expand Down Expand Up @@ -484,7 +484,7 @@ def _build_normal_cases(
pattern_description="Normal transaction",
attempt_transaction_ids="",
)
normal_cases.append(CaseRecord(case=case_file, groundtruth=groundtruth))
normal_cases.append(CaseRecord(input=case_file, expected_output=groundtruth))

return normal_cases

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,19 @@ def test_parse_patterns_file_parses_attempts_and_sets_seed_and_window(patterns_f
}

for record in cases:
assert record.groundtruth.is_laundering is True
assert record.case.seed_transaction_id
assert record.case.case_id
assert record.case.window_start == min_timestamp
assert record.expected_output.is_laundering is True
assert record.input.seed_transaction_id
assert record.input.case_id
assert record.input.window_start == min_timestamp

seed_timestamp, expected_count = expected_by_pattern[record.groundtruth.pattern_type]
assert record.case.seed_timestamp == seed_timestamp
seed_timestamp, expected_count = expected_by_pattern[record.expected_output.pattern_type]
assert record.input.seed_timestamp == seed_timestamp

attempt_ids = [item.strip() for item in record.groundtruth.attempt_transaction_ids.split(",") if item.strip()]
attempt_ids = [
item.strip() for item in record.expected_output.attempt_transaction_ids.split(",") if item.strip()
]
assert len(attempt_ids) == expected_count
assert attempt_ids[-1] == record.case.seed_transaction_id
assert attempt_ids[-1] == record.input.seed_transaction_id


def test_parse_patterns_file_rejects_negative_lookback(patterns_file: Path) -> None:
Expand Down Expand Up @@ -175,36 +177,38 @@ def test_build_cases_builds_each_case_type(patterns_file: Path, transactions_df:
laundering = [
case
for case in cases
if case.groundtruth.is_laundering and case.case.trigger_label == case.groundtruth.pattern_type.value
if case.expected_output.is_laundering and case.input.trigger_label == case.expected_output.pattern_type.value
]
false_negatives = [
case for case in cases if case.groundtruth.is_laundering and case.case.trigger_label in low_signal_labels
case for case in cases if case.expected_output.is_laundering and case.input.trigger_label in low_signal_labels
]
false_positives = [
case
for case in cases
if (not case.groundtruth.is_laundering) and case.case.trigger_label not in low_signal_labels
if (not case.expected_output.is_laundering) and case.input.trigger_label not in low_signal_labels
]
normals = [
case for case in cases if (not case.groundtruth.is_laundering) and case.case.trigger_label in low_signal_labels
case
for case in cases
if (not case.expected_output.is_laundering) and case.input.trigger_label in low_signal_labels
]

assert len(laundering) == 1
assert laundering[0].groundtruth.pattern_type in {
assert laundering[0].expected_output.pattern_type in {
LaunderingPattern.CYCLE,
LaunderingPattern.STACK,
LaunderingPattern.GATHER_SCATTER,
}

assert len(false_negatives) == 1
assert false_negatives[0].groundtruth.is_laundering is True
assert false_negatives[0].groundtruth.pattern_type != LaunderingPattern.NONE
assert false_negatives[0].expected_output.is_laundering is True
assert false_negatives[0].expected_output.pattern_type != LaunderingPattern.NONE

assert len(false_positives) == 1
assert false_positives[0].groundtruth.pattern_type == LaunderingPattern.NONE
assert false_positives[0].expected_output.pattern_type == LaunderingPattern.NONE

assert len(normals) == 1
assert normals[0].groundtruth.pattern_type == LaunderingPattern.NONE
assert normals[0].expected_output.pattern_type == LaunderingPattern.NONE


@pytest.mark.parametrize(
Expand Down
28 changes: 14 additions & 14 deletions implementations/aml_investigation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ def _write_results(output_path: Path, input_records: list[CaseRecord], results_b

with tmp_path.open("w", encoding="utf-8") as outfile:
for record in input_records:
case_id = record.case.case_id
case_id = record.input.case_id
if case_id in written:
continue
written.add(case_id)
out_record = results_by_id.get(case_id, record)
analyzed += int(out_record.analysis is not None)
analyzed += int(out_record.output is not None)
outfile.write(out_record.model_dump_json() + "\n")

tmp_path.replace(output_path)
Expand All @@ -178,7 +178,7 @@ def _write_results(output_path: Path, input_records: list[CaseRecord], results_b
async def _analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord:
"""Run the agent on one case and attach the validated AnalystOutput."""
message = google.genai.types.Content(
role="user", parts=[google.genai.types.Part(text=record.case.model_dump_json())]
role="user", parts=[google.genai.types.Part(text=record.input.model_dump_json())]
)
events_async = runner.run_async(session_id=str(uuid.uuid4()), user_id=getpass.getuser(), new_message=message)

Expand All @@ -188,10 +188,10 @@ async def _analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord:
final_text = "".join(part.text or "" for part in event.content.parts if part.text)

if not final_text:
logger.warning("No analyst output produced for case_id=%s", record.case.case_id)
logger.warning("No analyst output produced for case_id=%s", record.input.case_id)
return record

record.analysis = AnalystOutput.model_validate(_extract_json(final_text.strip()))
record.output = AnalystOutput.model_validate(_extract_json(final_text.strip()))
return record


Expand All @@ -200,7 +200,7 @@ async def _safe_analyze_case(runner: Runner, record: CaseRecord) -> CaseRecord:
try:
return await _analyze_case(runner, record)
except Exception as exc:
logger.exception("Case failed (case_id=%s): %s", record.case.case_id, exc)
logger.exception("Case failed (case_id=%s): %s", record.input.case_id, exc)
return record


Expand Down Expand Up @@ -234,7 +234,7 @@ async def _analyze_cases_to_jsonl(

for finished in asyncio.as_completed(tasks):
record = await finished
analyzed_by_id[record.case.case_id] = record
analyzed_by_id[record.input.case_id] = record
outfile.write(record.model_dump_json() + "\n")
outfile.flush()
os.fsync(outfile.fileno())
Expand All @@ -253,8 +253,8 @@ async def _main() -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)

input_records = _load_records(input_path)
existing_results = {record.case.case_id: record for record in _load_records(output_path)}
to_run = [r for r in input_records if existing_results.get(r.case.case_id, r).analysis is None]
existing_results = {record.input.case_id: record for record in _load_records(output_path)}
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the CaseRecord field rename, _load_records(output_path) will skip any legacy JSONL rows (old schema) as invalid, which can cause existing_results to be incomplete and make the resume logic re-run cases unnecessarily (and potentially overwrite/duplicate outputs). After adding schema aliases (or a migration path), it would also help to make legacy parsing failures more explicit here (e.g., fail fast or log a clear warning/count) so users don’t silently lose resume behavior.

Suggested change
existing_results = {record.input.case_id: record for record in _load_records(output_path)}
# Load existing results from the output file, and detect any lines that could not be parsed
existing_records = list(_load_records(output_path))
existing_results = {record.input.case_id: record for record in existing_records}
# Compare parsed records against total lines to surface potential legacy/invalid rows
if output_path.exists():
try:
with output_path.open("r", encoding="utf-8") as f:
total_lines = sum(1 for _ in f)
except OSError:
total_lines = None
if total_lines is not None and total_lines > len(existing_records):
logger.warning(
"Detected %d/%d records in %s that could not be parsed. "
"These may be legacy or invalid rows, and resume behavior may be affected.",
total_lines - len(existing_records),
total_lines,
output_path,
)

Copilot uses AI. Check for mistakes.
to_run = [r for r in input_records if existing_results.get(r.input.case_id, r).output is None]

logger.info("Resume: %d/%d done; %d remaining.", len(input_records) - len(to_run), len(input_records), len(to_run))

Expand All @@ -272,16 +272,16 @@ async def _main() -> None:
analyzed_count = _write_results(output_path, input_records, existing_results)
logger.info("Wrote %d analyzed cases to %s", analyzed_count, output_path)

final_records = [existing_results.get(r.case.case_id, r) for r in input_records]
scored = [r for r in final_records if r.analysis is not None]
final_records = [existing_results.get(r.input.case_id, r) for r in input_records]
scored = [r for r in final_records if r.output is not None]
if not scored:
logger.info("Metrics: N/A (no analyzed cases)")
else:
tp = fp = fn = tn = 0
for r in scored:
gt = r.groundtruth.is_laundering
assert r.analysis is not None # Guaranteed by filter above
pred = r.analysis.is_laundering
gt = r.expected_output.is_laundering
assert r.output is not None # Guaranteed by filter above
pred = r.output.is_laundering
if gt and pred:
tp += 1
elif (not gt) and pred:
Expand Down