Skip to content

Commit

Permalink
Fixed a bug that led to incorrect type evaluation when an inferred me…
Browse files Browse the repository at this point in the history
…thod return type includes a union where the subtypes are conditioned on constraints of a constrained TypeVar that parameterizes the class. In this case, one or more of these subtypes should be eliminated when a specialized class is bound to the method. This addresses #6446. (#6449)
  • Loading branch information
erictraut authored Nov 15, 2023
1 parent 7a4cf3d commit 76564a7
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 5 deletions.
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21143,7 +21143,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
args?: ValidateArgTypeParams[],
inferTypeIfNeeded = true
) {
const specializedReturnType = FunctionType.getSpecializedReturnType(type);
const specializedReturnType = FunctionType.getSpecializedReturnType(type, /* includeInferred */ false);
if (specializedReturnType) {
return adjustCallableReturnType(specializedReturnType, /* trackedSignatures */ undefined);
}
Expand Down
51 changes: 51 additions & 0 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2885,6 +2885,11 @@ export function requiresSpecialization(
}

function _requiresSpecialization(type: Type, options?: RequiresSpecializationOptions, recursionCount = 0): boolean {
// If the type is conditioned on a TypeVar, it may need to be specialized.
if (type.condition) {
return true;
}

switch (type.category) {
case TypeCategory.Class: {
if (ClassType.isPseudoGenericClass(type) && options?.ignorePseudoGeneric) {
Expand Down Expand Up @@ -3325,6 +3330,12 @@ class TypeVarTransformer {

type = this.transformGenericTypeAlias(type, recursionCount);

// If the type is conditioned on a type variable, see if the condition
// still applies.
if (type.condition) {
type = this.transformConditionalType(type, recursionCount);
}

// Shortcut the operation if possible.
if (!requiresSpecialization(type)) {
return type;
Expand Down Expand Up @@ -3551,6 +3562,11 @@ class TypeVarTransformer {
: type;
}

transformConditionalType(type: Type, recursionCount: number): Type {
// By default, do not perform any transform.
return type;
}

transformTypeVarsInClassType(classType: ClassType, recursionCount: number): ClassType {
const typeParams = ClassType.getTypeParameters(classType);

Expand Down Expand Up @@ -4246,6 +4262,41 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer {
return undefined;
}

override transformConditionalType(type: Type, recursionCount: number): Type {
if (!type.condition) {
return type;
}

const signatureContext = this._typeVarContext.getSignatureContext(
this._activeTypeVarSignatureContextIndex ?? 0
);

for (const condition of type.condition) {
// This doesn't apply to bound type variables.
if (!condition.isConstrainedTypeVar) {
continue;
}

const typeVarEntry = signatureContext.getTypeVarByName(condition.typeVarName);
if (!typeVarEntry || condition.constraintIndex >= typeVarEntry.typeVar.details.constraints.length) {
continue;
}

const value = signatureContext.getTypeVarType(typeVarEntry.typeVar);
if (!value) {
continue;
}

const constraintType = typeVarEntry.typeVar.details.constraints[condition.constraintIndex];

// If this violates the constraint, substitute a Never type.
if (!isTypeSame(constraintType, value)) {
return NeverType.createNever();
}
}
return type;
}

override doForEachSignatureContext(callback: () => FunctionType): FunctionType | OverloadedFunctionType {
const signatureContexts = this._typeVarContext.getSignatureContexts();

Expand Down
4 changes: 4 additions & 0 deletions packages/pyright-internal/src/analyzer/typeVarContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ export class TypeVarSignatureContext {
return this._typeVarMap.get(key);
}

getTypeVarByName(key: string): TypeVarMapEntry | undefined {
return this._typeVarMap.get(key);
}

getTypeVars(): TypeVarMapEntry[] {
const entries: TypeVarMapEntry[] = [];

Expand Down
18 changes: 14 additions & 4 deletions packages/pyright-internal/src/analyzer/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2034,10 +2034,20 @@ export namespace FunctionType {
}
}

export function getSpecializedReturnType(type: FunctionType) {
return type.specializedTypes && type.specializedTypes.returnType
? type.specializedTypes.returnType
: type.details.declaredReturnType;
export function getSpecializedReturnType(type: FunctionType, includeInferred = true) {
if (type.specializedTypes?.returnType) {
return type.specializedTypes.returnType;
}

if (type.details.declaredReturnType) {
return type.details.declaredReturnType;
}

if (includeInferred) {
return type.inferredReturnType;
}

return undefined;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This sample tests the case where an inferred method return type is
# a union with subtypes that are conditioned on different constraints of
# a constrained TypeVar. When the method is bound, one or more of these
# subtypes should be eliminated.

from typing import Generic, TypeVar, Awaitable

T = TypeVar("T")


class Async:
def fn(self, returnable: T) -> Awaitable[T]:
...


class Sync:
def fn(self, returnable: T) -> T:
...


T = TypeVar("T", Async, Sync)


class A(Generic[T]):
def __init__(self, client: T):
self._client = client

def method1(self):
return self._client.fn(7)


a1 = A(Async())
r1 = a1.method1()
reveal_type(r1, expected_text="Awaitable[int]*")

a2 = A(Sync())
r2 = a2.method1()
reveal_type(r2, expected_text="int*")
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,12 @@ test('ConstrainedTypeVar17', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('ConstrainedTypeVar18', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['constrainedTypeVar18.py']);

TestUtils.validateResults(analysisResults, 0);
});

test('MissingTypeArg1', () => {
const configOptions = new ConfigOptions('.');

Expand Down

0 comments on commit 76564a7

Please sign in to comment.