Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pyrefly/lib/solver/subset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> {
l_arg = l_args.next();
u_arg = u_args.next();
}
// EDGE CASE: Allow PosOnly parameters to match Pos parameters in protocols
// This handles cases like list.index() (which has position-only params) matching
// SequenceNotStr.index() from pandas 2.x typeshed stubs (which incorrectly
// lacks position-only markers, fixed in pandas 3.0).
// From a typing perspective, this is sound: if a protocol allows a parameter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems backwards to me. The implementation should be less restrictive, so that it can be used everywhere the protocol can be used.

Copy link
Contributor

@stroxler stroxler Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jack-GitHub12 what I had in mind if we wanted a hard-coding based fix is that we could just say SequenceNotStr <: list[str], so that we never even get to the level of protocol callable checks where this branch gets hit

If we do it that way, the hack is specific to Pandas SequenceNotStr instead of affecting arbitrary protocol checks

// to be passed by position or keyword (Pos), an implementation that only allows
// positional (PosOnly) is more restrictive, which is acceptable.
(Some(Param::PosOnly(_, l, l_req)), Some(Param::Pos(_, u, u_req)))
if (*u_req == Required::Required || matches!(l_req, Required::Optional(_))) =>
{
self.is_subset_eq(u, l)?;
l_arg = l_args.next();
u_arg = u_args.next();
}
(Some(Param::Pos(l_name, l, l_req)), Some(Param::Pos(u_name, u, u_req)))
if *u_req == Required::Required || matches!(l_req, Required::Optional(_)) =>
{
Expand Down
1 change: 1 addition & 0 deletions pyrefly/lib/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ mod natural;
mod new_type;
mod operators;
mod overload;
mod pandas;
mod paramspec;
mod pattern_match;
mod perf;
Expand Down
195 changes: 195 additions & 0 deletions pyrefly/lib/test/pandas/dataframe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

use crate::test::util::TestEnv;
use crate::testcase;

testcase!(
test_dataframe_list_str_columns,
{
let mut env = TestEnv::new();
// Add corrected pandas stubs inline
env.add(
"pandas._typing",
r#"
from typing import Any, Iterator, Protocol, Sequence, TypeVar, overload
from typing_extensions import SupportsIndex
_T_co = TypeVar("_T_co", covariant=True)

class SequenceNotStr(Protocol[_T_co]):
@overload
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
@overload
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
def __contains__(self, value: object, /) -> bool: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T_co]: ...
# FIXED: All parameters position-only to match list.index
def index(self, value: Any, start: int = ..., stop: int = ..., /) -> int: ...
def count(self, value: Any, /) -> int: ...
def __reversed__(self) -> Iterator[_T_co]: ...
"#,
);
env.add(
"pandas.core.frame",
r#"
from typing import Any
from pandas._typing import SequenceNotStr
Axes = SequenceNotStr[Any] | range

class DataFrame:
def __init__(
self,
data: Any = None,
index: Axes | None = None,
columns: Axes | None = None,
dtype: Any = None,
copy: bool | None = None,
) -> None: ...
"#,
);
env.add(
"pandas",
r#"
from pandas.core.frame import DataFrame as DataFrame
"#,
);
env
},
r#"
import pandas as pd

# This should work: passing list[str] for columns
df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["A", "B", "C"])
"#,
);

testcase!(
test_dataframe_list_str_both,
{
let mut env = TestEnv::new();
env.add(
"pandas._typing",
r#"
from typing import Any, Iterator, Protocol, Sequence, TypeVar, overload
from typing_extensions import SupportsIndex
_T_co = TypeVar("_T_co", covariant=True)

class SequenceNotStr(Protocol[_T_co]):
@overload
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
@overload
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
def __contains__(self, value: object, /) -> bool: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T_co]: ...
# FIXED: All parameters position-only to match list.index
def index(self, value: Any, start: int = ..., stop: int = ..., /) -> int: ...
def count(self, value: Any, /) -> int: ...
def __reversed__(self) -> Iterator[_T_co]: ...
"#,
);
env.add(
"pandas.core.frame",
r#"
from typing import Any
from pandas._typing import SequenceNotStr
Axes = SequenceNotStr[Any] | range

class DataFrame:
def __init__(
self,
data: Any = None,
index: Axes | None = None,
columns: Axes | None = None,
dtype: Any = None,
copy: bool | None = None,
) -> None: ...
"#,
);
env.add(
"pandas",
"from pandas.core.frame import DataFrame as DataFrame",
);
env
},
r#"
import pandas as pd

# Test list[str] for both columns and index
df = pd.DataFrame(
[[1, 2, 3], [4, 5, 6]],
columns=["A", "B", "C"],
index=["row1", "row2"]
)
"#,
);

// Test with BROKEN pandas 2.x stubs (without position-only markers)
// This demonstrates the edge case fix in is_subset_param_list works
testcase!(
test_dataframe_with_broken_stubs,
{
let mut env = TestEnv::new();
// Use pandas 2.x stubs WITHOUT position-only markers (the actual broken stubs)
env.add(
"pandas._typing",
r#"
from typing import Any, Iterator, Protocol, Sequence, TypeVar, overload
from typing_extensions import SupportsIndex
_T_co = TypeVar("_T_co", covariant=True)

class SequenceNotStr(Protocol[_T_co]):
@overload
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
@overload
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
def __contains__(self, value: object, /) -> bool: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T_co]: ...
# BROKEN: Missing position-only markers (actual pandas 2.x stubs)
# This should still work thanks to the edge case in is_subset_param_list
def index(self, value: Any, start: int = ..., stop: int = ...) -> int: ...
def count(self, value: Any, /) -> int: ...
def __reversed__(self) -> Iterator[_T_co]: ...
"#,
);
env.add(
"pandas.core.frame",
r#"
from typing import Any
from pandas._typing import SequenceNotStr
Axes = SequenceNotStr[Any] | range

class DataFrame:
def __init__(
self,
data: Any = None,
index: Axes | None = None,
columns: Axes | None = None,
dtype: Any = None,
copy: bool | None = None,
) -> None: ...
"#,
);
env.add(
"pandas",
r#"
from pandas.core.frame import DataFrame as DataFrame
"#,
);
env
},
r#"
import pandas as pd

# This should work even with broken stubs: list[str] should match SequenceNotStr[Any]
# because list.index() has position-only params, and our edge case allows PosOnly
# to match Pos in protocol checking
df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["A", "B", "C"])
"#,
);
9 changes: 9 additions & 0 deletions pyrefly/lib/test/pandas/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#![cfg(test)]
mod dataframe;