From 14b3c704b9202b99357dbbbfb18b4cf468eb158c Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Mon, 13 Jan 2025 15:04:34 -0700 Subject: [PATCH] Improved error reporting for "async with" statement. Added check that return result of `__aexit__` is awaitable and improved error messages for the case where `__enter__`, etc. are present but have incorrect signatures. This addresses #9694. (#9697) --- .../src/analyzer/typeEvaluator.ts | 46 ++++++++++--------- .../src/localization/localize.ts | 5 ++ .../src/localization/package.nls.en-us.json | 10 +++- .../src/tests/samples/coroutines1.py | 2 +- .../src/tests/samples/with1.py | 5 +- 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 066e6cf805ce..6c8b76877181 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -19607,7 +19607,11 @@ export function createTypeEvaluator( const isAsync = node.parent && node.parent.nodeType === ParseNodeType.With && !!node.parent.d.isAsync; if (isOptionalType(exprType)) { - addDiagnostic(DiagnosticRule.reportOptionalContextManager, LocMessage.noneNotUsableWith(), node.d.expr); + addDiagnostic( + DiagnosticRule.reportOptionalContextManager, + isAsync ? LocMessage.noneNotUsableWithAsync() : LocMessage.noneNotUsableWith(), + node.d.expr + ); exprType = removeNoneFromUnion(exprType); } @@ -19620,25 +19624,20 @@ export function createTypeEvaluator( return subtype; } - const additionalHelp = new DiagnosticAddendum(); + const enterDiag = new DiagnosticAddendum(); if (isClass(subtype)) { - let enterType = getTypeOfMagicMethodCall( + const enterTypeResult = getTypeOfMagicMethodCall( subtype, enterMethodName, [], node.d.expr, /* inferenceContext */ undefined, - additionalHelp.createAddendum() - )?.type; - - if (enterType) { - // For "async while", an implicit "await" is performed. - if (isAsync) { - enterType = getTypeOfAwaitable(enterType, node.d.expr); - } + enterDiag.createAddendum() + ); - return enterType; + if (enterTypeResult) { + return isAsync ? getTypeOfAwaitable(enterTypeResult.type, node.d.expr) : enterTypeResult.type; } if (!isAsync) { @@ -19651,15 +19650,15 @@ export function createTypeEvaluator( /* inferenceContext */ undefined )?.type ) { - additionalHelp.addMessage(LocAddendum.asyncHelp()); + enterDiag.addMessage(LocAddendum.asyncHelp()); } } } + const message = isAsync ? LocMessage.typeNotUsableWithAsync() : LocMessage.typeNotUsableWith(); addDiagnostic( DiagnosticRule.reportGeneralTypeIssues, - LocMessage.typeNotUsableWith().format({ type: printType(subtype), method: enterMethodName }) + - additionalHelp.getString(), + message.format({ type: printType(subtype), method: enterMethodName }) + enterDiag.getString(), node.d.expr ); return UnknownType.create(); @@ -19667,6 +19666,8 @@ export function createTypeEvaluator( // Verify that the target has an __exit__ or __aexit__ method defined. const exitMethodName = isAsync ? '__aexit__' : '__exit__'; + const exitDiag = new DiagnosticAddendum(); + doForEachSubtype(exprType, (subtype) => { subtype = makeTopLevelTypeVarsConcrete(subtype); @@ -19676,24 +19677,27 @@ export function createTypeEvaluator( if (isClass(subtype)) { const anyArg: TypeResult = { type: AnyType.create() }; - const exitType = getTypeOfMagicMethodCall( + const exitTypeResult = getTypeOfMagicMethodCall( subtype, exitMethodName, [anyArg, anyArg, anyArg], node.d.expr, - /* inferenceContext */ undefined - )?.type; + /* inferenceContext */ undefined, + exitDiag + ); - if (exitType) { - return; + if (exitTypeResult) { + return isAsync ? getTypeOfAwaitable(exitTypeResult.type, node.d.expr) : exitTypeResult.type; } } addDiagnostic( DiagnosticRule.reportGeneralTypeIssues, - LocMessage.typeNotUsableWith().format({ type: printType(subtype), method: exitMethodName }), + LocMessage.typeNotUsableWith().format({ type: printType(subtype), method: exitMethodName }) + + exitDiag.getString(), node.d.expr ); + return UnknownType.create(); }); if (node.d.target) { diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index f3239d985cd4..d331308e6264 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -698,6 +698,7 @@ export namespace Localizer { export const noneNotIterable = () => getRawString('Diagnostic.noneNotIterable'); export const noneNotSubscriptable = () => getRawString('Diagnostic.noneNotSubscriptable'); export const noneNotUsableWith = () => getRawString('Diagnostic.noneNotUsableWith'); + export const noneNotUsableWithAsync = () => getRawString('Diagnostic.noneNotUsableWithAsync'); export const noneOperator = () => new ParameterizedString<{ operator: string }>(getRawString('Diagnostic.noneOperator')); export const noneUnknownMember = () => @@ -1038,6 +1039,10 @@ export namespace Localizer { new ParameterizedString<{ type: string }>(getRawString('Diagnostic.typeNotSubscriptable')); export const typeNotUsableWith = () => new ParameterizedString<{ type: string; method: string }>(getRawString('Diagnostic.typeNotUsableWith')); + export const typeNotUsableWithAsync = () => + new ParameterizedString<{ type: string; method: string }>( + getRawString('Diagnostic.typeNotUsableWithAsync') + ); export const typeNotSupportBinaryOperator = () => new ParameterizedString<{ leftType: string; rightType: string; operator: string }>( getRawString('Diagnostic.typeNotSupportBinaryOperator') diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index d7a18edb167b..c4dddb38f29d 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -860,6 +860,10 @@ "message": "Object of type \"None\" cannot be used with \"with\"", "comment": "{Locked='None','with'}" }, + "noneNotUsableWithAsync": { + "message": "Object of type \"None\" cannot be used with \"async with\"", + "comment": "{Locked='None','with', 'async}" + }, "noneOperator": { "message": "Operator \"{operator}\" not supported for \"None\"", "comment": "{Locked='None'}" @@ -1333,7 +1337,11 @@ "typeNotSupportBinaryOperatorBidirectional": "Operator \"{operator}\" not supported for types \"{leftType}\" and \"{rightType}\" when expected type is \"{expectedType}\"", "typeNotSupportUnaryOperator": "Operator \"{operator}\" not supported for type \"{type}\"", "typeNotSupportUnaryOperatorBidirectional": "Operator \"{operator}\" not supported for type \"{type}\" when expected type is \"{expectedType}\"", - "typeNotUsableWith": "Object of type \"{type}\" cannot be used with \"with\" because it does not implement {method}", + "typeNotUsableWith": "Object of type \"{type}\" cannot be used with \"with\" because it does not correctly implement {method}", + "typeNotUsableWithAsync": { + "message": "Object of type \"{type}\" cannot be used with \"async with\" because it does not correctly implement {method}", + "comment": ["{Locked='async','with}"] + }, "typeParameterBoundNotAllowed": { "message": "Bound or constraint cannot be used with a variadic type parameter or ParamSpec", "comment": ["{Locked='ParamSpec'}", "'variadic' means that it accepts a variable number of arguments"] diff --git a/packages/pyright-internal/src/tests/samples/coroutines1.py b/packages/pyright-internal/src/tests/samples/coroutines1.py index 8e0af8bcfe35..4a6739dfbfd5 100644 --- a/packages/pyright-internal/src/tests/samples/coroutines1.py +++ b/packages/pyright-internal/src/tests/samples/coroutines1.py @@ -48,7 +48,7 @@ def __await__(self) -> Generator[Any, None, int]: yield 3 return 3 - def __aexit__( + async def __aexit__( self, t: Optional[type] = None, exc: Optional[BaseException] = None, diff --git a/packages/pyright-internal/src/tests/samples/with1.py b/packages/pyright-internal/src/tests/samples/with1.py index 0aca6001c109..06db6ae6a733 100644 --- a/packages/pyright-internal/src/tests/samples/with1.py +++ b/packages/pyright-internal/src/tests/samples/with1.py @@ -78,7 +78,7 @@ class Class4: async def __aenter__(self: _T1) -> _T1: return self - def __aexit__( + async def __aexit__( self, t: Optional[type] = None, exc: Optional[BaseException] = None, @@ -107,8 +107,7 @@ async def __aexit__(self, *args: Any) -> None: return None -class Class6(Class5[int]): - ... +class Class6(Class5[int]): ... async def do():