Skip to content

Commit

Permalink
Fixed an issue in type evaluation of call expressions where the calla…
Browse files Browse the repository at this point in the history
…ble subexpression evaluates to a union, and the callable subtypes have different signatures. Pyright was previously caching the types from the first subtype, so it didn't re-evaluate using the second subtype (which may require bidirectional type inference). This addresses #5428. (#5547)

Co-authored-by: Eric Traut <[email protected]>
  • Loading branch information
erictraut and msfterictraut authored Jul 20, 2023
1 parent 98fcba1 commit 9aa29e4
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 43 deletions.
30 changes: 23 additions & 7 deletions packages/pyright-internal/src/analyzer/typeCacheUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ interface SpeculativeContext {
speculativeRootNode: ParseNode;
entriesToUndo: SpeculativeEntry[];
dependentType: Type | undefined;
allowDiagnostics?: boolean;
}

interface DependentType {
Expand All @@ -42,6 +43,17 @@ export interface SpeculativeTypeEntry {
dependentTypes?: DependentType[];
}

export interface SpeculativeModeOptions {
// If specified, the type cached speculative result depends on
// this dependent type.
dependentType?: Type;

// Normally, diagnostics are suppressed for nodes under
// a speculative root, but this can be overridden by specifying
// this option.
allowDiagnostics?: boolean;
}

// This class maintains a stack of "speculative type contexts". When
// a context is popped off the stack, all of the speculative type cache
// entries that were created within that context are removed from the
Expand All @@ -58,20 +70,21 @@ export class SpeculativeTypeTracker {
private _speculativeTypeCache = new Map<number, SpeculativeTypeEntry[]>();
private _activeDependentTypes: DependentType[] = [];

enterSpeculativeContext(speculativeRootNode: ParseNode, dependentType: Type | undefined) {
enterSpeculativeContext(speculativeRootNode: ParseNode, options?: SpeculativeModeOptions) {
this._speculativeContextStack.push({
speculativeRootNode,
entriesToUndo: [],
dependentType,
dependentType: options?.dependentType,
allowDiagnostics: options?.allowDiagnostics,
});

// Retain a list of active dependent types. This information is already
// contained within the speculative context stack, but we retain a copy
// in this alternate form for performance reasons.
if (dependentType) {
if (options?.dependentType) {
this._activeDependentTypes.push({
speculativeRootNode,
dependentType,
dependentType: options.dependentType,
});
}
}
Expand All @@ -92,7 +105,7 @@ export class SpeculativeTypeTracker {
});
}

isSpeculative(node: ParseNode | undefined) {
isSpeculative(node: ParseNode | undefined, ignoreIfDiagnosticsAllowed = false) {
if (this._speculativeContextStack.length === 0) {
return false;
}
Expand All @@ -102,8 +115,11 @@ export class SpeculativeTypeTracker {
}

for (let i = this._speculativeContextStack.length - 1; i >= 0; i--) {
if (ParseTreeUtils.isNodeContainedWithin(node, this._speculativeContextStack[i].speculativeRootNode)) {
return true;
const stackEntry = this._speculativeContextStack[i];
if (ParseTreeUtils.isNodeContainedWithin(node, stackEntry.speculativeRootNode)) {
if (!ignoreIfDiagnosticsAllowed || !stackEntry.allowDiagnostics) {
return true;
}
}
}

Expand Down
90 changes: 54 additions & 36 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ import { evaluateStaticBoolExpression } from './staticExpressions';
import { Symbol, SymbolFlags, indeterminateSymbolId } from './symbol';
import { isConstantName, isPrivateName, isPrivateOrProtectedName } from './symbolNameUtils';
import { getLastTypedDeclaredForSymbol } from './symbolUtils';
import { SpeculativeTypeTracker } from './typeCacheUtils';
import { SpeculativeModeOptions, SpeculativeTypeTracker } from './typeCacheUtils';
import {
AbstractMethod,
AnnotationTypeOptions,
Expand Down Expand Up @@ -2877,7 +2877,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
function isDiagnosticSuppressedForNode(node: ParseNode) {
return (
suppressedNodeStack.some((suppressedNode) => ParseTreeUtils.isNodeContainedWithin(node, suppressedNode)) ||
isSpeculativeModeInUse(node)
speculativeTypeTracker.isSpeculative(node, /* ignoreIfDiagnosticsAllowed */ true)
);
}

Expand Down Expand Up @@ -3607,24 +3607,28 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
function mapSubtypesExpandTypeVars(
type: Type,
conditionFilter: TypeCondition[] | undefined,
callback: (expandedSubtype: Type, unexpandedSubtype: Type) => Type | undefined
callback: (expandedSubtype: Type, unexpandedSubtype: Type, isLastIteration: boolean) => Type | undefined
): Type {
const newSubtypes: Type[] = [];
let typeChanged = false;

const expandSubtype = (unexpandedType: Type) => {
function expandSubtype(unexpandedType: Type, isLastSubtype: boolean) {
let expandedType = isUnion(unexpandedType) ? unexpandedType : makeTopLevelTypeVarsConcrete(unexpandedType);

expandedType = transformPossibleRecursiveTypeAlias(expandedType);

doForEachSubtype(expandedType, (subtype) => {
doForEachSubtype(expandedType, (subtype, index, allSubtypes) => {
if (conditionFilter) {
if (!TypeCondition.isCompatible(getTypeCondition(subtype), conditionFilter)) {
return undefined;
}
}

let transformedType = callback(subtype, unexpandedType);
let transformedType = callback(
subtype,
unexpandedType,
isLastSubtype && index === allSubtypes.length - 1
);
if (transformedType !== unexpandedType) {
typeChanged = true;
}
Expand All @@ -3633,6 +3637,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
const typeCondition = getTypeCondition(subtype)?.filter(
(condition) => condition.isConstrainedTypeVar
);

if (typeCondition && typeCondition.length > 0) {
transformedType = addConditionToType(transformedType, typeCondition);
}
Expand All @@ -3641,14 +3646,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}
return undefined;
});
};
}

if (isUnion(type)) {
type.subtypes.forEach((subtype) => {
expandSubtype(subtype);
type.subtypes.forEach((subtype, index) => {
expandSubtype(subtype, index === type.subtypes.length - 1);
});
} else {
expandSubtype(type);
expandSubtype(type, /* isLastSubtype */ true);
}

if (!typeChanged) {
Expand Down Expand Up @@ -7594,6 +7599,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
): TypeResult {
let baseTypeResult: TypeResult | undefined;

// Check for the use of `type(x)` within a type annotation. This isn't
// allowed, and it's a common mistake, so we want to emit a diagnostic
// that guides the user to the right solution.
if (
(flags & EvaluatorFlags.ExpectingTypeAnnotation) !== 0 &&
node.leftExpression.nodeType === ParseNodeType.Name &&
Expand Down Expand Up @@ -8735,34 +8743,42 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
let returnType = mapSubtypesExpandTypeVars(
callTypeResult.type,
/* conditionFilter */ undefined,
(expandedSubtype, unexpandedSubtype) => {
const callResult = validateCallArgumentsForSubtype(
errorNode,
argList,
expandedSubtype,
unexpandedSubtype,
!!callTypeResult.isIncomplete,
typeVarContext,
skipUnknownArgCheck,
inferenceContext,
recursionCount
);
(expandedSubtype, unexpandedSubtype, isLastIteration) => {
return useSpeculativeMode(
isLastIteration ? undefined : errorNode,
() => {
const callResult = validateCallArgumentsForSubtype(
errorNode,
argList,
expandedSubtype,
unexpandedSubtype,
!!callTypeResult.isIncomplete,
typeVarContext,
skipUnknownArgCheck,
inferenceContext,
recursionCount
);

if (callResult.argumentErrors) {
argumentErrors = true;
}
if (callResult.argumentErrors) {
argumentErrors = true;
}

if (callResult.isTypeIncomplete) {
isTypeIncomplete = true;
}
if (callResult.isTypeIncomplete) {
isTypeIncomplete = true;
}

if (callResult.overloadsUsedForCall) {
appendArray(overloadsUsedForCall, callResult.overloadsUsedForCall);
}
if (callResult.overloadsUsedForCall) {
appendArray(overloadsUsedForCall, callResult.overloadsUsedForCall);
}

specializedInitSelfType = callResult.specializedInitSelfType;
specializedInitSelfType = callResult.specializedInitSelfType;

return callResult.returnType;
return callResult.returnType;
},
{
allowDiagnostics: true,
}
);
}
);

Expand Down Expand Up @@ -13478,7 +13494,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
isIncomplete = true;
}
},
inferenceContext?.expectedType
{
dependentType: inferenceContext?.expectedType,
}
);

// Mark the function type as no longer being evaluated.
Expand Down Expand Up @@ -19157,13 +19175,13 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
function useSpeculativeMode<T>(
speculativeNode: ParseNode | undefined,
callback: () => T,
dependentType?: Type | undefined
options?: SpeculativeModeOptions
) {
if (!speculativeNode) {
return callback();
}

speculativeTypeTracker.enterSpeculativeContext(speculativeNode, dependentType);
speculativeTypeTracker.enterSpeculativeContext(speculativeNode, options);

try {
const result = callback();
Expand Down
44 changes: 44 additions & 0 deletions packages/pyright-internal/src/tests/samples/call11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# This sample tests the case where a call expression involves a union
# on the LHS where the subtypes of the union have different signatures.

# pyright: strict

from __future__ import annotations
from typing import Any, Callable, Generic, Self, TypeAlias, TypeVar

T = TypeVar("T")
E = TypeVar("E")
U = TypeVar("U")
F = TypeVar("F")

Either: TypeAlias = "Left[T]" | "Right[E]"


class Left(Generic[T]):
def __init__(self, value: T) -> None:
self.value = value

def map_left(self, fn: Callable[[T], U]) -> Left[U]:
return Left(fn(self.value))

def map_right(self, fn: Callable[[Any], Any]) -> Self:
return self


class Right(Generic[E]):
def __init__(self, value: E) -> None:
self.value = value

def map_left(self, fn: Callable[[Any], Any]) -> Self:
return self

def map_right(self, fn: Callable[[E], F]) -> Right[F]:
return Right(fn(self.value))


def func() -> Either[int, str]:
raise NotImplementedError


result = func().map_left(lambda lv: lv + 1).map_right(lambda rv: rv + "a")
reveal_type(result, expected_text="Left[int] | Right[str]")
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator1.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,12 @@ test('Call10', () => {
TestUtils.validateResults(analysisResults, 3);
});

test('Call11', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['call11.py']);

TestUtils.validateResults(analysisResults, 0);
});

test('Function1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['function1.py']);

Expand Down

0 comments on commit 9aa29e4

Please sign in to comment.