Skip to content

Commit

Permalink
Additional simplifications for member accesses. (#6294)
Browse files Browse the repository at this point in the history
  • Loading branch information
erictraut authored Nov 2, 2023
1 parent 5a7bfaf commit 53a2b22
Show file tree
Hide file tree
Showing 15 changed files with 217 additions and 323 deletions.
24 changes: 12 additions & 12 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ import {
import {
AssignTypeFlags,
ClassMember,
ClassMemberLookupFlags,
MemberAccessFlags,
applySolvedTypeVars,
buildTypeVarContextFromSpecializedClass,
convertToInstance,
Expand Down Expand Up @@ -2137,7 +2137,7 @@ export class Checker extends ParseTreeWalker {
// Does the class have an operator overload for eq?
const metaclass = leftType.details.effectiveMetaclass;
if (metaclass && isClass(metaclass)) {
if (lookUpClassMember(metaclass, '__eq__', ClassMemberLookupFlags.SkipObjectBaseClass)) {
if (lookUpClassMember(metaclass, '__eq__', MemberAccessFlags.SkipObjectBaseClass)) {
return true;
}
}
Expand Down Expand Up @@ -2176,7 +2176,7 @@ export class Checker extends ParseTreeWalker {
const eqMethod = lookUpClassMember(
ClassType.cloneAsInstantiable(leftType),
'__eq__',
ClassMemberLookupFlags.SkipObjectBaseClass
MemberAccessFlags.SkipObjectBaseClass
);

if (eqMethod) {
Expand Down Expand Up @@ -4519,7 +4519,7 @@ export class Checker extends ParseTreeWalker {
// as Final in parent classes.
private _validateFinalMemberOverrides(classType: ClassType) {
classType.details.fields.forEach((localSymbol, name) => {
const parentSymbol = lookUpClassMember(classType, name, ClassMemberLookupFlags.SkipOriginalClass);
const parentSymbol = lookUpClassMember(classType, name, MemberAccessFlags.SkipOriginalClass);
if (
parentSymbol &&
isInstantiableClass(parentSymbol.classType) &&
Expand Down Expand Up @@ -4656,7 +4656,7 @@ export class Checker extends ParseTreeWalker {
const postInitMember = lookUpClassMember(
classType,
'__post_init__',
ClassMemberLookupFlags.SkipBaseClasses | ClassMemberLookupFlags.DeclaredTypesOnly
MemberAccessFlags.SkipBaseClasses | MemberAccessFlags.DeclaredTypesOnly
);

// If there's no __post_init__ method, there's nothing to check.
Expand Down Expand Up @@ -4897,7 +4897,7 @@ export class Checker extends ParseTreeWalker {

// If the symbol is declared by its parent, we can assume it
// is initialized there.
const parentSymbol = lookUpClassMember(classType, name, ClassMemberLookupFlags.SkipOriginalClass);
const parentSymbol = lookUpClassMember(classType, name, MemberAccessFlags.SkipOriginalClass);
if (parentSymbol) {
return;
}
Expand Down Expand Up @@ -5097,12 +5097,12 @@ export class Checker extends ParseTreeWalker {
const initMember = lookUpClassMember(
classType,
'__init__',
ClassMemberLookupFlags.SkipObjectBaseClass | ClassMemberLookupFlags.SkipInstanceVariables
MemberAccessFlags.SkipObjectBaseClass | MemberAccessFlags.SkipInstanceMembers
);
const newMember = lookUpClassMember(
classType,
'__new__',
ClassMemberLookupFlags.SkipObjectBaseClass | ClassMemberLookupFlags.SkipInstanceVariables
MemberAccessFlags.SkipObjectBaseClass | MemberAccessFlags.SkipInstanceMembers
);

if (!initMember || !newMember || !isClass(initMember.classType) || !isClass(newMember.classType)) {
Expand All @@ -5125,7 +5125,7 @@ export class Checker extends ParseTreeWalker {
const callMethod = lookUpClassMember(
metaclass,
'__call__',
ClassMemberLookupFlags.SkipTypeBaseClass | ClassMemberLookupFlags.SkipInstanceVariables
MemberAccessFlags.SkipTypeBaseClass | MemberAccessFlags.SkipInstanceMembers
);
if (callMethod) {
return;
Expand Down Expand Up @@ -5628,7 +5628,7 @@ export class Checker extends ParseTreeWalker {
}

assert(isClass(mroBaseClass));
const baseClassAndSymbol = lookUpClassMember(mroBaseClass, name, ClassMemberLookupFlags.Default);
const baseClassAndSymbol = lookUpClassMember(mroBaseClass, name, MemberAccessFlags.Default);
if (!baseClassAndSymbol) {
continue;
}
Expand Down Expand Up @@ -6286,9 +6286,9 @@ export class Checker extends ParseTreeWalker {
// it could be combined with other classes in a multi-inheritance
// situation that effectively adds new superclasses that we don't know
// about statically.
let effectiveFlags = ClassMemberLookupFlags.SkipInstanceVariables | ClassMemberLookupFlags.SkipOriginalClass;
let effectiveFlags = MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipOriginalClass;
if (ClassType.isFinal(classType)) {
effectiveFlags |= ClassMemberLookupFlags.SkipObjectBaseClass;
effectiveFlags |= MemberAccessFlags.SkipObjectBaseClass;
}

const methodMember = lookUpClassMember(classType, methodType.details.name, effectiveFlags);
Expand Down
10 changes: 5 additions & 5 deletions packages/pyright-internal/src/analyzer/codeFlowEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ import {
UnknownType,
} from './types';
import {
ClassMemberLookupFlags,
cleanIncompleteUnknown,
derivesFromStdlibClass,
doForEachSubtype,
isIncompleteUnknown,
isTypeAliasPlaceholder,
lookUpClassMember,
mapSubtypes,
MemberAccessFlags,
} from './typeUtils';

export interface FlowNodeTypeResult {
Expand Down Expand Up @@ -1525,7 +1525,7 @@ export function getCodeFlowEngine(
const metaclassCallMember = lookUpClassMember(
callSubtype.details.effectiveMetaclass,
'__call__',
ClassMemberLookupFlags.SkipInstanceVariables | ClassMemberLookupFlags.SkipObjectBaseClass
MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipObjectBaseClass
);
if (metaclassCallMember) {
return;
Expand All @@ -1535,14 +1535,14 @@ export function getCodeFlowEngine(
let constructorMember = lookUpClassMember(
callSubtype,
'__init__',
ClassMemberLookupFlags.SkipInstanceVariables | ClassMemberLookupFlags.SkipObjectBaseClass
MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipObjectBaseClass
);

if (constructorMember === undefined) {
constructorMember = lookUpClassMember(
callSubtype,
'__new__',
ClassMemberLookupFlags.SkipInstanceVariables | ClassMemberLookupFlags.SkipObjectBaseClass
MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipObjectBaseClass
);
}

Expand All @@ -1564,7 +1564,7 @@ export function getCodeFlowEngine(
const callMember = lookUpClassMember(
callSubtype,
'__call__',
ClassMemberLookupFlags.SkipInstanceVariables
MemberAccessFlags.SkipInstanceMembers
);
if (callMember) {
const callMemberType = evaluator.getTypeOfMember(callMember);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ import {
} from './types';
import {
applySolvedTypeVars,
ClassMemberLookupFlags,
convertToInstance,
getTypeVarScopeId,
lookUpObjectMember,
makeInferenceContext,
MemberAccessFlags,
} from './typeUtils';
import { TypeVarContext } from './typeVarContext';

Expand Down Expand Up @@ -79,11 +79,7 @@ function applyPartialTransform(
return result;
}

const callMemberResult = lookUpObjectMember(
result.returnType,
'__call__',
ClassMemberLookupFlags.SkipInstanceVariables
);
const callMemberResult = lookUpObjectMember(result.returnType, '__call__', MemberAccessFlags.SkipInstanceMembers);
if (!callMemberResult || !isTypeSame(convertToInstance(callMemberResult.classType), result.returnType)) {
return result;
}
Expand Down
86 changes: 34 additions & 52 deletions packages/pyright-internal/src/analyzer/constructors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,10 @@ import { getFileInfo } from './analyzerNodeInfo';
import { populateTypeVarContextBasedOnExpectedType } from './constraintSolver';
import { applyConstructorTransform, hasConstructorTransform } from './constructorTransform';
import { getTypeVarScopesForNode } from './parseTreeUtils';
import { CallResult, FunctionArgument, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
import {
CallResult,
ClassMemberLookup,
FunctionArgument,
MemberAccessFlags,
TypeEvaluator,
TypeResult,
} from './typeEvaluatorTypes';
import {
ClassMemberLookupFlags,
InferenceContext,
MemberAccessFlags,
applySolvedTypeVars,
buildTypeVarContextFromSpecializedClass,
convertToInstance,
Expand All @@ -57,7 +50,6 @@ import {
isAnyOrUnknown,
isClassInstance,
isFunction,
isInstantiableClass,
isNever,
isOverloadedFunction,
isTypeVar,
Expand Down Expand Up @@ -97,13 +89,13 @@ export function validateConstructorArguments(
}

// Determine whether the class overrides the object.__new__ method.
const newMethodTypeResult = evaluator.getTypeOfClassMemberName(
const newMethodTypeResult = evaluator.getTypeOfObjectMember(
errorNode,
type,
'__new__',
{ method: 'get' },
/* diag */ undefined,
MemberAccessFlags.AccessClassMembersOnly |
MemberAccessFlags.SkipClassMembers |
MemberAccessFlags.SkipObjectBaseClass |
MemberAccessFlags.SkipAttributeAccessOverride |
MemberAccessFlags.TreatConstructorAsClassMethod
Expand Down Expand Up @@ -200,7 +192,7 @@ function validateNewAndInitMethods(
type: ClassType,
skipUnknownArgCheck: boolean,
inferenceContext: InferenceContext | undefined,
newMethodTypeResult: ClassMemberLookup | undefined
newMethodTypeResult: TypeResult | undefined
): CallResult {
let returnType: Type | undefined;
let validatedArgExpressions = false;
Expand Down Expand Up @@ -270,13 +262,13 @@ function validateNewAndInitMethods(
}

// Determine whether the class overrides the object.__init__ method.
initMethodTypeResult = evaluator.getTypeOfClassMemberName(
initMethodTypeResult = evaluator.getTypeOfObjectMember(
errorNode,
initMethodBindToType,
'__init__',
{ method: 'get' },
/* diag */ undefined,
MemberAccessFlags.AccessClassMembersOnly |
MemberAccessFlags.SkipInstanceMembers |
MemberAccessFlags.SkipObjectBaseClass |
MemberAccessFlags.SkipAttributeAccessOverride
);
Expand Down Expand Up @@ -655,46 +647,36 @@ function validateMetaclassCall(
skipUnknownArgCheck: boolean,
inferenceContext: InferenceContext | undefined
): CallResult | undefined {
const metaclass = type.details.effectiveMetaclass;
const metaclassCallMethodInfo = evaluator.getTypeOfObjectMember(
errorNode,
type,
'__call__',
{ method: 'get' },
/* diag */ undefined,
MemberAccessFlags.SkipInstanceMembers |
MemberAccessFlags.SkipTypeBaseClass |
MemberAccessFlags.SkipAttributeAccessOverride
);

if (metaclass && isInstantiableClass(metaclass) && !ClassType.isSameGenericClass(metaclass, type)) {
const metaclassCallMethodInfo = evaluator.getTypeOfClassMemberName(
if (metaclassCallMethodInfo) {
const callResult = evaluator.validateCallArguments(
errorNode,
ClassType.cloneAsInstance(metaclass),
'__call__',
{ method: 'get' },
/* diag */ undefined,
MemberAccessFlags.AccessClassMembersOnly |
MemberAccessFlags.SkipTypeBaseClass |
MemberAccessFlags.SkipAttributeAccessOverride,
type
argList,
metaclassCallMethodInfo,
/* typeVarContext */ undefined,
skipUnknownArgCheck,
inferenceContext
);

if (metaclassCallMethodInfo) {
const callResult = evaluator.validateCallArguments(
errorNode,
argList,
metaclassCallMethodInfo,
/* typeVarContext */ undefined,
skipUnknownArgCheck,
inferenceContext
);

if (!callResult.returnType || isUnknown(callResult.returnType)) {
// The return result isn't known. We'll assume in this case that
// the metaclass __call__ method allocated a new instance of the
// requested class.
const typeVarContext = new TypeVarContext(getTypeVarScopeId(type));
callResult.returnType = applyExpectedTypeForConstructor(
evaluator,
type,
inferenceContext,
typeVarContext
);
}

return callResult;
if (!callResult.returnType || isUnknown(callResult.returnType)) {
// The return result isn't known. We'll assume in this case that
// the metaclass __call__ method allocated a new instance of the
// requested class.
const typeVarContext = new TypeVarContext(getTypeVarScopeId(type));
callResult.returnType = applyExpectedTypeForConstructor(evaluator, type, inferenceContext, typeVarContext);
}

return callResult;
}

return undefined;
Expand Down Expand Up @@ -789,7 +771,7 @@ export function createFunctionFromConstructor(
const initInfo = lookUpClassMember(
classType,
'__init__',
ClassMemberLookupFlags.SkipInstanceVariables | ClassMemberLookupFlags.SkipObjectBaseClass
MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipObjectBaseClass
);

if (initInfo) {
Expand Down Expand Up @@ -853,7 +835,7 @@ export function createFunctionFromConstructor(
const newInfo = lookUpClassMember(
classType,
'__new__',
ClassMemberLookupFlags.SkipInstanceVariables | ClassMemberLookupFlags.SkipObjectBaseClass
MemberAccessFlags.SkipInstanceMembers | MemberAccessFlags.SkipObjectBaseClass
);

if (newInfo) {
Expand Down
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/analyzer/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import { Symbol, SymbolFlags } from './symbol';
import { isSingleDunderName } from './symbolNameUtils';
import { FunctionArgument, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
import { enumerateLiteralsForType } from './typeGuards';
import { ClassMemberLookupFlags, computeMroLinearization, lookUpClassMember } from './typeUtils';
import { MemberAccessFlags, computeMroLinearization, lookUpClassMember } from './typeUtils';
import {
AnyType,
ClassType,
Expand Down Expand Up @@ -332,7 +332,7 @@ export function transformTypeForPossibleEnumClass(
// the value of each enum element is simply the value assigned to it.
// The __new__ method can transform the value in ways that we cannot
// determine statically.
const newMember = lookUpClassMember(enumClassInfo.classType, '__new__', ClassMemberLookupFlags.SkipBaseClasses);
const newMember = lookUpClassMember(enumClassInfo.classType, '__new__', MemberAccessFlags.SkipBaseClasses);
if (newMember) {
// We may want to change this to UnknownType in the future, but
// for now, we'll leave it as Any which is consistent with the
Expand Down
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/analyzer/functionTransform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {
OverloadedFunctionType,
Type,
} from './types';
import { ClassMember, ClassMemberLookupFlags, lookUpObjectMember, synthesizeTypeVarForSelfCls } from './typeUtils';
import { ClassMember, lookUpObjectMember, MemberAccessFlags, synthesizeTypeVarForSelfCls } from './typeUtils';

export function applyFunctionTransform(
evaluator: TypeEvaluator,
Expand Down Expand Up @@ -65,7 +65,7 @@ function applyTotalOrderingTransform(
// Verify that the class has at least one of the required functions.
let firstMemberFound: ClassMember | undefined;
const missingMethods = orderingMethods.filter((methodName) => {
const memberInfo = lookUpObjectMember(instanceType, methodName, ClassMemberLookupFlags.SkipInstanceVariables);
const memberInfo = lookUpObjectMember(instanceType, methodName, MemberAccessFlags.SkipInstanceMembers);
if (memberInfo && !firstMemberFound) {
firstMemberFound = memberInfo;
}
Expand Down
6 changes: 2 additions & 4 deletions packages/pyright-internal/src/analyzer/protocols.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ import {
applySolvedTypeVars,
AssignTypeFlags,
ClassMember,
ClassMemberLookupFlags,
containsLiteralType,
getTypeVarScopeId,
lookUpClassMember,
MemberAccessFlags,
partiallySpecializeType,
populateTypeVarContextForSelfType,
removeParamSpecVariadicsFromSignature,
Expand Down Expand Up @@ -489,9 +489,7 @@ function assignClassToProtocolInternal(
if (isSrcReadOnly) {
// The source attribute is read-only. Make sure the setter
// is not defined in the dest property.
if (
lookUpClassMember(destMemberType, '__set__', ClassMemberLookupFlags.SkipInstanceVariables)
) {
if (lookUpClassMember(destMemberType, '__set__', MemberAccessFlags.SkipInstanceMembers)) {
if (subDiag) {
subDiag.addMessage(
Localizer.DiagnosticAddendum.memberIsWritableInProtocol().format({ name })
Expand Down
Loading

0 comments on commit 53a2b22

Please sign in to comment.