Skip to content

Commit 3908bc6

Browse files
fix
Added a union-aware fast path in expr_infer_with_hint plus branch-aware list inference so we can explore hints that don’t contain type variables before the generic ones. Together with the new type_contains_var (now driven by Type::universe) and prefer_union_branch_without_vars, contextual typing no longer eagerly binds callable/lambda type vars when a non-generic branch of a union would have sufficed.
1 parent 8f545ea commit 3908bc6

File tree

3 files changed

+65
-12
lines changed

3 files changed

+65
-12
lines changed

pyrefly/lib/alt/answers_solver.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,4 +857,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
857857
pub fn error_swallower(&self) -> ErrorCollector {
858858
ErrorCollector::new(self.module().dupe(), ErrorStyle::Never)
859859
}
860+
861+
pub fn prefer_union_branch_without_vars(&self, ty: &Type) -> Option<Type> {
862+
if let Type::Union(options) = ty {
863+
let mut reordered = options.clone();
864+
reordered.sort_by_key(|option| self.type_contains_var(option));
865+
if reordered == *options {
866+
None
867+
} else {
868+
Some(Type::Union(reordered))
869+
}
870+
} else {
871+
None
872+
}
873+
}
874+
875+
pub(crate) fn type_contains_var(&self, ty: &Type) -> bool {
876+
let mut has_var = false;
877+
ty.universe(&mut |t| {
878+
if matches!(t, Type::Var(_)) {
879+
has_var = true;
880+
}
881+
});
882+
has_var
883+
}
860884
}

pyrefly/lib/alt/callable.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
541541
// We ignore positional-only parameters because they can't be passed in by name.
542542
seen_names.insert(name, ty);
543543
}
544+
let ty = if let Some(reordered) = self.prefer_union_branch_without_vars(ty)
545+
{
546+
type_owner.push(reordered)
547+
} else {
548+
ty
549+
};
544550
arg_pre.post_check(
545551
self,
546552
callable_name,
@@ -568,17 +574,25 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
568574
ty,
569575
name,
570576
kind: PosParamKind::Variadic,
571-
}) => arg_pre.post_check(
572-
self,
573-
callable_name,
574-
ty,
575-
name,
576-
true,
577-
arg.range(),
578-
arg_errors,
579-
call_errors,
580-
context,
581-
),
577+
}) => {
578+
let ty = if let Some(reordered) = self.prefer_union_branch_without_vars(ty)
579+
{
580+
type_owner.push(reordered)
581+
} else {
582+
ty
583+
};
584+
arg_pre.post_check(
585+
self,
586+
callable_name,
587+
ty,
588+
name,
589+
true,
590+
arg.range(),
591+
arg_errors,
592+
call_errors,
593+
context,
594+
)
595+
}
582596
None => {
583597
arg_pre.post_infer(self, arg_errors);
584598
if !arg_pre.is_star() {

pyrefly/lib/alt/expr.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,19 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
231231
hint: Option<HintRef>,
232232
errors: &ErrorCollector,
233233
) -> Type {
234+
if let Some(hint_ref) = hint
235+
&& let Type::Union(options) = hint_ref.ty()
236+
{
237+
let mut branches: Vec<&Type> = options.iter().collect();
238+
branches.sort_by_key(|option| self.type_contains_var(option));
239+
for option in branches {
240+
let branch_hint = HintRef::new(option, hint_ref.errors());
241+
let ty = self.expr_infer_with_hint(x, Some(branch_hint), errors);
242+
if self.is_subset_eq(&ty, option) {
243+
return ty;
244+
}
245+
}
246+
}
234247
self.expr_infer_type_info_with_hint(x, hint, errors)
235248
.into_ty()
236249
}
@@ -423,7 +436,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
423436
if let Some(hint_ref) = hint.as_ref()
424437
&& let Type::Union(options) = hint_ref.ty()
425438
{
426-
for option in options {
439+
let mut branches: Vec<&Type> = options.iter().collect();
440+
branches.sort_by_key(|option| self.type_contains_var(option));
441+
for option in branches {
427442
let branch_hint =
428443
self.decompose_list(HintRef::new(option, hint_ref.errors()));
429444
let ty = self.list_with_hint(x, branch_hint, errors);

0 commit comments

Comments
 (0)