diff --git a/src/eigenscript/compiler/analysis/observer.py b/src/eigenscript/compiler/analysis/observer.py index dfbb8bb..7bc9679 100644 --- a/src/eigenscript/compiler/analysis/observer.py +++ b/src/eigenscript/compiler/analysis/observer.py @@ -10,7 +10,7 @@ This implements zero-cost abstraction: pay for geometric semantics only when used. """ -from typing import Set +from typing import Optional, Set from eigenscript.parser.ast_builder import ( ASTNode, Identifier, @@ -26,6 +26,7 @@ ListLiteral, Index, Program, + TentativeAssignment, ) @@ -40,10 +41,19 @@ class ObserverAnalyzer: Unobserved variables can be compiled to raw doubles for maximum performance. """ + PREDICATE_NAMES = { + "converged", + "diverging", + "oscillating", + "stable", + "improving", + } + def __init__(self): self.observed: Set[str] = set() self.user_functions: Set[str] = set() self.current_function: str = None + self.last_assigned: Optional[str] = None def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]: """Analyze AST and return set of variable names that need EigenValue tracking. @@ -58,6 +68,7 @@ def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]: self.observed = set() self.user_functions = set() self.current_function = None + self.last_assigned = None # First pass: collect all user-defined function names for node in ast_nodes: @@ -82,7 +93,9 @@ def _visit(self, node: ASTNode): elif isinstance(node, FunctionDef): # Function parameters are always observed (might be interrogated inside) prev_function = self.current_function + prev_last_assigned = self.last_assigned self.current_function = node.name + self.last_assigned = None # In EigenScript, functions implicitly have parameter 'n' self.observed.add("n") @@ -90,9 +103,12 @@ def _visit(self, node: ASTNode): for stmt in node.body: self._visit(stmt) + self.last_assigned = prev_last_assigned self.current_function = prev_function - elif isinstance(node, Assignment): + elif isinstance(node, (Assignment, TentativeAssignment)): + # Assignment/TentativeAssignment identifier is a string name of the target + self.last_assigned = node.identifier self._visit(node.expression) elif isinstance(node, Interrogative): @@ -152,32 +168,19 @@ def _visit(self, node: ASTNode): self._visit(node.list_expr) self._visit(node.index_expr) - elif isinstance(node, Identifier): - # Check if this identifier is a predicate - if node.name in [ - "converged", - "diverging", - "oscillating", - "stable", - "improving", - ]: - # Predicates require the last variable to be observed - # This is a simplified heuristic - ideally we'd track scope - pass - def _check_for_predicates(self, node: ASTNode): """Check if condition uses predicates (converged, diverging, etc.).""" + if node is None: + return + if isinstance(node, Identifier): - if node.name in [ - "converged", - "diverging", - "oscillating", - "stable", - "improving", - ]: - # TODO: Mark the variable being tested as observed - # For now, this is handled by the codegen heuristic of "last variable" - pass + if node.name in self.PREDICATE_NAMES and self.last_assigned: + self.observed.add(self.last_assigned) + elif isinstance(node, UnaryOp): + self._check_for_predicates(node.operand) + elif isinstance(node, BinaryOp): + self._check_for_predicates(node.left) + self._check_for_predicates(node.right) def _mark_expression_observed(self, node: ASTNode): """Recursively mark all identifiers in an expression as observed.""" diff --git a/tests/test_observer_predicates.py b/tests/test_observer_predicates.py new file mode 100644 index 0000000..381640d --- /dev/null +++ b/tests/test_observer_predicates.py @@ -0,0 +1,45 @@ +""" +Tests for predicate handling in the ObserverAnalyzer. + +Ensure that predicate usage marks the relevant variables as observed so +predicate checks operate on EigenValue-tracked variables. +""" + +import textwrap + +from eigenscript.lexer import Tokenizer +from eigenscript.parser import Parser +from eigenscript.compiler.analysis.observer import ObserverAnalyzer + + +def _analyze(code: str): + """Helper to run observer analysis on EigenScript code.""" + source = textwrap.dedent(code).strip() + tokens = Tokenizer(source).tokenize() + ast = Parser(tokens).parse() + analyzer = ObserverAnalyzer() + return analyzer.analyze(ast.statements) + + +def test_predicate_marks_last_assignment_observed(): + """A predicate condition should mark the last assigned variable as observed.""" + observed = _analyze( + """ + x is 1 + if converged: + x is x + 1 + """ + ) + assert "x" in observed + + +def test_predicate_with_not_marks_last_assignment_observed(): + """NOT predicate conditions should also mark the last assigned variable.""" + observed = _analyze( + """ + value is 0 + loop while not converged: + value is value + 1 + """ + ) + assert "value" in observed