diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index ca4d5b36bccf..83c81d8f6e27 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -6558,9 +6558,11 @@ export class Checker extends ParseTreeWalker { } const baseClass = baseClassAndSymbol.classType; - const childClassSelf = ClassType.cloneAsInstance(selfSpecializeClass(childClassType)); + const childClassSelf = ClassType.cloneAsInstance( + selfSpecializeClass(childClassType, { useBoundTypeVars: true }) + ); - let baseType = partiallySpecializeType( + const baseType = partiallySpecializeType( this._evaluator.getEffectiveTypeOfSymbol(baseClassAndSymbol.symbol), baseClass, this._evaluator.getTypeClassType(), @@ -6575,7 +6577,6 @@ export class Checker extends ParseTreeWalker { ); if (childClassType.shared.typeVarScopeId) { - baseType = makeTypeVarsBound(baseType, [childClassType.shared.typeVarScopeId]); overrideType = makeTypeVarsBound(overrideType, [childClassType.shared.typeVarScopeId]); } diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 618f3d2ccaa6..78b9d8f2aba2 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -26812,6 +26812,7 @@ export function createTypeEvaluator( ): boolean { const baseParamDetails = getParamListDetails(baseMethod); const overrideParamDetails = getParamListDetails(overrideMethod); + const constraints = new ConstraintTracker(); let canOverride = true; @@ -26853,8 +26854,8 @@ export function createTypeEvaluator( overrideArgsType, baseParamDetails.params[i].type, diag?.createAddendum(), - /* constraints */ undefined, - AssignTypeFlags.Default + constraints, + AssignTypeFlags.Contravariant ) ) { LocAddendum.overrideParamType().format({ @@ -26976,8 +26977,8 @@ export function createTypeEvaluator( overrideParamType, baseParamType, diag?.createAddendum(), - /* constraints */ undefined, - AssignTypeFlags.Default + constraints, + AssignTypeFlags.Contravariant ) ) { diag?.addMessage( @@ -27038,8 +27039,8 @@ export function createTypeEvaluator( overrideParamType, baseParamType, diag?.createAddendum(), - /* constraints */ undefined, - AssignTypeFlags.Default + constraints, + AssignTypeFlags.Contravariant ) ) { diag?.addMessage( @@ -27083,8 +27084,8 @@ export function createTypeEvaluator( targetParamType, paramInfo.type, diag?.createAddendum(), - /* constraints */ undefined, - AssignTypeFlags.Default + constraints, + AssignTypeFlags.Contravariant ) ) { diag?.addMessage( @@ -27164,7 +27165,7 @@ export function createTypeEvaluator( baseReturnType, overrideReturnType, diag?.createAddendum(), - /* constraints */ undefined, + constraints, AssignTypeFlags.Default ) ) { diff --git a/packages/pyright-internal/src/tests/samples/methodOverride1.py b/packages/pyright-internal/src/tests/samples/methodOverride1.py index b4df2b747a02..06b00cd1c74f 100644 --- a/packages/pyright-internal/src/tests/samples/methodOverride1.py +++ b/packages/pyright-internal/src/tests/samples/methodOverride1.py @@ -18,10 +18,10 @@ P = ParamSpec("P") T = TypeVar("T") S = TypeVar("S") +U = TypeVar("U", bound=int) -def decorator(func: Callable[P, None]) -> Callable[P, int]: - ... +def decorator(func: Callable[P, None]) -> Callable[P, int]: ... class ParentClass: @@ -73,21 +73,16 @@ def my_method15(self, a: int) -> int: def my_method16(self, a: int) -> int: return 1 - def my_method17(self, a: str, b: int, c: float, d: bool) -> None: - ... + def my_method17(self, a: str, b: int, c: float, d: bool) -> None: ... - def my_method18(self, a: str, b: int, c: float, d: bool) -> None: - ... + def my_method18(self, a: str, b: int, c: float, d: bool) -> None: ... - def my_method19(self, a: str, b: int, c: float, d: bool) -> None: - ... + def my_method19(self, a: str, b: int, c: float, d: bool) -> None: ... @classmethod - def my_method20(cls: type[T_ParentClass], a: str) -> T_ParentClass: - ... + def my_method20(cls: type[T_ParentClass], a: str) -> T_ParentClass: ... - def my_method21(self, var: int) -> None: - ... + def my_method21(self, var: int) -> None: ... def _protected_method1(self, a: int): return 1 @@ -95,86 +90,61 @@ def _protected_method1(self, a: int): def __private_method1(self, a: int): return 1 - def my_method22(self, a: str, b: int, c: float, d: bool) -> None: - ... + def my_method22(self, a: str, b: int, c: float, d: bool) -> None: ... - def my_method23(self, a: str = "") -> None: - ... + def my_method23(self, a: str = "") -> None: ... - def my_method24(self, a: str) -> None: - ... + def my_method24(self, a: str) -> None: ... - def my_method25(self, *, a: str = "") -> None: - ... + def my_method25(self, *, a: str = "") -> None: ... - def my_method26(self, *, a: str) -> None: - ... + def my_method26(self, *, a: str) -> None: ... - def my_method27(self, a: object, /) -> None: - ... + def my_method27(self, a: object, /) -> None: ... - def my_method28(self, __a: object) -> None: - ... + def my_method28(self, __a: object) -> None: ... @classmethod - def my_method29(cls, /) -> None: - ... + def my_method29(cls, /) -> None: ... @classmethod - def my_method30(cls, /) -> None: - ... + def my_method30(cls, /) -> None: ... @staticmethod - def my_method31(a: "type[ParentClass]", /) -> None: - ... + def my_method31(a: "type[ParentClass]", /) -> None: ... @staticmethod - def my_method32(a: "type[ParentClass]", /) -> None: - ... + def my_method32(a: "type[ParentClass]", /) -> None: ... - def my_method33(self, /) -> None: - ... + def my_method33(self, /) -> None: ... - def my_method34(self, /) -> None: - ... + def my_method34(self, /) -> None: ... - def my_method35(self, *, a: int) -> None: - ... + def my_method35(self, *, a: int) -> None: ... - def my_method36(self, *, a: int) -> None: - ... + def my_method36(self, *, a: int) -> None: ... - def my_method37(self, a: int, /) -> None: - ... + def my_method37(self, a: int, /) -> None: ... - def my_method38(self, a: int, /) -> None: - ... + def my_method38(self, a: int, /) -> None: ... - def my_method39(self, a: int, /) -> None: - ... + def my_method39(self, a: int, /) -> None: ... - def my_method40(self, a: int, /) -> None: - ... + def my_method40(self, a: int, /) -> None: ... - def my_method41(self, a: int, b: str, c: str) -> None: - ... + def my_method41(self, a: int, b: str, c: str) -> None: ... - def my_method42(self, a: int, b: int, c: str) -> None: - ... + def my_method42(self, a: int, b: int, c: str) -> None: ... my_method43: Callable[..., None] - def my_method44(self, *args: object, **kwargs: object) -> None: - ... + def my_method44(self, *args: object, **kwargs: object) -> None: ... - def my_method45(self, __i: int) -> None: - ... + def my_method45(self, __i: int) -> None: ... - def __my_method46__(self, x: int) -> None: - ... + def __my_method46__(self, x: int) -> None: ... - def __my_method47__(self, x: int) -> None: - ... + def __my_method47__(self, x: int) -> None: ... T_ChildClass = TypeVar("T_ChildClass", bound="ChildClass") @@ -243,24 +213,19 @@ def my_method14(self, a: int) -> int | str: class my_method16: pass - def my_method17(self, *args: object, **kwargs: object) -> None: - ... + def my_method17(self, *args: object, **kwargs: object) -> None: ... - def my_method18(self, a: str, *args: object, **kwargs: object) -> None: - ... + def my_method18(self, a: str, *args: object, **kwargs: object) -> None: ... # This should generate an error because b param doesn't match a in name. - def my_method19(self, b: str, *args: object, **kwargs: object) -> None: - ... + def my_method19(self, b: str, *args: object, **kwargs: object) -> None: ... @classmethod - def my_method20(cls: type[T_ChildClass], a: str) -> T_ChildClass: - ... + def my_method20(cls: type[T_ChildClass], a: str) -> T_ChildClass: ... # This should generate an error. @decorator - def my_method21(self, var: int) -> None: - ... + def my_method21(self, var: int) -> None: ... # This should generate an error. def _protected_method1(self): @@ -270,143 +235,111 @@ def __private_method1(self): return 1 # This should generate an error. - def my_method22(self, a: str, b: int, c: float, d: bool, /) -> None: - ... + def my_method22(self, a: str, b: int, c: float, d: bool, /) -> None: ... # This should generate an error because a is missing a default value. - def my_method23(self, a: str) -> None: - ... + def my_method23(self, a: str) -> None: ... - def my_method24(self, a: str = "") -> None: - ... + def my_method24(self, a: str = "") -> None: ... # This should generate an error because a is missing a default value. - def my_method25(self, *, a: str) -> None: - ... + def my_method25(self, *, a: str) -> None: ... - def my_method26(self, *, a: str = "") -> None: - ... + def my_method26(self, *, a: str = "") -> None: ... - def my_method27(self, __a: object) -> None: - ... + def my_method27(self, __a: object) -> None: ... - def my_method28(self, a: object, /) -> None: - ... + def my_method28(self, a: object, /) -> None: ... # This should generate an error because it is not a classmethod. - def my_method29(self, /) -> None: - ... + def my_method29(self, /) -> None: ... # This should generate an error because it is not a classmethod. @staticmethod - def my_method30(a: type[ParentClass], /) -> None: - ... + def my_method30(a: type[ParentClass], /) -> None: ... # This should generate an error because it is not a staticmethod. @classmethod - def my_method31(cls, /) -> None: - ... + def my_method31(cls, /) -> None: ... # This should generate an error because it is not a staticmethod. - def my_method32(self, /) -> None: - ... + def my_method32(self, /) -> None: ... # This should generate an error because it is not an instance method. @classmethod - def my_method33(cls, /) -> None: - ... + def my_method33(cls, /) -> None: ... # This should generate an error because it is not an instance method. @staticmethod - def my_method34(a: type[ParentClass], /) -> None: - ... + def my_method34(a: type[ParentClass], /) -> None: ... - def my_method35(self, **kwargs: int) -> None: - ... + def my_method35(self, **kwargs: int) -> None: ... # This should generate an error because the method in the parent # class has a keyword-only parameter that is type 'int', and this # isn't compatible with 'str'. - def my_method36(self, **kwargs: str) -> None: - ... + def my_method36(self, **kwargs: str) -> None: ... - def my_method37(self, *args: Any) -> None: - ... + def my_method37(self, *args: Any) -> None: ... # This should generate an error because the number of position-only # parameters doesn't match. - def my_method38(self, **kwargs: Any) -> None: - ... + def my_method38(self, **kwargs: Any) -> None: ... - def my_method39(self, *args: Any) -> None: - ... + def my_method39(self, *args: Any) -> None: ... # This should generate an error because the number of position-only # parameters doesn't match. - def my_method40(self, **kwargs: Any) -> None: - ... + def my_method40(self, **kwargs: Any) -> None: ... # This should generate an error because keyword parameters "a" # and "b" are missing. - def my_method41(self, a: int, *args: str) -> None: - ... + def my_method41(self, a: int, *args: str) -> None: ... # This should generate an error because args doesn't have the right type. - def my_method42(self, a: int, *args: int) -> None: - ... + def my_method42(self, a: int, *args: int) -> None: ... - def my_method43(self, a: int, b: str, c: str) -> None: - ... + def my_method43(self, a: int, b: str, c: str) -> None: ... # This should generate an error because kwargs is missing. - def my_method44(self, *object) -> None: - ... + def my_method44(self, *object) -> None: ... - def my_method45(self, i: int, /) -> None: - ... + def my_method45(self, i: int, /) -> None: ... - def __my_method46__(self, y: int) -> None: - ... + def __my_method46__(self, y: int) -> None: ... # This should generate an error because of a type mismatch. - def __my_method47__(self, y: str) -> None: - ... + def __my_method47__(self, y: str) -> None: ... class A: - def test(self, t: Sequence[int]) -> Sequence[str]: - ... + def test(self, t: Sequence[int]) -> Sequence[str]: ... class GeneralizedArgument(A): - def test(self, t: Iterable[int], bbb: str = "") -> Sequence[str]: - ... + def test(self, t: Iterable[int], bbb: str = "") -> Sequence[str]: ... class NarrowerArgument(A): # This should generate error because list[int] is narrower # than Iterable[int]. - def test(self, t: list[int]) -> Sequence[str]: - ... + def test(self, t: list[int]) -> Sequence[str]: ... class NarrowerReturn(A): - def test(self, t: Sequence[int]) -> list[str]: - ... + def test(self, t: Sequence[int]) -> list[str]: ... class GeneralizedReturn1(A): # This should generate an error because Iterable[str] is # wider than Sequence[str]. - def test(self, t: Sequence[int]) -> Iterable[str]: - ... + def test(self, t: Sequence[int]) -> Iterable[str]: ... class GeneralizedReturn2(A): # This should generate an error because list[int] is # incompatible with Sequence[str]. - def test(self, t: Sequence[int]) -> list[int]: - ... + def test(self, t: Sequence[int]) -> list[int]: ... _T1 = TypeVar("_T1") @@ -529,32 +462,27 @@ def case(self, value: Any) -> Iterable[Any]: class Derived3(Base3): @overload - def case(self, value: int) -> Iterable[int]: - ... + def case(self, value: int) -> Iterable[int]: ... @overload - def case(self, value: float) -> Iterable[float]: - ... + def case(self, value: float) -> Iterable[float]: ... def case(self, value: Any) -> Iterable[Any]: return [] class Base4: - def a(self) -> int: - ... + def a(self) -> int: ... class Base5: - def a(self) -> int: - ... + def a(self) -> int: ... class C(Base4, Base5): # This should generate two error if reportIncompatibleMethodOverride # is enabled. - def a(self) -> float: - ... + def a(self) -> float: ... class MyObject(TypedDict): @@ -562,37 +490,41 @@ class MyObject(TypedDict): class Base6(Generic["T"]): - def method1(self, v: int) -> None: - ... + def method1(self, v: int) -> None: ... - def method2(self, v: T) -> None: - ... + def method2(self, v: T) -> None: ... - def method3(self, v: T) -> None: - ... + def method3(self, v: T) -> None: ... - def method4(self, v: S) -> S: - ... + def method4(self, v: S) -> S: ... - def method5(self, v: S) -> S: - ... + def method5(self, v: S) -> S: ... class Derived6(Base6[int], Generic["T"]): # This should generate an error. - def method1(self, v: T): - ... + def method1(self, v: T): ... # This should generate an error. - def method2(self, v: T) -> None: - ... + def method2(self, v: T) -> None: ... - def method3(self, v: int) -> None: - ... + def method3(self, v: int) -> None: ... + + def method4(self, v: T) -> T: ... + + def method5(self, v: S) -> S: ... + + +class Base7(Generic[T]): + def method1(self, x: T) -> T: + return x + + +class Derived7_1(Base7[T]): + def method1(self, x: S) -> S: + return x - # This should generate an error. - def method4(self, v: T) -> T: - ... - def method5(self, v: S) -> S: - ... +class Derived7_2(Base7[int]): + def method1(self, x: U) -> U: + return x diff --git a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts index 7b57a9fd1544..c6f2890d2d58 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts @@ -925,7 +925,7 @@ test('MethodOverride1', () => { configOptions.diagnosticRuleSet.reportIncompatibleMethodOverride = 'error'; analysisResults = TestUtils.typeAnalyzeSampleFiles(['methodOverride1.py'], configOptions); - TestUtils.validateResults(analysisResults, 41); + TestUtils.validateResults(analysisResults, 40); }); test('MethodOverride2', () => {