Skip to content

Commit 6adb7a4

Browse files
authored
[mypyc] Fix calling async methods through vectorcall (#20393)
Fixes mypyc/mypyc#1170 Fixed a bug introduced when adding the async function wrapper in 81eaa5d that would result in `TypeErrors` being raised when the wrapped function was called using the vectorcall protocol. The exception was caused by the wrapper keeping the `self` argument of methods in the `args` vector instead of extracting it to a separate c function argument. The called function would not expect the `self` argument to be part of `args` so it appeared as if too many arguments were passed.
1 parent 88a3c58 commit 6adb7a4

File tree

4 files changed

+33
-9
lines changed

4 files changed

+33
-9
lines changed

mypyc/codegen/emit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
TYPE_PREFIX,
2424
)
2525
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
26-
from mypyc.ir.func_ir import FuncDecl, FuncIR, get_text_signature
26+
from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncDecl, FuncIR, get_text_signature
2727
from mypyc.ir.ops import BasicBlock, Value
2828
from mypyc.ir.rtypes import (
2929
RInstance,
@@ -1222,10 +1222,11 @@ def emit_cpyfunction_instance(
12221222
cfunc = f"(PyCFunction){cname}"
12231223
func_flags = "METH_FASTCALL | METH_KEYWORDS"
12241224
doc = f"PyDoc_STR({native_function_doc_initializer(fn)})"
1225+
has_self_arg = "true" if fn.class_name and fn.decl.kind != FUNC_STATICMETHOD else "false"
12251226

12261227
code_flags = "CO_COROUTINE"
12271228
self.emit_line(
1228-
f'PyObject* {wrapper_name} = CPyFunction_New({module}, "{filepath}", "{name}", {cfunc}, {func_flags}, {doc}, {fn.line}, {code_flags});'
1229+
f'PyObject* {wrapper_name} = CPyFunction_New({module}, "{filepath}", "{name}", {cfunc}, {func_flags}, {doc}, {fn.line}, {code_flags}, {has_self_arg});'
12291230
)
12301231
self.emit_line(f"if (unlikely(!{wrapper_name}))")
12311232
self.emit_line(error_stmt)

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ typedef struct {
971971

972972
PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *funcname,
973973
PyCFunction func, int func_flags, const char *func_doc,
974-
int first_line, int code_flags);
974+
int first_line, int code_flags, bool has_self_arg);
975975
PyObject* CPyFunction_get_name(PyObject *op, void *context);
976976
int CPyFunction_set_name(PyObject *op, PyObject *value, void *context);
977977
PyObject* CPyFunction_get_code(PyObject *op, void *context);

mypyc/lib-rt/function_wrapper.c

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,21 @@ static PyObject* CPyFunction_Vectorcall(PyObject *func, PyObject *const *args, s
177177
PyCFunction meth = ((PyCFunctionObject *)f)->m_ml->ml_meth;
178178

179179
self = ((PyCFunctionObject *)f)->m_self;
180+
if (!self) {
181+
self = args[0];
182+
args += 1;
183+
nargs -= 1;
184+
}
180185
return ((_PyCFunctionFastWithKeywords)(void(*)(void))meth)(self, args, nargs, kwnames);
181186
}
182187

183188

184189
static CPyFunction* CPyFunction_Init(CPyFunction *op, PyMethodDef *ml, PyObject* name,
185-
PyObject *module, PyObject* code) {
190+
PyObject *module, PyObject* code, bool set_self) {
186191
PyCFunctionObject *cf = (PyCFunctionObject *)op;
187192
CPyFunction_weakreflist(op) = NULL;
188193
cf->m_ml = ml;
189-
cf->m_self = (PyObject *) op;
194+
cf->m_self = set_self ? (PyObject *) op : NULL;
190195

191196
Py_XINCREF(module);
192197
cf->m_module = module;
@@ -226,9 +231,10 @@ static PyMethodDef* CPyMethodDef_New(const char *name, PyCFunction func, int fla
226231

227232
PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *funcname,
228233
PyCFunction func, int func_flags, const char *func_doc,
229-
int first_line, int code_flags) {
234+
int first_line, int code_flags, bool has_self_arg) {
230235
PyMethodDef *method = NULL;
231236
PyObject *code = NULL, *op = NULL;
237+
bool set_self = false;
232238

233239
if (!CPyFunctionType) {
234240
CPyFunctionType = (PyTypeObject *)PyType_FromSpec(&CPyFunction_spec);
@@ -245,8 +251,15 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu
245251
if (unlikely(!code)) {
246252
goto err;
247253
}
254+
255+
// Set m_self inside the function wrapper only if the wrapped function has no self arg
256+
// to pass m_self as the self arg when the function is called.
257+
// When the function has a self arg, it will come in the args vector passed to the
258+
// vectorcall handler.
259+
set_self = !has_self_arg;
248260
op = (PyObject *)CPyFunction_Init(PyObject_GC_New(CPyFunction, CPyFunctionType),
249-
method, PyUnicode_FromString(funcname), module, code);
261+
method, PyUnicode_FromString(funcname), module,
262+
code, set_self);
250263
if (unlikely(!op)) {
251264
goto err;
252265
}

mypyc/test-data/run-async.test

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ async def sleep(t: float) -> None: ...
228228

229229
[case testAsyncWith]
230230
from testutil import async_val
231+
from typing import Any
231232

232233
class async_ctx:
233234
async def __aenter__(self) -> str:
@@ -242,15 +243,22 @@ async def async_with() -> str:
242243
async with async_ctx() as x:
243244
return await async_val("body")
244245

246+
async def async_with_vectorcall() -> str:
247+
ctx: Any = async_ctx()
248+
async with ctx:
249+
return await async_val("vc")
245250

246251
[file driver.py]
247-
from native import async_with
252+
from native import async_with, async_with_vectorcall
248253
from testutil import run_generator
249254

250255
yields, val = run_generator(async_with(), [None, 'x', None])
251256
assert yields == ('enter', 'body', 'exit'), yields
252257
assert val == 'x', val
253258

259+
yields, val = run_generator(async_with_vectorcall(), [None, 'x', None])
260+
assert yields == ('enter', 'vc', 'exit'), yields
261+
assert val == 'x', val
254262

255263
[case testAsyncReturn]
256264
from testutil import async_val
@@ -1516,7 +1524,9 @@ def test_method() -> None:
15161524
assert str(T.returns_one_async).startswith("<function T.returns_one_async"), str(T.returns_one_async)
15171525

15181526
t = T()
1519-
assert asyncio.run(t.returns_one_async()) == 1
1527+
# Call through variable to make sure the call is through vectorcall and not optimized to a native call.
1528+
f: Any = t.returns_one_async
1529+
assert asyncio.run(f()) == 1
15201530

15211531
assert not is_coroutine(T.returns_two)
15221532
assert is_coroutine(T.returns_two_async)

0 commit comments

Comments
 (0)