From e8250ba18f4a3d61e7ca26c6d2ccd396615acf3c Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sun, 23 Jul 2023 16:49:25 -0700 Subject: [PATCH] Fixed a bug that resulted in the incorrect inferred variance for a type variable used within a frozen dataclass. This addresses #5568. --- .../src/analyzer/typeEvaluator.ts | 18 ++++++++++++------ .../src/tests/samples/autoVariance1.py | 18 ++++++++++++++++++ .../src/tests/typeEvaluator5.test.ts | 2 +- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index e3c54f26f5d4..eeb6323c6ccd 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -21204,12 +21204,18 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } } else { const primaryDecl = symbol.getDeclarations()[0]; - // Class and instance variables that are mutable need to - // enforce invariance. - const flags = - primaryDecl?.type === DeclarationType.Variable && !isFinalVariableDeclaration(primaryDecl) - ? AssignTypeFlags.EnforceInvariance - : AssignTypeFlags.Default; + + let flags = AssignTypeFlags.Default; + if ( + primaryDecl?.type === DeclarationType.Variable && + !isFinalVariableDeclaration(primaryDecl) && + !ClassType.isFrozenDataClass(destType) + ) { + // Class and instance variables that are mutable need to + // enforce invariance. + flags |= AssignTypeFlags.EnforceInvariance; + } + if ( !assignType( destMemberType, diff --git a/packages/pyright-internal/src/tests/samples/autoVariance1.py b/packages/pyright-internal/src/tests/samples/autoVariance1.py index 6b2cb023cc22..c415312647ed 100644 --- a/packages/pyright-internal/src/tests/samples/autoVariance1.py +++ b/packages/pyright-internal/src/tests/samples/autoVariance1.py @@ -1,6 +1,7 @@ # This sample tests variance inference for type variables that use # autovariance. +from dataclasses import dataclass from typing import Iterator, Sequence @@ -32,6 +33,14 @@ def method1(self) -> "ShouldBeCovariant2[T]": vco3_1: ShouldBeCovariant3[float] = ShouldBeCovariant3[int]() +@dataclass(frozen=True) +class ShouldBeCovariant4[T]: + x: T + +vo4_1: ShouldBeCovariant4[int] = ShouldBeCovariant4(1) +vo4_2: ShouldBeCovariant4[float] = vo4_1 + + class ShouldBeInvariant1[T]: def __init__(self, value: T) -> None: self._value = value @@ -83,6 +92,15 @@ class ShouldBeInvariant3[K, V](dict[K, V]): # This should generate an error based on variance vinv3_4: ShouldBeInvariant3[str, int] = ShouldBeInvariant3[str, float]() +@dataclass +class ShouldBeInvariant4[T]: + x: T + +vinv4_1: ShouldBeInvariant4[int] = ShouldBeInvariant4(1) + +# This should generate an error based on variance +vinv4_2: ShouldBeInvariant4[float] = vinv4_1 + class ShouldBeContravariant1[T]: def __init__(self, value: T) -> None: diff --git a/packages/pyright-internal/src/tests/typeEvaluator5.test.ts b/packages/pyright-internal/src/tests/typeEvaluator5.test.ts index b7682e27511a..033b4c9cbe08 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator5.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator5.test.ts @@ -77,7 +77,7 @@ test('AutoVariance1', () => { configOptions.defaultPythonVersion = PythonVersion.V3_12; const analysisResults = TestUtils.typeAnalyzeSampleFiles(['autoVariance1.py'], configOptions); - TestUtils.validateResults(analysisResults, 11); + TestUtils.validateResults(analysisResults, 12); }); test('AutoVariance2', () => {