Skip to content

Commit 3bcde7a

Browse files
committed
Correctly fold UnaryOp nodes
1 parent 12b49b3 commit 3bcde7a

File tree

2 files changed

+323
-25
lines changed

2 files changed

+323
-25
lines changed

src/python_minifier/transforms/constant_folding.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,35 @@
1010
from python_minifier.util import is_constant_node
1111

1212

13-
class FoldConstants(SuiteTransformer):
14-
"""
15-
Fold Constants if it would reduce the size of the source
13+
def is_foldable_constant(node):
1614
"""
15+
Check if a node is a constant expression that can participate in folding.
1716
18-
def __init__(self):
19-
super(FoldConstants, self).__init__()
17+
We can asume that children have already been folded, so foldable constants are either:
18+
- Simple literals (Num, NameConstant)
19+
- UnaryOp(USub/Invert) on a Num - these don't fold to shorter forms,
20+
so they remain after child visiting. UAdd and Not would have been
21+
folded away since they always produce shorter results.
22+
"""
23+
if is_constant_node(node, (ast.Num, ast.NameConstant)):
24+
return True
2025

21-
def visit_BinOp(self, node):
26+
if isinstance(node, ast.UnaryOp):
27+
if isinstance(node.op, (ast.USub, ast.Invert)):
28+
return is_constant_node(node.operand, ast.Num)
2229

23-
node.left = self.visit(node.left)
24-
node.right = self.visit(node.right)
30+
return False
2531

26-
# Check this is a constant expression that could be folded
27-
# We don't try to fold strings or bytes, since they have probably been arranged this way to make the source shorter and we are unlikely to beat that
28-
if not is_constant_node(node.left, (ast.Num, ast.NameConstant)):
29-
return node
30-
if not is_constant_node(node.right, (ast.Num, ast.NameConstant)):
31-
return node
3232

33-
if isinstance(node.op, ast.Div):
34-
# Folding div is subtle, since it can have different results in Python 2 and Python 3
35-
# Do this once target version options have been implemented
36-
return node
33+
class FoldConstants(SuiteTransformer):
34+
"""
35+
Fold Constants if it would reduce the size of the source
36+
"""
3737

38-
if isinstance(node.op, ast.Pow):
39-
# This can be folded, but it is unlikely to reduce the size of the source
40-
# It can also be slow to evaluate
41-
return node
38+
def __init__(self):
39+
super(FoldConstants, self).__init__()
4240

41+
def fold(self, node):
4342
# Evaluate the expression
4443
try:
4544
original_expression = unparse_expression(node)
@@ -96,6 +95,44 @@ def visit_BinOp(self, node):
9695
# New representation is shorter and has the same value, so use it
9796
return self.add_child(new_node, get_parent(node), node.namespace)
9897

98+
def visit_BinOp(self, node):
99+
100+
node.left = self.visit(node.left)
101+
node.right = self.visit(node.right)
102+
103+
# Check this is a constant expression that could be folded
104+
# We don't try to fold strings or bytes, since they have probably been arranged this way to make the source shorter and we are unlikely to beat that
105+
if not is_foldable_constant(node.left):
106+
return node
107+
if not is_foldable_constant(node.right):
108+
return node
109+
110+
if isinstance(node.op, ast.Div):
111+
# Folding div is subtle, since it can have different results in Python 2 and Python 3
112+
# Do this once target version options have been implemented
113+
return node
114+
115+
if isinstance(node.op, ast.Pow):
116+
# This can be folded, but it is unlikely to reduce the size of the source
117+
# It can also be slow to evaluate
118+
return node
119+
120+
return self.fold(node)
121+
122+
def visit_UnaryOp(self, node):
123+
124+
node.operand = self.visit(node.operand)
125+
126+
# Only fold if the operand is a foldable constant
127+
if not is_foldable_constant(node.operand):
128+
return node
129+
130+
# Only fold these unary operators
131+
if not isinstance(node.op, (ast.USub, ast.UAdd, ast.Invert, ast.Not)):
132+
return node
133+
134+
return self.fold(node)
135+
99136

100137
def equal_value_and_type(a, b):
101138
if type(a) != type(b):

test/test_folding.py

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
import pytest
55

6+
from python_minifier import minify
67
from python_minifier.ast_annotation import add_parent
78
from python_minifier.ast_compare import compare_ast
89
from python_minifier.rename import add_namespace
9-
from python_minifier.transforms.constant_folding import FoldConstants
10-
10+
from python_minifier.transforms.constant_folding import FoldConstants, equal_value_and_type
1111

1212
def fold_constants(source):
1313
module = ast.parse(source)
@@ -106,7 +106,9 @@ def test_bool(source, expected):
106106
('0xf0|0x0f', '0xff'),
107107
('10%2', '0'),
108108
('10%3', '1'),
109-
('10-100', '-90')
109+
('10-100', '-90'),
110+
('1+1', '2'),
111+
('2+2', '4'),
110112
]
111113
)
112114
def test_int(source, expected):
@@ -149,3 +151,262 @@ def test_not_eval(source, expected):
149151
"""
150152

