Skip to content

Commit 2064d9e

Browse files
committed
Implement bytes.startswith in mypyc
1 parent fefc070 commit 2064d9e

File tree

7 files changed

+79
-1
lines changed

7 files changed

+79
-1
lines changed

bench.cpython-314-darwin.so

162 KB
Binary file not shown.

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
783783
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
784784
CPyTagged CPyBytes_Ord(PyObject *obj);
785785
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
786-
786+
int CPyBytes_Startswith(PyObject *self, PyObject *subobj);
787787

788788
int CPyBytes_Compare(PyObject *left, PyObject *right);
789789

mypyc/lib-rt/bytes_ops.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,42 @@ PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) {
171171
}
172172
return PySequence_Repeat(bytes, temp_count);
173173
}
174+
175+
int CPyBytes_Startswith(PyObject *self, PyObject *subobj) {
176+
if (PyBytes_CheckExact(self) && PyBytes_CheckExact(subobj)) {
177+
if (self == subobj) {
178+
return 1;
179+
}
180+
181+
Py_ssize_t self_len = PyBytes_GET_SIZE(self);
182+
Py_ssize_t subobj_len = PyBytes_GET_SIZE(subobj);
183+
184+
if (subobj_len > self_len) {
185+
return 0;
186+
}
187+
188+
const char *self_buf = PyBytes_AS_STRING(self);
189+
const char *subobj_buf = PyBytes_AS_STRING(subobj);
190+
191+
if (subobj_len == 0) {
192+
return 1;
193+
}
194+
195+
return memcmp(self_buf, subobj_buf, (size_t)subobj_len) == 0 ? 1 : 0;
196+
}
197+
_Py_IDENTIFIER(startswith);
198+
PyObject *name = _PyUnicode_FromId(&PyId_startswith);
199+
if (name == NULL) {
200+
return 2;
201+
}
202+
PyObject *result = PyObject_CallMethodOneArg(self, name, subobj);
203+
if (result == NULL) {
204+
return 2;
205+
}
206+
int ret = PyObject_IsTrue(result);
207+
Py_DECREF(result);
208+
if (ret < 0) {
209+
return 2;
210+
}
211+
return ret;
212+
}

mypyc/primitives/bytes_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mypyc.ir.rtypes import (
88
RUnion,
99
bit_rprimitive,
10+
bool_rprimitive,
1011
bytes_rprimitive,
1112
c_int_rprimitive,
1213
c_pyssize_t_rprimitive,
@@ -139,6 +140,16 @@
139140
dependencies=[BYTES_EXTRA_OPS],
140141
)
141142

143+
# bytes.startswith(bytes)
144+
method_op(
145+
name="startswith",
146+
arg_types=[bytes_rprimitive, bytes_rprimitive],
147+
return_type=c_int_rprimitive,
148+
c_function_name="CPyBytes_Startswith",
149+
truncated_type=bool_rprimitive,
150+
error_kind=ERR_MAGIC,
151+
)
152+
142153
# Join bytes objects and return a new bytes.
143154
# The first argument is the total number of the following bytes.
144155
bytes_build_op = custom_op(

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def __getitem__(self, i: slice) -> bytes: ...
179179
def join(self, x: Iterable[object]) -> bytes: ...
180180
def decode(self, encoding: str=..., errors: str=...) -> str: ...
181181
def translate(self, t: bytes) -> bytes: ...
182+
def startswith(self, t: bytes) -> bool: ...
182183
def __iter__(self) -> Iterator[int]: ...
183184

184185
class bytearray:

mypyc/test-data/irbuild-bytes.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,17 @@ def f(b, table):
248248
L0:
249249
r0 = CPyBytes_Translate(b, table)
250250
return r0
251+
252+
[case testBytesStartsWith]
253+
def f(a: bytes, b: bytes) -> bool:
254+
return a.startswith(b)
255+
[out]
256+
def f(a, b):
257+
a, b :: bytes
258+
r0 :: i32
259+
r1 :: bool
260+
L0:
261+
r0 = CPyBytes_Startswith(a, b)
262+
r1 = truncate r0: i32 to builtins.bool
263+
return r1
264+

mypyc/test-data/run-bytes.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,19 @@ def test_translate() -> None:
200200
with assertRaises(ValueError, "translation table must be 256 characters long"):
201201
b'test'.translate(bytes(100))
202202

203+
def test_startswith() -> None:
204+
# Test default behavior
205+
test = b'some string'
206+
assert test.startswith(b'some')
207+
assert test.startswith(b'some string')
208+
assert not test.startswith(b'other')
209+
assert not test.startswith(b'some string but longer')
210+
211+
# Test empty cases
212+
assert test.startswith(b'')
213+
assert b''.startswith(b'')
214+
assert not b''.startswith(test)
215+
203216
[case testBytesSlicing]
204217
def test_bytes_slicing() -> None:
205218
b = b'abcdefg'

0 commit comments

Comments
 (0)