Skip to content

Commit

Permalink
Fixed a few places where union order resulted in different type evalu… (
Browse files Browse the repository at this point in the history
#6306)

* Fixed a few places where union order resulted in different type evaluation behaviors. This addresses #6302.

* Fixed style issue.
  • Loading branch information
erictraut authored Nov 3, 2023
1 parent 91c915d commit 2285641
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 39 deletions.
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/analyzer/constraintSolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import {
isEffectivelyInstantiable,
isPartlyUnknown,
mapSubtypes,
sortTypes,
specializeTupleClass,
specializeWithDefaultTypeArgs,
transformExpectedType,
Expand Down Expand Up @@ -1090,7 +1091,7 @@ export function populateTypeVarContextBasedOnExpectedType(
if (isUnion(synthTypeVar)) {
let foundSynthTypeVar: TypeVarType | undefined;

synthTypeVar.subtypes.forEach((subtype) => {
sortTypes(synthTypeVar.subtypes).forEach((subtype) => {
if (
isTypeVar(subtype) &&
subtype.details.isSynthesized &&
Expand Down
66 changes: 37 additions & 29 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3762,6 +3762,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
type: Type,
conditionFilter: TypeCondition[] | undefined,
callback: (expandedSubtype: Type, unexpandedSubtype: Type, isLastIteration: boolean) => Type | undefined,
sortSubtypes = false,
recursionCount = 0
): Type {
const newSubtypes: Type[] = [];
Expand All @@ -3772,42 +3773,47 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions

expandedType = transformPossibleRecursiveTypeAlias(expandedType);

doForEachSubtype(expandedType, (subtype, index, allSubtypes) => {
if (conditionFilter) {
const filteredType = applyConditionFilterToType(subtype, conditionFilter, recursionCount);
if (!filteredType) {
return undefined;
}
doForEachSubtype(
expandedType,
(subtype, index, allSubtypes) => {
if (conditionFilter) {
const filteredType = applyConditionFilterToType(subtype, conditionFilter, recursionCount);
if (!filteredType) {
return undefined;
}

subtype = filteredType;
}
subtype = filteredType;
}

let transformedType = callback(
subtype,
unexpandedType,
isLastSubtype && index === allSubtypes.length - 1
);
if (transformedType !== unexpandedType) {
typeChanged = true;
}
if (transformedType) {
// Apply the type condition if it's associated with a constrained TypeVar.
const typeCondition = getTypeCondition(subtype)?.filter(
(condition) => condition.isConstrainedTypeVar
let transformedType = callback(
subtype,
unexpandedType,
isLastSubtype && index === allSubtypes.length - 1
);

if (typeCondition && typeCondition.length > 0) {
transformedType = addConditionToType(transformedType, typeCondition);
if (transformedType !== unexpandedType) {
typeChanged = true;
}
if (transformedType) {
// Apply the type condition if it's associated with a constrained TypeVar.
const typeCondition = getTypeCondition(subtype)?.filter(
(condition) => condition.isConstrainedTypeVar
);

newSubtypes.push(transformedType);
}
return undefined;
});
if (typeCondition && typeCondition.length > 0) {
transformedType = addConditionToType(transformedType, typeCondition);
}

newSubtypes.push(transformedType);
}
return undefined;
},
sortSubtypes
);
}

if (isUnion(type)) {
type.subtypes.forEach((subtype, index) => {
const subtypes = sortSubtypes ? sortTypes(type.subtypes) : type.subtypes;
subtypes.forEach((subtype, index) => {
expandSubtype(subtype, index === type.subtypes.length - 1);
});
} else {
Expand Down Expand Up @@ -3867,6 +3873,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
(expandedSubtype) => {
return expandedSubtype;
},
/* sortSubtypes */ undefined,
recursionCount
);

Expand Down Expand Up @@ -8998,7 +9005,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
allowDiagnostics: true,
}
);
}
},
/* sortSubtypes */ true
);

// If we ended up with a "Never" type because all code paths returned
Expand Down
10 changes: 7 additions & 3 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -541,15 +541,19 @@ function compareTypes(a: Type, b: Type, recursionCount = 0): number {
return bParam.category - aParam.category;
}

const typeComparison = compareTypes(aParam.type, bParam.type);
const typeComparison = compareTypes(
FunctionType.getEffectiveParameterType(a, i),
FunctionType.getEffectiveParameterType(bFunc, i)
);

if (typeComparison !== 0) {
return typeComparison;
}
}

const returnTypeComparison = compareTypes(
a.details.declaredReturnType ?? UnknownType.create(),
bFunc.details.declaredReturnType ?? UnknownType.create()
FunctionType.getSpecializedReturnType(a) ?? UnknownType.create(),
FunctionType.getSpecializedReturnType(bFunc) ?? UnknownType.create()
);

if (returnTypeComparison !== 0) {
Expand Down
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/tests/samples/assignment12.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def a_test(x: int):

def b_test(x: int | str):
u = x.upper() # type: ignore
reveal_type(u, expected_text="Unknown | str")
reveal_type(u, expected_text="str | Unknown")

# This should generate an error if reportUnknownVariableType is enabled.
y: str = u
reveal_type(y, expected_text="Unknown | str")
reveal_type(y, expected_text="str | Unknown")
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/samples/call11.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def func() -> Either[int, str]:


result = func().map_left(lambda lv: lv + 1).map_right(lambda rv: rv + "a")
reveal_type(result, expected_text="Left[int] | Right[str]")
reveal_type(result, expected_text="Right[str] | Left[int]")
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class ClassD(TypedDict):
def foo(a: CorD):
reveal_type(a, expected_text="ClassC | ClassD")
options = a.get("options", [])
reveal_type(options, expected_text="Any | list[Any] | list[ClassC | ClassD]")
reveal_type(options, expected_text="list[ClassC | ClassD] | Any | list[Any]")

for option in options:
reveal_type(option, expected_text="Any | ClassC | ClassD")
reveal_type(option["type"], expected_text="Any | int")
reveal_type(option, expected_text="ClassC | ClassD | Any")
reveal_type(option["type"], expected_text="int | Any")

0 comments on commit 2285641

Please sign in to comment.