Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions src/eigenscript/compiler/analysis/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
This implements zero-cost abstraction: pay for geometric semantics only when used.
"""

from typing import Set
from typing import Optional, Set
from eigenscript.parser.ast_builder import (
ASTNode,
Identifier,
Expand All @@ -26,6 +26,7 @@
ListLiteral,
Index,
Program,
TentativeAssignment,
)


Expand All @@ -40,10 +41,19 @@ class ObserverAnalyzer:
Unobserved variables can be compiled to raw doubles for maximum performance.
"""

PREDICATE_NAMES = {
"converged",
"diverging",
"oscillating",
"stable",
"improving",
}

def __init__(self):
self.observed: Set[str] = set()
self.user_functions: Set[str] = set()
self.current_function: str = None
self.last_assigned: Optional[str] = None

def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]:
"""Analyze AST and return set of variable names that need EigenValue tracking.
Expand All @@ -58,6 +68,7 @@ def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]:
self.observed = set()
self.user_functions = set()
self.current_function = None
self.last_assigned = None

# First pass: collect all user-defined function names
for node in ast_nodes:
Expand All @@ -82,17 +93,22 @@ def _visit(self, node: ASTNode):
elif isinstance(node, FunctionDef):
# Function parameters are always observed (might be interrogated inside)
prev_function = self.current_function
prev_last_assigned = self.last_assigned
self.current_function = node.name
self.last_assigned = None

# In EigenScript, functions implicitly have parameter 'n'
self.observed.add("n")

for stmt in node.body:
self._visit(stmt)

self.last_assigned = prev_last_assigned
self.current_function = prev_function

elif isinstance(node, Assignment):
elif isinstance(node, (Assignment, TentativeAssignment)):
# Assignment/TentativeAssignment identifier is a string name of the target
self.last_assigned = node.identifier
self._visit(node.expression)

elif isinstance(node, Interrogative):
Expand Down Expand Up @@ -152,32 +168,19 @@ def _visit(self, node: ASTNode):
self._visit(node.list_expr)
self._visit(node.index_expr)

elif isinstance(node, Identifier):
# Check if this identifier is a predicate
if node.name in [
"converged",
"diverging",
"oscillating",
"stable",
"improving",
]:
# Predicates require the last variable to be observed
# This is a simplified heuristic - ideally we'd track scope
pass

def _check_for_predicates(self, node: ASTNode):
"""Check if condition uses predicates (converged, diverging, etc.)."""
if node is None:
return

if isinstance(node, Identifier):
if node.name in [
"converged",
"diverging",
"oscillating",
"stable",
"improving",
]:
# TODO: Mark the variable being tested as observed
# For now, this is handled by the codegen heuristic of "last variable"
pass
if node.name in self.PREDICATE_NAMES and self.last_assigned:
self.observed.add(self.last_assigned)
elif isinstance(node, UnaryOp):
self._check_for_predicates(node.operand)
elif isinstance(node, BinaryOp):
self._check_for_predicates(node.left)
self._check_for_predicates(node.right)

def _mark_expression_observed(self, node: ASTNode):
"""Recursively mark all identifiers in an expression as observed."""
Expand Down
45 changes: 45 additions & 0 deletions tests/test_observer_predicates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Tests for predicate handling in the ObserverAnalyzer.

Ensure that predicate usage marks the relevant variables as observed so
predicate checks operate on EigenValue-tracked variables.
"""

import textwrap

from eigenscript.lexer import Tokenizer
from eigenscript.parser import Parser
from eigenscript.compiler.analysis.observer import ObserverAnalyzer


def _analyze(code: str):
"""Helper to run observer analysis on EigenScript code."""
source = textwrap.dedent(code).strip()
tokens = Tokenizer(source).tokenize()
ast = Parser(tokens).parse()
analyzer = ObserverAnalyzer()
return analyzer.analyze(ast.statements)


def test_predicate_marks_last_assignment_observed():
"""A predicate condition should mark the last assigned variable as observed."""
observed = _analyze(
"""
x is 1
if converged:
x is x + 1
"""
)
assert "x" in observed


def test_predicate_with_not_marks_last_assignment_observed():
"""NOT predicate conditions should also mark the last assigned variable."""
observed = _analyze(
"""
value is 0
loop while not converged:
value is value + 1
"""
)
assert "value" in observed
Loading