Skip to content

Commit 5a8acc8

Browse files
authored
0.2.4 Fixed array literal parsing in the cal block (#5)
1 parent e843050 commit 5a8acc8

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

source/openpulse/openpulse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
the :obj:`~parser.parse` function.
1515
"""
1616

17-
__version__ = "0.2.3"
17+
__version__ = "0.2.4"
1818

1919
from . import ast
2020

source/openpulse/openpulse/parser.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def _in_subroutine(self):
120120
def _in_loop(self):
121121
return any(
122122
isinstance(
123-
scope, (openpulseParser.ForStatementContext, openpulseParser.WhileStatementContext)
123+
scope,
124+
(
125+
openpulseParser.ForStatementContext,
126+
openpulseParser.WhileStatementContext,
127+
),
124128
)
125129
for scope in reversed(self._current_context())
126130
)
@@ -134,6 +138,20 @@ def _visitPulseType(self, ctx: openpulseParser.ScalarTypeContext):
134138
if ctx.FRAME():
135139
return openpulse_ast.FrameType()
136140

141+
@span
142+
def visitArrayLiteral(self, ctx: qasm3Parser.ArrayLiteralContext):
143+
array_literal_element = (
144+
openpulseParser.ExpressionContext,
145+
openpulseParser.ArrayLiteralContext,
146+
)
147+
148+
def predicate(child):
149+
return isinstance(child, array_literal_element)
150+
151+
return ast.ArrayLiteral(
152+
values=[self.visit(element) for element in ctx.getChildren(predicate=predicate)],
153+
)
154+
137155
@span
138156
def visitCalibrationBlock(self, ctx: openpulseParser.CalibrationBlockContext):
139157
with self._push_context(ctx):
@@ -202,6 +220,7 @@ def visitReturnStatement(self, ctx: qasm3Parser.ReturnStatementContext):
202220
# Reuse some QASMNodeVisitor methods in OpenPulseNodeVisitor
203221
# The following methods are overridden in OpenPulseNodeVisitor and thus not imported:
204222
"""
223+
"visitArrayLiteral",
205224
"visitIndexOperator",
206225
"visitRangeExpression",
207226
"visitReturnStatement",
@@ -213,7 +232,6 @@ def visitReturnStatement(self, ctx: qasm3Parser.ReturnStatementContext):
213232
OpenPulseNodeVisitor.visitAliasExpression = QASMNodeVisitor.visitAliasExpression
214233
OpenPulseNodeVisitor.visitAnnotation = QASMNodeVisitor.visitAnnotation
215234
OpenPulseNodeVisitor.visitArgumentDefinition = QASMNodeVisitor.visitArgumentDefinition
216-
OpenPulseNodeVisitor.visitArrayLiteral = QASMNodeVisitor.visitArrayLiteral
217235
OpenPulseNodeVisitor.visitArrayType = QASMNodeVisitor.visitArrayType
218236
OpenPulseNodeVisitor.visitAssignmentStatement = QASMNodeVisitor.visitAssignmentStatement
219237
OpenPulseNodeVisitor.visitBarrierStatement = QASMNodeVisitor.visitBarrierStatement

source/openpulse/tests/test_openpulse_parser.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from openpulse.parser import parse
88
from openpulse.ast import (
99
AngleType,
10+
ArrayLiteral,
11+
ArrayType,
1012
CalibrationDefinition,
1113
CalibrationStatement,
1214
ClassicalArgument,
@@ -94,7 +96,8 @@ def test_calibration_definition():
9496
arguments=[Identifier(name="$1")],
9597
),
9698
UnaryExpression(
97-
op=UnaryOperator["-"], expression=Identifier(name="theta")
99+
op=UnaryOperator["-"],
100+
expression=Identifier(name="theta"),
98101
),
99102
],
100103
)
@@ -194,7 +197,10 @@ def test_calibration2():
194197
identifier=Identifier(name="readout_waveform_wf"),
195198
init_expression=FunctionCall(
196199
name=Identifier(name="constant"),
197-
arguments=[FloatLiteral(value=5e-06), FloatLiteral(value=0.03)],
200+
arguments=[
201+
FloatLiteral(value=5e-06),
202+
FloatLiteral(value=0.03),
203+
],
198204
),
199205
),
200206
ForInLoop(
@@ -229,6 +235,39 @@ def test_calibration2():
229235
SpanGuard().visit(program)
230236

231237

238+
def test_array():
239+
p = """
240+
cal {
241+
array[int[32], 4] my_array = {3, 4, 5, 5};
242+
}
243+
""".strip()
244+
program = parse(p)
245+
assert _remove_spans(program) == Program(
246+
statements=[
247+
CalibrationStatement(
248+
body=[
249+
ClassicalDeclaration(
250+
type=ArrayType(
251+
base_type=IntType(size=IntegerLiteral(value=32)),
252+
dimensions=[IntegerLiteral(value=4)],
253+
),
254+
identifier=Identifier(name="my_array"),
255+
init_expression=ArrayLiteral(
256+
values=[
257+
IntegerLiteral(value=3),
258+
IntegerLiteral(value=4),
259+
IntegerLiteral(value=5),
260+
IntegerLiteral(value=5),
261+
],
262+
),
263+
)
264+
]
265+
)
266+
]
267+
)
268+
SpanGuard().visit(program)
269+
270+
232271
@pytest.mark.parametrize(
233272
"p",
234273
[
@@ -261,11 +300,6 @@ def test_calibration2():
261300
}
262301
}
263302
""",
264-
"""
265-
cal {
266-
array[int[32], 3] my_ints = {5, 6, 7};
267-
}
268-
""",
269303
],
270304
)
271305
def test_parsing(p: str):

0 commit comments

Comments
 (0)