Skip to content

Commit 1cea058

Browse files
authored
[mypyc] Add primitive for bytes.startswith (#20387)
Implements `bytes.startswith` in mypy. Potentially could be more efficient without relying on `memcmp` but not sure. Tested with the following benchmark code, which shows a ~6.3x performance improvement compared to standard Python: ``` import time def bench(prefix: bytes, a: list[bytes], n: int) -> int: i = 0 for x in range(n): for b in a: if b.startswith(prefix): i += 1 return i a = [b"foo", b"barasdfsf", b"foobar", b"ab", b"asrtert", b"sertyeryt"] n = 5 * 1000 * 1000 prefix = b"foo" bench(prefix, a, n) t0 = time.time() bench(prefix, a, n) td = time.time() - t0 print(f"{td}s") ``` Output: ``` $ python /tmp/bench.py 1.0015509128570557s $ python -c 'import bench' 0.154998779296875s ```
1 parent 6adb7a4 commit 1cea058

File tree

6 files changed

+86
-1
lines changed

6 files changed

+86
-1
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
783783
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
784784
CPyTagged CPyBytes_Ord(PyObject *obj);
785785
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
786-
786+
int CPyBytes_Startswith(PyObject *self, PyObject *subobj);
787787

788788
int CPyBytes_Compare(PyObject *left, PyObject *right);
789789

mypyc/lib-rt/bytes_ops.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,41 @@ PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) {
171171
}
172172
return PySequence_Repeat(bytes, temp_count);
173173
}
174+
175+
int CPyBytes_Startswith(PyObject *self, PyObject *subobj) {
176+
if (PyBytes_CheckExact(self) && PyBytes_CheckExact(subobj)) {
177+
if (self == subobj) {
178+
return 1;
179+
}
180+
181+
Py_ssize_t subobj_len = PyBytes_GET_SIZE(subobj);
182+
if (subobj_len == 0) {
183+
return 1;
184+
}
185+
186+
Py_ssize_t self_len = PyBytes_GET_SIZE(self);
187+
if (subobj_len > self_len) {
188+
return 0;
189+
}
190+
191+
const char *self_buf = PyBytes_AS_STRING(self);
192+
const char *subobj_buf = PyBytes_AS_STRING(subobj);
193+
194+
return memcmp(self_buf, subobj_buf, (size_t)subobj_len) == 0 ? 1 : 0;
195+
}
196+
_Py_IDENTIFIER(startswith);
197+
PyObject *name = _PyUnicode_FromId(&PyId_startswith);
198+
if (name == NULL) {
199+
return 2;
200+
}
201+
PyObject *result = PyObject_CallMethodOneArg(self, name, subobj);
202+
if (result == NULL) {
203+
return 2;
204+
}
205+
int ret = PyObject_IsTrue(result);
206+
Py_DECREF(result);
207+
if (ret < 0) {
208+
return 2;
209+
}
210+
return ret;
211+
}

mypyc/primitives/bytes_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mypyc.ir.rtypes import (
88
RUnion,
99
bit_rprimitive,
10+
bool_rprimitive,
1011
bytes_rprimitive,
1112
c_int_rprimitive,
1213
c_pyssize_t_rprimitive,
@@ -139,6 +140,16 @@
139140
dependencies=[BYTES_EXTRA_OPS],
140141
)
141142

143+
# bytes.startswith(bytes)
144+
method_op(
145+
name="startswith",
146+
arg_types=[bytes_rprimitive, bytes_rprimitive],
147+
return_type=c_int_rprimitive,
148+
c_function_name="CPyBytes_Startswith",
149+
truncated_type=bool_rprimitive,
150+
error_kind=ERR_MAGIC,
151+
)
152+
142153
# Join bytes objects and return a new bytes.
143154
# The first argument is the total number of the following bytes.
144155
bytes_build_op = custom_op(

mypyc/test-data/fixtures/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def __getitem__(self, i: slice) -> bytes: ...
179179
def join(self, x: Iterable[object]) -> bytes: ...
180180
def decode(self, encoding: str=..., errors: str=...) -> str: ...
181181
def translate(self, t: bytes) -> bytes: ...
182+
def startswith(self, t: bytes) -> bool: ...
182183
def __iter__(self) -> Iterator[int]: ...
183184

184185
class bytearray:
@@ -192,6 +193,7 @@ def __add__(self, s: bytes) -> bytearray: ...
192193
def __setitem__(self, i: int, o: int) -> None: ...
193194
def __getitem__(self, i: int) -> int: ...
194195
def decode(self, x: str = ..., y: str = ...) -> str: ...
196+
def startswith(self, t: bytes) -> bool: ...
195197

196198
class bool(int):
197199
def __init__(self, o: object = ...) -> None: ...

mypyc/test-data/irbuild-bytes.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,16 @@ def f(b, table):
248248
L0:
249249
r0 = CPyBytes_Translate(b, table)
250250
return r0
251+
252+
[case testBytesStartsWith]
253+
def f(a: bytes, b: bytes) -> bool:
254+
return a.startswith(b)
255+
[out]
256+
def f(a, b):
257+
a, b :: bytes
258+
r0 :: i32
259+
r1 :: bool
260+
L0:
261+
r0 = CPyBytes_Startswith(a, b)
262+
r1 = truncate r0: i32 to builtins.bool
263+
return r1

mypyc/test-data/run-bytes.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,27 @@ def test_translate() -> None:
200200
with assertRaises(ValueError, "translation table must be 256 characters long"):
201201
b'test'.translate(bytes(100))
202202

203+
def test_startswith() -> None:
204+
# Test default behavior
205+
test = b'some string'
206+
assert test.startswith(b'some')
207+
assert test.startswith(b'some string')
208+
assert not test.startswith(b'other')
209+
assert not test.startswith(b'some string but longer')
210+
211+
# Test empty cases
212+
assert test.startswith(b'')
213+
assert b''.startswith(b'')
214+
assert not b''.startswith(test)
215+
216+
# Test bytearray to verify slow paths
217+
assert test.startswith(bytearray(b'some'))
218+
assert not test.startswith(bytearray(b'other'))
219+
220+
test = bytearray(b'some string')
221+
assert test.startswith(b'some')
222+
assert not test.startswith(b'other')
223+
203224
[case testBytesSlicing]
204225
def test_bytes_slicing() -> None:
205226
b = b'abcdefg'

0 commit comments

Comments
 (0)