151153
run_test(source, expected)
154+
155+
156+
@pytest.mark.parametrize(
157+
('source', 'expected'), [
158+
# Folding results in infinity, which can be represented as 1e999
159+
('1e308 + 1e308', '1e999'),
160+
('1e308 * 2', '1e999'),
161+
]
162+
)
163+
def test_fold_infinity(source, expected):
164+
"""
165+
Test that expressions resulting in infinity are folded to 1e999.
166+
167+
Infinity can be represented as 1e999, which is shorter than
168+
the original expression.
169+
"""
170+
run_test(source, expected)
171+
172+
173+
@pytest.mark.parametrize(
174+
('source', 'expected'), [
175+
# Folding would result in NaN, which cannot be represented as a literal
176+
('1e999 - 1e999', '1e999 - 1e999'),
177+
('0.0 * 1e999', '0.0 * 1e999'),
178+
]
179+
)
180+
def test_no_fold_nan(source, expected):
181+
"""
182+
Test that expressions resulting in NaN are not folded.
183+
184+
NaN is not a valid Python literal, so we cannot fold expressions
185+
that would produce it.
186+
"""
187+
run_test(source, expected)
188+
189+
190+
@pytest.mark.parametrize(
191+
('source', 'expected'), [
192+
('100.0+100.0', '200.0'),
193+
('1000.0+1000.0', '2000.0'),
194+
]
195+
)
196+
def test_fold_float(source, expected):
197+
"""
198+
Test that float expressions are folded when the result is shorter.
199+
"""
200+
run_test(source, expected)
201+
202+
203+
def test_equal_value_and_type():
204+
"""
205+
Test the equal_value_and_type helper function.
206+
"""
207+
208+
# Same type and value
209+
assert equal_value_and_type(1, 1) is True
210+
assert equal_value_and_type(1.0, 1.0) is True
211+
assert equal_value_and_type(True, True) is True
212+
assert equal_value_and_type('hello', 'hello') is True
213+
214+
# Different types
215+
assert equal_value_and_type(1, 1.0) is False
216+
assert equal_value_and_type(1, True) is False
217+
assert equal_value_and_type(True, 1) is False
218+
219+
# Different values
220+
assert equal_value_and_type(1, 2) is False
221+
assert equal_value_and_type(1.0, 2.0) is False
222+
223+
224+
def test_equal_value_and_type_nan():
225+
"""
226+
Test the equal_value_and_type helper function with NaN values.
227+
"""
228+
229+
nan = float('nan')
230+
231+
# NaN is not equal to itself in Python (nan != nan is True)
232+
# But if both are NaN, equal_value_and_type returns True via a == b
233+
# Since nan == nan is False, we need to check the actual behavior
234+
result = equal_value_and_type(nan, nan)
235+
# Python's nan == nan is False, so this should be False
236+
assert result is False
237+
238+
# NaN compared to non-NaN should be False
239+
assert equal_value_and_type(nan, 1.0) is False
240+
assert equal_value_and_type(1.0, nan) is False
241+
242+
243+
@pytest.mark.parametrize(
244+
('source', 'expected'), [
245+
('5 - 10', '-5'),
246+
('0 - 100', '-100'),
247+
('1.0 - 2.0', '-1.0'),
248+
('0.0 - 100.0', '-100.0'),
249+
]
250+
)
251+
def test_negative_results(source, expected):
252+
"""
253+
Test BinOp expressions that produce negative results.
254+
"""
255+
run_test(source, expected)
256+
257+
258+
@pytest.mark.parametrize(
259+
('source', 'expected'), [
260+
('5 * -2', '-10'),
261+
('-5 * 2', '-10'),
262+
('-5 + 10', '5'),
263+
('-90 + 10', '-80'),
264+
('10 - 20 + 5', '-5'),
265+
('(5 - 10) * 2', '-10'),
266+
('2 * (0 - 5)', '-10'),
267+
('(1 - 10) + (2 - 20)', '-27'),
268+
]
269+
)
270+
def test_negative_operands_folded(source, expected):
271+
"""
272+
Test that expressions with negative operands are folded.
273+
"""
274+
run_test(source, expected)
275+
276+
277+
@pytest.mark.parametrize(
278+
('source', 'expected'), [
279+
('-(-5)', '5'),
280+
('--5', '5'),
281+
('-(-100)', '100'),
282+
('-(-(5 + 5))', '10'),
283+
('~(~0)', '0'),
284+
('~~5', '5'),
285+
('~~100', '100'),
286+
('+(+5)', '5'),
287+
('+(-5)', '-5'),
288+
]
289+
)
290+
def test_unary_folded(source, expected):
291+
"""
292+
Test that unary operations on constant expressions are folded.
293+
"""
294+
run_test(source, expected)
295+
296+
297+
@pytest.mark.parametrize(
298+
('source', 'expected'), [
299+
('not not True', 'True'),
300+
('not not False', 'False'),
301+
('not True', 'False'),
302+
('not False', 'True'),
303+
]
304+
)
305+
def test_unary_not_folded(source, expected):
306+
"""
307+
Test that 'not' operations on constant expressions are folded.
308+
"""
309+
if sys.version_info < (3, 4):
310+
pytest.skip('NameConstant not in python < 3.4')
311+
312+
run_test(source, expected)
313+
314+
315+
@pytest.mark.parametrize(
316+
('source', 'expected'), [
317+
('-5', '-5'),
318+
('~5', '~5'),
319+
]
320+
)
321+
def test_unary_simple_not_folded(source, expected):
322+
"""
323+
Test that simple unary operations on literals are not folded
324+
when the result would not be shorter.
325+
"""
326+
run_test(source, expected)
327+
328+
329+
def test_unary_plus_folded():
330+
"""
331+
Test that unary plus on a literal is folded to remove the plus.
332+
"""
333+
run_test('+5', '5')
334+
335+
336+
def test_constant_folding_enabled_by_default():
337+
"""Verify constant folding is enabled by default."""
338+
source = 'x = 10 + 10'
339+
result = minify(source)
340+
assert '20' in result
341+
assert '10+10' not in result and '10 + 10' not in result
342+
343+
344+
def test_constant_folding_disabled():
345+
"""Verify expressions are not folded when constant_folding=False."""
346+
source = 'x = 10 + 10'
347+
result = minify(source, constant_folding=False)
348+
assert '10+10' in result or '10 + 10' in result
349+
assert result.strip() != 'x=20'
350+
351+
352+
def test_constant_folding_disabled_complex_expression():
353+
"""Verify complex expressions are preserved when disabled."""
354+
source = 'SECONDS_IN_A_DAY = 60 * 60 * 24'
355+
result = minify(source, constant_folding=False)
356+
assert '60*60*24' in result or '60 * 60 * 24' in result
357+
358+
359+
def test_constant_folding_enabled_complex_expression():
360+
"""Verify complex expressions are folded when enabled."""
361+
source = 'SECONDS_IN_A_DAY = 60 * 60 * 24'
362+
result = minify(source, constant_folding=True)
363+
assert '86400' in result
364+
365+
366+
@pytest.mark.parametrize(
367+
('source', 'should_contain_when_disabled'), [
368+
('x = 5 - 10', '5-10'),
369+
('x = True | False', 'True|False'),
370+
('x = 0xff ^ 0x0f', '255^15'),
371+
]
372+
)
373+
def test_constant_folding_disabled_various_ops(source, should_contain_when_disabled):
374+
"""Verify various operations are not folded when disabled."""
375+
if sys.version_info < (3, 4) and 'True' in source:
376+
pytest.skip('NameConstant not in python < 3.4')
377+
378+
result = minify(source, constant_folding=False)
379+
assert should_contain_when_disabled in result.replace(' ', '')
380+
381+
382+
@pytest.mark.parametrize(
383+
('source', 'expected'), [
384+
('1j + 2j', '3j'),
385+
('3j * 2', '6j'),
386+
('2 * 3j', '6j'),
387+
('10j - 5j', '5j'),
388+
]
389+
)
390+
def test_complex_folded(source, expected):
391+
"""
392+
Test complex number operations that are folded.
393+
394+
Complex operations are folded when the result is shorter than the original.
395+
"""
396+
run_test(source, expected)
397+
398+
399+
@pytest.mark.parametrize(
400+
('source', 'expected'), [
401+
('1j - 2j', '1j-2j'),
402+
('1j * 1j', '1j*1j'),
403+
('0j + 5', '0j+5'),
404+
('2j / 1j', '2j/1j'),
405+
('1j ** 2', '1j**2'),
406+
]
407+
)
408+
def test_complex_not_folded(source, expected):
409+
"""
410+
Test complex number operations that are not folded.
411+
"""
412+
run_test(source, expected)

0 commit comments

Comments
 (0)