diff --git a/docs/type-concepts-advanced.md b/docs/type-concepts-advanced.md index 0c7b2a7ac630..19c3097aa170 100644 --- a/docs/type-concepts-advanced.md +++ b/docs/type-concepts-advanced.md @@ -62,6 +62,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `type(x) is T` and `type(x) is not T` * `type(x) == T` and `type(x) != T` * `x is E` and `x is not E` (where E is a literal enum or bool) +* `x is C` and `x is not C` (where C is a class) * `x == L` and `x != L` (where L is an expression that evaluates to a literal type) * `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None) * `x.y is E` and `x.y is not E` (where E is a literal enum or bool and x is a type that is distinguished by a field with a literal type) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 3ab508cfbc66..6919fe08dfc3 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -222,12 +222,12 @@ export function getTypeNarrowingCallback( } } - // Look for "X is Y" or "X is not Y" where Y is a an enum or bool literal. if (isOrIsNotOperator) { if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) { const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression); const rightType = rightTypeResult.type; + // Look for "X is Y" or "X is not Y" where Y is a an enum or bool literal. if ( isClassInstance(rightType) && (ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) && @@ -246,9 +246,19 @@ export function getTypeNarrowingCallback( }; }; } + + // Look for X is or X is not . + if (isInstantiableClass(rightType)) { + return (type: Type) => { + return { + type: narrowTypeForClassComparison(evaluator, type, rightType, adjIsPositiveTest), + isIncomplete: !!rightTypeResult.isIncomplete, + }; + }; + } } - // Look for X[] is or X[] is not + // Look for X[] is or X[] is not . if ( testExpression.leftExpression.nodeType === ParseNodeType.Index && testExpression.leftExpression.items.length === 1 && @@ -2078,6 +2088,54 @@ function narrowTypeForTypeIs(evaluator: TypeEvaluator, type: Type, classType: Cl ); } +// Attempts to narrow a type based on a comparison with a class using "is" or +// "is not". This pattern is sometimes used for sentinels. +function narrowTypeForClassComparison( + evaluator: TypeEvaluator, + referenceType: Type, + classType: ClassType, + isPositiveTest: boolean +): Type { + return mapSubtypes(referenceType, (subtype) => { + const concreteSubtype = evaluator.makeTopLevelTypeVarsConcrete(subtype); + + if (isPositiveTest) { + if (isNoneInstance(concreteSubtype)) { + return undefined; + } + + if (isClassInstance(concreteSubtype) && TypeBase.isInstance(subtype)) { + return undefined; + } + + if (isInstantiableClass(concreteSubtype) && ClassType.isFinal(concreteSubtype)) { + if ( + !ClassType.isSameGenericClass(concreteSubtype, classType) && + !isIsinstanceFilterSuperclass( + evaluator, + concreteSubtype, + classType, + classType, + /* isInstanceCheck */ false + ) + ) { + return undefined; + } + } + } else { + if ( + isInstantiableClass(concreteSubtype) && + ClassType.isSameGenericClass(classType, concreteSubtype) && + ClassType.isFinal(classType) + ) { + return undefined; + } + } + + return subtype; + }); +} + // Attempts to narrow a type (make it more constrained) based on a comparison // (equal or not equal) to a literal value. It also handles "is" or "is not" // operators if isIsOperator is true. diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIsClass1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIsClass1.py new file mode 100644 index 000000000000..22c299f09858 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIsClass1.py @@ -0,0 +1,58 @@ +# This sample tests type narrowing for conditional +# statements of the form X is or X is not . + +from typing import Any, TypeVar, final + + +@final +class A: + ... + + +@final +class B: + ... + + +class C: + ... + + +def func1(x: type[A] | type[B] | None | int): + if x is A: + reveal_type(x, expected_text="type[A]") + else: + reveal_type(x, expected_text="type[B] | int | None") + + +def func2(x: type[A] | type[B] | None | int, y: type[A]): + if x is not y: + reveal_type(x, expected_text="type[B] | int | None") + else: + reveal_type(x, expected_text="type[A]") + + +def func3(x: type[A] | type[B] | Any): + if x is A: + reveal_type(x, expected_text="type[A] | Any") + else: + reveal_type(x, expected_text="type[B] | Any") + + +def func4(x: type[A] | type[B] | type[C]): + if x is C: + reveal_type(x, expected_text="type[C]") + else: + reveal_type(x, expected_text="type[A] | type[B] | type[C]") + + +T = TypeVar("T") + + +def func5(x: type[A] | type[B] | type[T]) -> type[A] | type[B] | type[T]: + if x is A: + reveal_type(x, expected_text="type[A] | type[T@func5]") + else: + reveal_type(x, expected_text="type[B] | type[T@func5]") + + return x diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index ce80b4731e94..6e02bf4cc290 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -315,6 +315,12 @@ test('TypeNarrowingIsNone2', () => { TestUtils.validateResults(analysisResults, 0); }); +test('TypeNarrowingIsClass1', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingIsClass1.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('TypeNarrowingIsNoneTuple1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingIsNoneTuple1.py']);