From 22856412a156bba17c0bd211a810b994a95a8614 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Thu, 2 Nov 2023 20:02:11 -0700 Subject: [PATCH] =?UTF-8?q?Fixed=20a=20few=20places=20where=20union=20orde?= =?UTF-8?q?r=20resulted=20in=20different=20type=20evalu=E2=80=A6=20(#6306)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixed a few places where union order resulted in different type evaluation behaviors. This addresses #6302. * Fixed style issue. --- .../src/analyzer/constraintSolver.ts | 3 +- .../src/analyzer/typeEvaluator.ts | 66 +++++++++++-------- .../src/analyzer/typeUtils.ts | 10 ++- .../src/tests/samples/assignment12.py | 4 +- .../src/tests/samples/call11.py | 2 +- .../src/tests/samples/recursiveTypeAlias8.py | 6 +- 6 files changed, 52 insertions(+), 39 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/constraintSolver.ts b/packages/pyright-internal/src/analyzer/constraintSolver.ts index 14ef720fafbe..95add04e952c 100644 --- a/packages/pyright-internal/src/analyzer/constraintSolver.ts +++ b/packages/pyright-internal/src/analyzer/constraintSolver.ts @@ -55,6 +55,7 @@ import { isEffectivelyInstantiable, isPartlyUnknown, mapSubtypes, + sortTypes, specializeTupleClass, specializeWithDefaultTypeArgs, transformExpectedType, @@ -1090,7 +1091,7 @@ export function populateTypeVarContextBasedOnExpectedType( if (isUnion(synthTypeVar)) { let foundSynthTypeVar: TypeVarType | undefined; - synthTypeVar.subtypes.forEach((subtype) => { + sortTypes(synthTypeVar.subtypes).forEach((subtype) => { if ( isTypeVar(subtype) && subtype.details.isSynthesized && diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index cf78943cbf0c..543b35133d35 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -3762,6 +3762,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions type: Type, conditionFilter: TypeCondition[] | undefined, callback: (expandedSubtype: Type, unexpandedSubtype: Type, isLastIteration: boolean) => Type | undefined, + sortSubtypes = false, recursionCount = 0 ): Type { const newSubtypes: Type[] = []; @@ -3772,42 +3773,47 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions expandedType = transformPossibleRecursiveTypeAlias(expandedType); - doForEachSubtype(expandedType, (subtype, index, allSubtypes) => { - if (conditionFilter) { - const filteredType = applyConditionFilterToType(subtype, conditionFilter, recursionCount); - if (!filteredType) { - return undefined; - } + doForEachSubtype( + expandedType, + (subtype, index, allSubtypes) => { + if (conditionFilter) { + const filteredType = applyConditionFilterToType(subtype, conditionFilter, recursionCount); + if (!filteredType) { + return undefined; + } - subtype = filteredType; - } + subtype = filteredType; + } - let transformedType = callback( - subtype, - unexpandedType, - isLastSubtype && index === allSubtypes.length - 1 - ); - if (transformedType !== unexpandedType) { - typeChanged = true; - } - if (transformedType) { - // Apply the type condition if it's associated with a constrained TypeVar. - const typeCondition = getTypeCondition(subtype)?.filter( - (condition) => condition.isConstrainedTypeVar + let transformedType = callback( + subtype, + unexpandedType, + isLastSubtype && index === allSubtypes.length - 1 ); - - if (typeCondition && typeCondition.length > 0) { - transformedType = addConditionToType(transformedType, typeCondition); + if (transformedType !== unexpandedType) { + typeChanged = true; } + if (transformedType) { + // Apply the type condition if it's associated with a constrained TypeVar. + const typeCondition = getTypeCondition(subtype)?.filter( + (condition) => condition.isConstrainedTypeVar + ); - newSubtypes.push(transformedType); - } - return undefined; - }); + if (typeCondition && typeCondition.length > 0) { + transformedType = addConditionToType(transformedType, typeCondition); + } + + newSubtypes.push(transformedType); + } + return undefined; + }, + sortSubtypes + ); } if (isUnion(type)) { - type.subtypes.forEach((subtype, index) => { + const subtypes = sortSubtypes ? sortTypes(type.subtypes) : type.subtypes; + subtypes.forEach((subtype, index) => { expandSubtype(subtype, index === type.subtypes.length - 1); }); } else { @@ -3867,6 +3873,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions (expandedSubtype) => { return expandedSubtype; }, + /* sortSubtypes */ undefined, recursionCount ); @@ -8998,7 +9005,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions allowDiagnostics: true, } ); - } + }, + /* sortSubtypes */ true ); // If we ended up with a "Never" type because all code paths returned diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 7ec748539497..5bd553737885 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -541,15 +541,19 @@ function compareTypes(a: Type, b: Type, recursionCount = 0): number { return bParam.category - aParam.category; } - const typeComparison = compareTypes(aParam.type, bParam.type); + const typeComparison = compareTypes( + FunctionType.getEffectiveParameterType(a, i), + FunctionType.getEffectiveParameterType(bFunc, i) + ); + if (typeComparison !== 0) { return typeComparison; } } const returnTypeComparison = compareTypes( - a.details.declaredReturnType ?? UnknownType.create(), - bFunc.details.declaredReturnType ?? UnknownType.create() + FunctionType.getSpecializedReturnType(a) ?? UnknownType.create(), + FunctionType.getSpecializedReturnType(bFunc) ?? UnknownType.create() ); if (returnTypeComparison !== 0) { diff --git a/packages/pyright-internal/src/tests/samples/assignment12.py b/packages/pyright-internal/src/tests/samples/assignment12.py index f823d4fa8f05..ec7d8004c9eb 100644 --- a/packages/pyright-internal/src/tests/samples/assignment12.py +++ b/packages/pyright-internal/src/tests/samples/assignment12.py @@ -13,8 +13,8 @@ def a_test(x: int): def b_test(x: int | str): u = x.upper() # type: ignore - reveal_type(u, expected_text="Unknown | str") + reveal_type(u, expected_text="str | Unknown") # This should generate an error if reportUnknownVariableType is enabled. y: str = u - reveal_type(y, expected_text="Unknown | str") + reveal_type(y, expected_text="str | Unknown") diff --git a/packages/pyright-internal/src/tests/samples/call11.py b/packages/pyright-internal/src/tests/samples/call11.py index c1af3c803078..5f73ff45dab7 100644 --- a/packages/pyright-internal/src/tests/samples/call11.py +++ b/packages/pyright-internal/src/tests/samples/call11.py @@ -41,4 +41,4 @@ def func() -> Either[int, str]: result = func().map_left(lambda lv: lv + 1).map_right(lambda rv: rv + "a") -reveal_type(result, expected_text="Left[int] | Right[str]") +reveal_type(result, expected_text="Right[str] | Left[int]") diff --git a/packages/pyright-internal/src/tests/samples/recursiveTypeAlias8.py b/packages/pyright-internal/src/tests/samples/recursiveTypeAlias8.py index bf410c9ee371..17e71927ab30 100644 --- a/packages/pyright-internal/src/tests/samples/recursiveTypeAlias8.py +++ b/packages/pyright-internal/src/tests/samples/recursiveTypeAlias8.py @@ -31,8 +31,8 @@ class ClassD(TypedDict): def foo(a: CorD): reveal_type(a, expected_text="ClassC | ClassD") options = a.get("options", []) - reveal_type(options, expected_text="Any | list[Any] | list[ClassC | ClassD]") + reveal_type(options, expected_text="list[ClassC | ClassD] | Any | list[Any]") for option in options: - reveal_type(option, expected_text="Any | ClassC | ClassD") - reveal_type(option["type"], expected_text="Any | int") + reveal_type(option, expected_text="ClassC | ClassD | Any") + reveal_type(option["type"], expected_text="int | Any")