Skip to content

Commit 57c761f

Browse files
authored
Type multi.pyi (#1539)
* Type multi.pyi * extra test * remove undocumented, refine view * fixup
1 parent d904a17 commit 57c761f

File tree

5 files changed

+165
-40
lines changed

5 files changed

+165
-40
lines changed

pandas-stubs/_typing.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ np_1darray_complex: TypeAlias = np_1darray[np.complexfloating]
951951
np_1darray_object: TypeAlias = np_1darray[np.object_]
952952
np_1darray_bool: TypeAlias = np_1darray[np.bool]
953953
np_1darray_intp: TypeAlias = np_1darray[np.intp]
954+
np_1darray_int8: TypeAlias = np_1darray[np.int8]
954955
np_1darray_int64: TypeAlias = np_1darray[np.int64]
955956
np_1darray_anyint: TypeAlias = np_1darray[np.integer]
956957
np_1darray_float: TypeAlias = np_1darray[np.floating]

pandas-stubs/core/indexes/multi.pyi

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ from collections.abc import (
77
)
88
from typing import (
99
Any,
10-
final,
1110
overload,
1211
)
1312

@@ -19,15 +18,22 @@ from typing_extensions import Self
1918
from pandas._typing import (
2019
AnyAll,
2120
Axes,
22-
DropKeep,
2321
Dtype,
2422
HashableT,
2523
IndexLabel,
24+
Label,
2625
Level,
2726
MaskType,
2827
NaPosition,
28+
NumpyNotTimeDtypeArg,
29+
NumpyTimedeltaDtypeArg,
30+
NumpyTimestampDtypeArg,
2931
SequenceNotStr,
32+
Shape,
3033
np_1darray_bool,
34+
np_1darray_int8,
35+
np_1darray_intp,
36+
np_ndarray,
3137
np_ndarray_anyint,
3238
)
3339

@@ -70,19 +76,46 @@ class MultiIndex(Index):
7076
sortorder: int | None = ...,
7177
names: SequenceNotStr[Hashable] = ...,
7278
) -> Self: ...
73-
@property
74-
def shape(self): ...
7579
@property # Should be read-only
7680
def levels(self) -> list[Index]: ...
77-
def set_levels(self, levels, *, level=..., verify_integrity: bool = ...): ...
81+
@overload
82+
def set_levels(
83+
self,
84+
levels: Sequence[SequenceNotStr[Hashable]],
85+
*,
86+
level: Sequence[Level] | None = None,
87+
verify_integrity: bool = True,
88+
) -> MultiIndex: ...
89+
@overload
90+
def set_levels(
91+
self,
92+
levels: SequenceNotStr[Hashable],
93+
*,
94+
level: Level,
95+
verify_integrity: bool = True,
96+
) -> MultiIndex: ...
7897
@property
79-
def codes(self): ...
80-
def set_codes(self, codes, *, level=..., verify_integrity: bool = ...): ...
98+
def codes(self) -> list[np_1darray_int8]: ...
99+
@overload
100+
def set_codes(
101+
self,
102+
codes: Sequence[Sequence[int]],
103+
*,
104+
level: Sequence[Level] | None = None,
105+
verify_integrity: bool = True,
106+
) -> MultiIndex: ...
107+
@overload
108+
def set_codes(
109+
self,
110+
codes: Sequence[int],
111+
*,
112+
level: Level,
113+
verify_integrity: bool = True,
114+
) -> MultiIndex: ...
81115
def copy( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] # pyrefly: ignore
82116
self, names: SequenceNotStr[Hashable] = ..., deep: bool = False
83117
) -> Self: ...
84-
def view(self, cls=...): ...
85-
def __contains__(self, key) -> bool: ...
118+
def view(self, cls: NumpyNotTimeDtypeArg | NumpyTimedeltaDtypeArg | NumpyTimestampDtypeArg | type[np_ndarray] | None = None) -> MultiIndex: ... # type: ignore[override] # pyrefly: ignore[bad-override] # pyright: ignore[reportIncompatibleMethodOverride]
86119
@property
87120
def dtype(self) -> np.dtype: ...
88121
@property
@@ -92,29 +125,34 @@ class MultiIndex(Index):
92125
def nbytes(self) -> int: ...
93126
def __len__(self) -> int: ...
94127
@property
95-
def values(self): ...
96-
@property
97128
def is_monotonic_increasing(self) -> bool: ...
98129
@property
99130
def is_monotonic_decreasing(self) -> bool: ...
100-
def duplicated(self, keep: DropKeep = "first"): ...
101131
def dropna(self, how: AnyAll = "any") -> Self: ...
102132
def droplevel(self, level: Level | Sequence[Level] = 0) -> MultiIndex | Index: ... # type: ignore[override]
103133
def get_level_values(self, level: str | int) -> Index: ...
104-
def unique(self, level=...): ...
134+
@overload # type: ignore[override]
135+
def unique( # pyrefly: ignore[bad-override]
136+
self, level: None = None
137+
) -> MultiIndex: ...
138+
@overload
139+
def unique( # ty: ignore[invalid-method-override] # pyright: ignore[reportIncompatibleMethodOverride]
140+
self, level: Level
141+
) -> (
142+
Index
143+
): ... # ty: ignore[invalid-method-override] # pyrefly: ignore[bad-override]
105144
def to_frame( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
106145
self,
107146
index: bool = True,
108147
name: list[HashableT] = ...,
109148
allow_duplicates: bool = False,
110149
) -> pd.DataFrame: ...
111150
def to_flat_index(self) -> Index: ...
112-
def remove_unused_levels(self): ...
151+
def remove_unused_levels(self) -> MultiIndex: ...
113152
@property
114153
def nlevels(self) -> int: ...
115154
@property
116-
def levshape(self): ...
117-
def __reduce__(self): ...
155+
def levshape(self) -> Shape: ...
118156
@overload # type: ignore[override]
119157
# pyrefly: ignore # bad-override
120158
def __getitem__(
@@ -125,36 +163,32 @@ class MultiIndex(Index):
125163
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] # ty: ignore[invalid-method-override]
126164
self, key: int
127165
) -> tuple[Hashable, ...]: ...
128-
def append(self, other): ...
129-
def repeat(self, repeats, axis=...): ...
130-
def drop(self, codes, level: Level | None = None, errors: str = "raise") -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
166+
@overload # type: ignore[override]
167+
def append(self, other: MultiIndex | Sequence[MultiIndex]) -> MultiIndex: ...
168+
@overload
169+
def append( # pyright: ignore[reportIncompatibleMethodOverride]
170+
self, other: Index | Sequence[Index]
171+
) -> Index: ... # pyrefly: ignore[bad-override]
172+
def drop(self, codes: Level | Sequence[Level], level: Level | None = None, errors: str = "raise") -> MultiIndex: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
131173
def swaplevel(self, i: int = -2, j: int = -1) -> Self: ...
132-
def reorder_levels(self, order): ...
174+
def reorder_levels(self, order: Sequence[Level]) -> MultiIndex: ...
133175
def sortlevel(
134176
self,
135177
level: Level | Sequence[Level] = 0,
136178
ascending: bool = True,
137179
sort_remaining: bool = True,
138180
na_position: NaPosition = "first",
139-
): ...
140-
@final
141-
def get_indexer(self, target, method=..., limit=..., tolerance=...): ...
142-
def get_indexer_non_unique(self, target): ...
143-
def reindex(self, target, method=..., level=..., limit=..., tolerance=...): ...
144-
def get_slice_bound(
145-
self, label: Hashable | Sequence[Hashable], side: str
146-
) -> int: ...
181+
) -> tuple[MultiIndex, np_1darray_intp]: ...
147182
def get_loc_level(
148-
self, key, level: Level | list[Level] | None = None, drop_level: bool = True
149-
): ...
150-
def get_locs(self, seq): ...
183+
self,
184+
key: Label | Sequence[Label],
185+
level: Level | Sequence[Level] | None = None,
186+
drop_level: bool = True,
187+
) -> tuple[int | slice | np_1darray_bool, Index]: ...
188+
def get_locs(self, seq: Level | Sequence[Level]) -> np_1darray_intp: ...
151189
def truncate(
152190
self, before: IndexLabel | None = None, after: IndexLabel | None = None
153-
): ...
154-
def equals(self, other) -> bool: ...
155-
def equal_levels(self, other): ...
156-
def insert(self, loc, item): ...
157-
def delete(self, loc): ...
191+
) -> MultiIndex: ...
158192
@overload # type: ignore[override]
159193
def isin( # pyrefly: ignore[bad-override]
160194
self, values: Iterable[Any], level: Level

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ ignore = [
205205
"PYI042", # https://docs.astral.sh/ruff/rules/snake-case-type-alias/
206206
"ERA001", "PLR0402", "PLC0105"
207207
]
208-
"multi.pyi" = [
209-
# TODO: remove when multi.pyi is fully typed
210-
"ANN001", "ANN201", "ANN204", "ANN206",
211-
]
212208
"indexing.pyi" = [
213209
# TODO: remove when indexing.pyi is fully typed
214210
"ANN001", "ANN201", "ANN204", "ANN206",

tests/_typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
np_1darray_complex,
3232
np_1darray_dt,
3333
np_1darray_float,
34+
np_1darray_int8,
3435
np_1darray_int64,
3536
np_1darray_intp,
3637
np_1darray_object,
@@ -81,6 +82,7 @@
8182
"np_ndarray_dt",
8283
"np_1darray_object",
8384
"np_1darray_td",
85+
"np_1darray_int8",
8486
"np_1darray_int64",
8587
"np_ndarray_num",
8688
"FloatDtypeArg",
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import pandas as pd
5+
from typing_extensions import (
6+
assert_type,
7+
)
8+
9+
from tests import (
10+
check,
11+
)
12+
from tests._typing import (
13+
np_1darray_bool,
14+
np_1darray_int8,
15+
np_1darray_intp,
16+
)
17+
18+
19+
def test_multiindex_unique() -> None:
20+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
21+
check(assert_type(mi.unique(), pd.MultiIndex), pd.MultiIndex)
22+
check(assert_type(mi.unique(level=0), pd.Index), pd.Index)
23+
24+
25+
def test_multiindex_set_levels() -> None:
26+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
27+
res = mi.set_levels([[10, 20, 30], [40, 50, 60]])
28+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
29+
res = mi.set_levels([10, 20, 30], level=0)
30+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
31+
32+
33+
def test_multiindex_codes() -> None:
34+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
35+
check(assert_type(mi.codes, list[np_1darray_int8]), list)
36+
37+
38+
def test_multiindex_set_codes() -> None:
39+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
40+
res = mi.set_codes([[0, 1, 2], [0, 1, 2]])
41+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
42+
res = mi.set_codes([0, 1, 2], level=0)
43+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
44+
45+
46+
def test_multiindex_view() -> None:
47+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
48+
check(assert_type(mi.view(), pd.MultiIndex), pd.MultiIndex)
49+
check(assert_type(mi.view(np.ndarray), pd.MultiIndex), pd.MultiIndex)
50+
51+
52+
def test_multiindex_remove_unused_levels() -> None:
53+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
54+
res = mi.remove_unused_levels()
55+
check(assert_type(res, pd.MultiIndex), pd.MultiIndex)
56+
57+
58+
def test_multiindex_levshape() -> None:
59+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
60+
ls = mi.levshape
61+
check(assert_type(ls, tuple[int, ...]), tuple, int)
62+
63+
64+
def test_multiindex_append() -> None:
65+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
66+
check(assert_type(mi.append([mi]), pd.MultiIndex), pd.MultiIndex)
67+
check(assert_type(mi.append([pd.Index([1, 2])]), pd.Index), pd.Index)
68+
69+
70+
def test_multiindex_drop() -> None:
71+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
72+
dropped = mi.drop([1])
73+
check(assert_type(dropped, pd.MultiIndex), pd.MultiIndex)
74+
75+
76+
def test_multiindex_reorder_levels() -> None:
77+
mi = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
78+
reordered = mi.reorder_levels([1, 0])
79+
check(assert_type(reordered, pd.MultiIndex), pd.MultiIndex)
80+
81+
82+
def test_multiindex_get_locs() -> None:
83+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
84+
locs = mi.get_locs([1, 4])
85+
check(assert_type(locs, np_1darray_intp), np_1darray_intp)
86+
87+
88+
def test_multiindex_get_loc_level() -> None:
89+
mi = pd.MultiIndex.from_arrays([[1, 2, 3, 1], [4, 5, 6, 4]])
90+
res_0, res_1 = mi.get_loc_level(1, level=0)
91+
check(assert_type(res_0, int | slice | np_1darray_bool), np_1darray_bool)
92+
check(assert_type(res_1, pd.Index), pd.Index)

0 commit comments

Comments
 (0)