diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 754d6b280..ff5c5456e 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -215,18 +215,17 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]): _TT = TypeVar("_TT", bound=Literal[True, False]) -# ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945 -class DFCallable1(Protocol[P]): # ty: ignore[invalid-argument-type] +class DFCallable1(Protocol[P]): def __call__( self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs ) -> Scalar | list[Any] | dict[Hashable, Any]: ... -class DFCallable2(Protocol[P]): # ty: ignore[invalid-argument-type] +class DFCallable2(Protocol[P]): def __call__( self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs ) -> DataFrame | Series: ... -class DFCallable3(Protocol[P]): # ty: ignore[invalid-argument-type] +class DFCallable3(Protocol[P]): def __call__( self, df: Iterable[Any], /, *args: P.args, **kwargs: P.kwargs ) -> float: ... diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index 75d3265b0..e9699eccc 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -71,9 +71,9 @@ from pandas._typing import ( from pandas.plotting import PlotAccessor _ResamplerGroupBy: TypeAlias = ( - DatetimeIndexResamplerGroupby[NDFrameT] # ty: ignore[invalid-argument-type] - | PeriodIndexResamplerGroupby[NDFrameT] # ty: ignore[invalid-argument-type] - | TimedeltaIndexResamplerGroupby[NDFrameT] # ty: ignore[invalid-argument-type] + DatetimeIndexResamplerGroupby[NDFrameT] + | PeriodIndexResamplerGroupby[NDFrameT] + | TimedeltaIndexResamplerGroupby[NDFrameT] ) class GroupBy(BaseGroupBy[NDFrameT]): diff --git a/pandas-stubs/core/reshape/pivot.pyi b/pandas-stubs/core/reshape/pivot.pyi index eb7d9d479..7567ae385 100644 --- a/pandas-stubs/core/reshape/pivot.pyi +++ b/pandas-stubs/core/reshape/pivot.pyi @@ -22,9 +22,6 @@ from pandas.core.series import Series from pandas._typing import ( AnyArrayLike, ArrayLike, - HashableT1, - HashableT2, - HashableT3, Label, Scalar, ScalarT, @@ -33,12 +30,16 @@ from pandas._typing import ( ) _PivotAggCallable: TypeAlias = Callable[[Series], ScalarT] - _PivotAggFunc: TypeAlias = ( _PivotAggCallable[ScalarT] | np.ufunc | Literal["mean", "sum", "count", "min", "max", "median", "std", "var"] ) +_PivotAggFuncTypes: TypeAlias = ( + _PivotAggFunc[ScalarT] + | Sequence[_PivotAggFunc[ScalarT]] + | Mapping[Any, _PivotAggFunc[ScalarT]] +) _NonIterableHashable: TypeAlias = ( str @@ -53,13 +54,11 @@ _NonIterableHashable: TypeAlias = ( | pd.Timedelta ) -_PivotTableIndexTypes: TypeAlias = ( - Label | Sequence[HashableT1] | Series | Grouper | None -) +_PivotTableIndexTypes: TypeAlias = Label | Sequence[Hashable] | Series | Grouper | None _PivotTableColumnsTypes: TypeAlias = ( - Label | Sequence[HashableT2] | Series | Grouper | None + Label | Sequence[Hashable] | Series | Grouper | None ) -_PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None +_PivotTableValuesTypes: TypeAlias = Label | Sequence[Hashable] | None _ExtendedAnyArrayLike: TypeAlias = AnyArrayLike | ArrayLike _Values: TypeAlias = SequenceNotStr[Any] | _ExtendedAnyArrayLike @@ -67,18 +66,10 @@ _Values: TypeAlias = SequenceNotStr[Any] | _ExtendedAnyArrayLike @overload def pivot_table( data: DataFrame, - values: _PivotTableValuesTypes[ - Hashable # ty: ignore[invalid-type-arguments] - ] = None, - index: _PivotTableIndexTypes[Hashable] = None, # ty: ignore[invalid-type-arguments] - columns: _PivotTableColumnsTypes[ - Hashable # ty: ignore[invalid-type-arguments] - ] = None, - aggfunc: ( - _PivotAggFunc[Scalar] - | Sequence[_PivotAggFunc[Scalar]] - | Mapping[Any, _PivotAggFunc[Scalar]] - ) = "mean", + values: _PivotTableValuesTypes = None, + index: _PivotTableIndexTypes = None, + columns: _PivotTableColumnsTypes = None, + aggfunc: _PivotAggFuncTypes[Scalar] = "mean", fill_value: Scalar | None = None, margins: bool = False, dropna: bool = True, @@ -91,21 +82,11 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: _PivotTableValuesTypes[ - Hashable # ty: ignore[invalid-type-arguments] - ] = None, + values: _PivotTableValuesTypes = None, *, index: Grouper, - columns: ( - _PivotTableColumnsTypes[Hashable] # ty: ignore[invalid-type-arguments] - | np_ndarray - | Index[Any] - ) = None, - aggfunc: ( - _PivotAggFunc[Scalar] - | Sequence[_PivotAggFunc[Scalar]] - | Mapping[Any, _PivotAggFunc[Scalar]] - ) = "mean", + columns: _PivotTableColumnsTypes | np_ndarray | Index[Any] = None, + aggfunc: _PivotAggFuncTypes[Scalar] = "mean", fill_value: Scalar | None = None, margins: bool = False, dropna: bool = True, @@ -116,21 +97,11 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: _PivotTableValuesTypes[ - Hashable # ty: ignore[invalid-type-arguments] - ] = None, - index: ( - _PivotTableIndexTypes[Hashable] # ty: ignore[invalid-type-arguments] - | np_ndarray - | Index[Any] - ) = None, + values: _PivotTableValuesTypes = None, + index: _PivotTableIndexTypes | np_ndarray | Index[Any] = None, *, columns: Grouper, - aggfunc: ( - _PivotAggFunc[Scalar] - | Sequence[_PivotAggFunc[Scalar]] - | Mapping[Any, _PivotAggFunc[Scalar]] - ) = "mean", + aggfunc: _PivotAggFuncTypes[Scalar] = "mean", fill_value: Scalar | None = None, margins: bool = False, dropna: bool = True, @@ -141,17 +112,17 @@ def pivot_table( def pivot( data: DataFrame, *, - index: _NonIterableHashable | Sequence[HashableT1] = ..., - columns: _NonIterableHashable | Sequence[HashableT2] = ..., - values: _NonIterableHashable | Sequence[HashableT3] = ..., + index: _NonIterableHashable | Sequence[Hashable] = ..., + columns: _NonIterableHashable | Sequence[Hashable] = ..., + values: _NonIterableHashable | Sequence[Hashable] = ..., ) -> DataFrame: ... @overload def crosstab( index: _Values | list[_Values], columns: _Values | list[_Values], values: _Values, - rownames: list[HashableT1] | None = ..., - colnames: list[HashableT2] | None = ..., + rownames: SequenceNotStr[Hashable] | None = ..., + colnames: SequenceNotStr[Hashable] | None = ..., *, aggfunc: str | np.ufunc | Callable[[Series], float], margins: bool = ..., @@ -164,8 +135,8 @@ def crosstab( index: _Values | list[_Values], columns: _Values | list[_Values], values: None = None, - rownames: list[HashableT1] | None = ..., - colnames: list[HashableT2] | None = ..., + rownames: SequenceNotStr[Hashable] | None = ..., + colnames: SequenceNotStr[Hashable] | None = ..., aggfunc: None = None, margins: bool = ..., margins_name: str = ..., diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 9e2e87949..c3d237eb9 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -1716,8 +1716,7 @@ def m2(x: pd.Series) -> int: colnames: list[tuple[str]] = [("a",)] check( assert_type( - pd.crosstab(a, b, colnames=colnames, rownames=rownames), - pd.DataFrame, + pd.crosstab(a, b, colnames=colnames, rownames=rownames), pd.DataFrame ), pd.DataFrame, )