diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 7190cac3edb9..1ee13955fb0e 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -25,6 +25,7 @@ StarExpr, TupleExpr, TypeAlias, + Var, ) from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types from mypyc.ir.ops import ( @@ -1235,10 +1236,50 @@ def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None: return other + sum(stars) # type: ignore [arg-type] elif isinstance(expr, StarExpr): return get_expr_length(builder, expr.expr) + elif ( + isinstance(expr, RefExpr) + and isinstance(expr.node, Var) + and expr.node.is_final + and isinstance(expr.node.final_value, str) + and expr.node.has_explicit_value + ): + return len(expr.node.final_value) + elif ( + isinstance(expr, CallExpr) + and isinstance(callee := expr.callee, NameExpr) + and all(kind == ARG_POS for kind in expr.arg_kinds) + ): + fullname = callee.fullname + if ( + fullname + in ( + "builtins.list", + "builtins.tuple", + "builtins.enumerate", + "builtins.sorted", + "builtins.reversed", + ) + and len(expr.args) == 1 + ): + return get_expr_length(builder, expr.args[0]) + elif fullname == "builtins.map" and len(expr.args) == 2: + return get_expr_length(builder, expr.args[1]) + elif fullname == "builtins.zip" and expr.args: + arg_lengths = [get_expr_length(builder, arg) for arg in expr.args] + if all(arg is not None for arg in arg_lengths): + return min(arg_lengths) # type: ignore [type-var] + elif fullname == "builtins.range" and len(expr.args) <= 3: + folded_args = [constant_fold_expr(builder, arg) for arg in expr.args] + if all(isinstance(arg, int) for arg in folded_args): + try: + return len(range(*cast(list[int], folded_args))) + except ValueError: # prevent crash if invalid args + pass + # TODO: extend this, passing length of listcomp and genexp should have worthwhile # performance boost and can be (sometimes) figured out pretty easily. set and dict # comps *can* be done as well but will need special logic to consider the possibility - # of key conflicts. Range, enumerate, zip are all simple logic. + # of key conflicts. # we might still be able to get the length directly from the type rtype = builder.node_type(expr) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index d9202707124b..38b51a5f19fa 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -312,6 +312,11 @@ def __iter__(self) -> Iterator[int]: pass def __len__(self) -> int: pass def __next__(self) -> int: pass +class map(Iterator[_S]): + def __init__(self, func: Callable[[_T], _S], iterable: Iterable[_T]) -> None: pass + def __iter__(self) -> Self: pass + def __next__(self) -> _S: pass + class property: def __init__(self, fget: Optional[Callable[[Any], Any]] = ..., fset: Optional[Callable[[Any, Any], None]] = ..., diff --git a/mypyc/test-data/irbuild-tuple.test b/mypyc/test-data/irbuild-tuple.test index 0ae1a2f0cda6..9b6949760910 100644 --- a/mypyc/test-data/irbuild-tuple.test +++ b/mypyc/test-data/irbuild-tuple.test @@ -898,6 +898,126 @@ L4: a = r1 return 1 +[case testTupleBuiltFromLengthCheckable] +from typing import Tuple + +def f(val: bool) -> bool: + return not val + +def test() -> None: + # this tuple is created from a very complex genexp but we can still compute the length and preallocate the tuple + a = tuple( + x + for x + in zip( + map(str, range(5)), + enumerate(sorted(reversed(tuple("abcdefg")))) + ) + ) +[out] +def f(val): + val, r0 :: bool +L0: + r0 = val ^ 1 + return r0 +def test(): + r0 :: list + r1, r2, r3 :: object + r4 :: object[1] + r5 :: object_ptr + r6 :: object + r7 :: range + r8 :: object + r9 :: str + r10 :: object + r11 :: object[2] + r12 :: object_ptr + r13 :: object + r14 :: str + r15 :: tuple + r16 :: object + r17 :: str + r18 :: object + r19 :: object[1] + r20 :: object_ptr + r21 :: object + r22 :: list + r23 :: object + r24 :: str + r25 :: object + r26 :: object[1] + r27 :: object_ptr + r28, r29 :: object + r30 :: str + r31 :: object + r32 :: object[2] + r33 :: object_ptr + r34, r35, r36 :: object + r37, x :: tuple[str, tuple[int, str]] + r38 :: object + r39 :: i32 + r40, r41 :: bit + r42, a :: tuple +L0: + r0 = PyList_New(0) + r1 = load_address PyUnicode_Type + r2 = load_address PyRange_Type + r3 = object 5 + r4 = [r3] + r5 = load_address r4 + r6 = PyObject_Vectorcall(r2, r5, 1, 0) + keep_alive r3 + r7 = cast(range, r6) + r8 = builtins :: module + r9 = 'map' + r10 = CPyObject_GetAttr(r8, r9) + r11 = [r1, r7] + r12 = load_address r11 + r13 = PyObject_Vectorcall(r10, r12, 2, 0) + keep_alive r1, r7 + r14 = 'abcdefg' + r15 = PySequence_Tuple(r14) + r16 = builtins :: module + r17 = 'reversed' + r18 = CPyObject_GetAttr(r16, r17) + r19 = [r15] + r20 = load_address r19 + r21 = PyObject_Vectorcall(r18, r20, 1, 0) + keep_alive r15 + r22 = CPySequence_Sort(r21) + r23 = builtins :: module + r24 = 'enumerate' + r25 = CPyObject_GetAttr(r23, r24) + r26 = [r22] + r27 = load_address r26 + r28 = PyObject_Vectorcall(r25, r27, 1, 0) + keep_alive r22 + r29 = builtins :: module + r30 = 'zip' + r31 = CPyObject_GetAttr(r29, r30) + r32 = [r13, r28] + r33 = load_address r32 + r34 = PyObject_Vectorcall(r31, r33, 2, 0) + keep_alive r13, r28 + r35 = PyObject_GetIter(r34) +L1: + r36 = PyIter_Next(r35) + if is_error(r36) goto L4 else goto L2 +L2: + r37 = unbox(tuple[str, tuple[int, str]], r36) + x = r37 + r38 = box(tuple[str, tuple[int, str]], x) + r39 = PyList_Append(r0, r38) + r40 = r39 >= 0 :: signed +L3: + goto L1 +L4: + r41 = CPy_NoErrOccurred() +L5: + r42 = PyList_AsTuple(r0) + a = r42 + return 1 + [case testTupleBuiltFromStars] from typing import Final