diff --git a/packages/pyright-internal/src/analyzer/typeCacheUtils.ts b/packages/pyright-internal/src/analyzer/typeCacheUtils.ts index 208207dc41fa..e6b26a0c1e9b 100644 --- a/packages/pyright-internal/src/analyzer/typeCacheUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeCacheUtils.ts @@ -23,6 +23,7 @@ interface SpeculativeContext { speculativeRootNode: ParseNode; entriesToUndo: SpeculativeEntry[]; dependentType: Type | undefined; + allowDiagnostics?: boolean; } interface DependentType { @@ -42,6 +43,17 @@ export interface SpeculativeTypeEntry { dependentTypes?: DependentType[]; } +export interface SpeculativeModeOptions { + // If specified, the type cached speculative result depends on + // this dependent type. + dependentType?: Type; + + // Normally, diagnostics are suppressed for nodes under + // a speculative root, but this can be overridden by specifying + // this option. + allowDiagnostics?: boolean; +} + // This class maintains a stack of "speculative type contexts". When // a context is popped off the stack, all of the speculative type cache // entries that were created within that context are removed from the @@ -58,20 +70,21 @@ export class SpeculativeTypeTracker { private _speculativeTypeCache = new Map(); private _activeDependentTypes: DependentType[] = []; - enterSpeculativeContext(speculativeRootNode: ParseNode, dependentType: Type | undefined) { + enterSpeculativeContext(speculativeRootNode: ParseNode, options?: SpeculativeModeOptions) { this._speculativeContextStack.push({ speculativeRootNode, entriesToUndo: [], - dependentType, + dependentType: options?.dependentType, + allowDiagnostics: options?.allowDiagnostics, }); // Retain a list of active dependent types. This information is already // contained within the speculative context stack, but we retain a copy // in this alternate form for performance reasons. - if (dependentType) { + if (options?.dependentType) { this._activeDependentTypes.push({ speculativeRootNode, - dependentType, + dependentType: options.dependentType, }); } } @@ -92,7 +105,7 @@ export class SpeculativeTypeTracker { }); } - isSpeculative(node: ParseNode | undefined) { + isSpeculative(node: ParseNode | undefined, ignoreIfDiagnosticsAllowed = false) { if (this._speculativeContextStack.length === 0) { return false; } @@ -102,8 +115,11 @@ export class SpeculativeTypeTracker { } for (let i = this._speculativeContextStack.length - 1; i >= 0; i--) { - if (ParseTreeUtils.isNodeContainedWithin(node, this._speculativeContextStack[i].speculativeRootNode)) { - return true; + const stackEntry = this._speculativeContextStack[i]; + if (ParseTreeUtils.isNodeContainedWithin(node, stackEntry.speculativeRootNode)) { + if (!ignoreIfDiagnosticsAllowed || !stackEntry.allowDiagnostics) { + return true; + } } } diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index c081f30d3d61..680e2a4e585a 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -147,7 +147,7 @@ import { evaluateStaticBoolExpression } from './staticExpressions'; import { Symbol, SymbolFlags, indeterminateSymbolId } from './symbol'; import { isConstantName, isPrivateName, isPrivateOrProtectedName } from './symbolNameUtils'; import { getLastTypedDeclaredForSymbol } from './symbolUtils'; -import { SpeculativeTypeTracker } from './typeCacheUtils'; +import { SpeculativeModeOptions, SpeculativeTypeTracker } from './typeCacheUtils'; import { AbstractMethod, AnnotationTypeOptions, @@ -2877,7 +2877,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions function isDiagnosticSuppressedForNode(node: ParseNode) { return ( suppressedNodeStack.some((suppressedNode) => ParseTreeUtils.isNodeContainedWithin(node, suppressedNode)) || - isSpeculativeModeInUse(node) + speculativeTypeTracker.isSpeculative(node, /* ignoreIfDiagnosticsAllowed */ true) ); } @@ -3607,24 +3607,28 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions function mapSubtypesExpandTypeVars( type: Type, conditionFilter: TypeCondition[] | undefined, - callback: (expandedSubtype: Type, unexpandedSubtype: Type) => Type | undefined + callback: (expandedSubtype: Type, unexpandedSubtype: Type, isLastIteration: boolean) => Type | undefined ): Type { const newSubtypes: Type[] = []; let typeChanged = false; - const expandSubtype = (unexpandedType: Type) => { + function expandSubtype(unexpandedType: Type, isLastSubtype: boolean) { let expandedType = isUnion(unexpandedType) ? unexpandedType : makeTopLevelTypeVarsConcrete(unexpandedType); expandedType = transformPossibleRecursiveTypeAlias(expandedType); - doForEachSubtype(expandedType, (subtype) => { + doForEachSubtype(expandedType, (subtype, index, allSubtypes) => { if (conditionFilter) { if (!TypeCondition.isCompatible(getTypeCondition(subtype), conditionFilter)) { return undefined; } } - let transformedType = callback(subtype, unexpandedType); + let transformedType = callback( + subtype, + unexpandedType, + isLastSubtype && index === allSubtypes.length - 1 + ); if (transformedType !== unexpandedType) { typeChanged = true; } @@ -3633,6 +3637,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions const typeCondition = getTypeCondition(subtype)?.filter( (condition) => condition.isConstrainedTypeVar ); + if (typeCondition && typeCondition.length > 0) { transformedType = addConditionToType(transformedType, typeCondition); } @@ -3641,14 +3646,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } return undefined; }); - }; + } if (isUnion(type)) { - type.subtypes.forEach((subtype) => { - expandSubtype(subtype); + type.subtypes.forEach((subtype, index) => { + expandSubtype(subtype, index === type.subtypes.length - 1); }); } else { - expandSubtype(type); + expandSubtype(type, /* isLastSubtype */ true); } if (!typeChanged) { @@ -7594,6 +7599,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions ): TypeResult { let baseTypeResult: TypeResult | undefined; + // Check for the use of `type(x)` within a type annotation. This isn't + // allowed, and it's a common mistake, so we want to emit a diagnostic + // that guides the user to the right solution. if ( (flags & EvaluatorFlags.ExpectingTypeAnnotation) !== 0 && node.leftExpression.nodeType === ParseNodeType.Name && @@ -8735,34 +8743,42 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions let returnType = mapSubtypesExpandTypeVars( callTypeResult.type, /* conditionFilter */ undefined, - (expandedSubtype, unexpandedSubtype) => { - const callResult = validateCallArgumentsForSubtype( - errorNode, - argList, - expandedSubtype, - unexpandedSubtype, - !!callTypeResult.isIncomplete, - typeVarContext, - skipUnknownArgCheck, - inferenceContext, - recursionCount - ); + (expandedSubtype, unexpandedSubtype, isLastIteration) => { + return useSpeculativeMode( + isLastIteration ? undefined : errorNode, + () => { + const callResult = validateCallArgumentsForSubtype( + errorNode, + argList, + expandedSubtype, + unexpandedSubtype, + !!callTypeResult.isIncomplete, + typeVarContext, + skipUnknownArgCheck, + inferenceContext, + recursionCount + ); - if (callResult.argumentErrors) { - argumentErrors = true; - } + if (callResult.argumentErrors) { + argumentErrors = true; + } - if (callResult.isTypeIncomplete) { - isTypeIncomplete = true; - } + if (callResult.isTypeIncomplete) { + isTypeIncomplete = true; + } - if (callResult.overloadsUsedForCall) { - appendArray(overloadsUsedForCall, callResult.overloadsUsedForCall); - } + if (callResult.overloadsUsedForCall) { + appendArray(overloadsUsedForCall, callResult.overloadsUsedForCall); + } - specializedInitSelfType = callResult.specializedInitSelfType; + specializedInitSelfType = callResult.specializedInitSelfType; - return callResult.returnType; + return callResult.returnType; + }, + { + allowDiagnostics: true, + } + ); } ); @@ -13478,7 +13494,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions isIncomplete = true; } }, - inferenceContext?.expectedType + { + dependentType: inferenceContext?.expectedType, + } ); // Mark the function type as no longer being evaluated. @@ -19157,13 +19175,13 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions function useSpeculativeMode( speculativeNode: ParseNode | undefined, callback: () => T, - dependentType?: Type | undefined + options?: SpeculativeModeOptions ) { if (!speculativeNode) { return callback(); } - speculativeTypeTracker.enterSpeculativeContext(speculativeNode, dependentType); + speculativeTypeTracker.enterSpeculativeContext(speculativeNode, options); try { const result = callback(); diff --git a/packages/pyright-internal/src/tests/samples/call11.py b/packages/pyright-internal/src/tests/samples/call11.py new file mode 100644 index 000000000000..c1af3c803078 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/call11.py @@ -0,0 +1,44 @@ +# This sample tests the case where a call expression involves a union +# on the LHS where the subtypes of the union have different signatures. + +# pyright: strict + +from __future__ import annotations +from typing import Any, Callable, Generic, Self, TypeAlias, TypeVar + +T = TypeVar("T") +E = TypeVar("E") +U = TypeVar("U") +F = TypeVar("F") + +Either: TypeAlias = "Left[T]" | "Right[E]" + + +class Left(Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + + def map_left(self, fn: Callable[[T], U]) -> Left[U]: + return Left(fn(self.value)) + + def map_right(self, fn: Callable[[Any], Any]) -> Self: + return self + + +class Right(Generic[E]): + def __init__(self, value: E) -> None: + self.value = value + + def map_left(self, fn: Callable[[Any], Any]) -> Self: + return self + + def map_right(self, fn: Callable[[E], F]) -> Right[F]: + return Right(fn(self.value)) + + +def func() -> Either[int, str]: + raise NotImplementedError + + +result = func().map_left(lambda lv: lv + 1).map_right(lambda rv: rv + "a") +reveal_type(result, expected_text="Left[int] | Right[str]") diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index 6e02bf4cc290..4e9675062356 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -797,6 +797,12 @@ test('Call10', () => { TestUtils.validateResults(analysisResults, 3); }); +test('Call11', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['call11.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('Function1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['function1.py']);