diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index bc4be4da36fd..ceab74247eb7 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -9367,9 +9367,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // Does this function define the param spec, or is it an inner // function nested within another function that defines the param // spec? We need to handle these two cases differently. + const paramSpecScopeId = varArgListParam.type.scopeId; if ( - varArgListParam.type.scopeId === typeResult.type.details.typeVarScopeId || - varArgListParam.type.scopeId === typeResult.type.details.constructorTypeVarScopeId + paramSpecScopeId === typeResult.type.details.typeVarScopeId || + paramSpecScopeId === typeResult.type.details.constructorTypeVarScopeId ) { paramSpecArgList = []; paramSpecTarget = TypeVarType.cloneForParamSpecAccess(varArgListParam.type, /* access */ undefined); @@ -9377,6 +9378,19 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions positionalOnlyLimitIndex = varArgListParamIndex; } } + } else if (typeResult.type.details.paramSpec) { + const paramSpecScopeId = typeResult.type.details.paramSpec.scopeId; + if ( + paramSpecScopeId === typeResult.type.details.typeVarScopeId || + paramSpecScopeId === typeResult.type.details.constructorTypeVarScopeId + ) { + hasParamSpecArgsKwargs = true; + paramSpecArgList = []; + paramSpecTarget = TypeVarType.cloneForParamSpecAccess( + typeResult.type.details.paramSpec, + /* access */ undefined + ); + } } // If there are keyword arguments present after a *args argument, @@ -9994,27 +10008,25 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions }); trySetActive(argList[argIndex], paramDetails.params[paramInfoIndex].param); } + } else if (paramSpecArgList) { + paramSpecArgList.push(argList[argIndex]); } else if (paramDetails.kwargsIndex !== undefined) { - if (paramSpecArgList) { - paramSpecArgList.push(argList[argIndex]); - } else { - const paramType = paramDetails.params[paramDetails.kwargsIndex].type; - validateArgTypeParams.push({ - paramCategory: ParameterCategory.KwargsDict, - paramType, - requiresTypeVarMatching: requiresSpecialization(paramType), - argument: argList[argIndex], - errorNode: argList[argIndex].valueExpression ?? errorNode, - paramName: paramNameValue, - }); + const paramType = paramDetails.params[paramDetails.kwargsIndex].type; + validateArgTypeParams.push({ + paramCategory: ParameterCategory.KwargsDict, + paramType, + requiresTypeVarMatching: requiresSpecialization(paramType), + argument: argList[argIndex], + errorNode: argList[argIndex].valueExpression ?? errorNode, + paramName: paramNameValue, + }); - // Remember that this parameter has already received a value. - paramMap.set(paramNameValue, { - argsNeeded: 1, - argsReceived: 1, - isPositionalOnly: false, - }); - } + // Remember that this parameter has already received a value. + paramMap.set(paramNameValue, { + argsNeeded: 1, + argsReceived: 1, + isPositionalOnly: false, + }); assert( paramDetails.params[paramDetails.kwargsIndex], 'paramDetails.kwargsIndex params entry is undefined' @@ -10030,20 +10042,24 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions reportedArgError = true; } } else if (argList[argIndex].argumentCategory === ArgumentCategory.Simple) { - if (!isDiagnosticSuppressedForNode(errorNode)) { - const fileInfo = AnalyzerNodeInfo.getFileInfo(errorNode); - addDiagnostic( - fileInfo.diagnosticRuleSet.reportGeneralTypeIssues, - DiagnosticRule.reportGeneralTypeIssues, - positionParamLimitIndex === 1 - ? Localizer.Diagnostic.argPositionalExpectedOne() - : Localizer.Diagnostic.argPositionalExpectedCount().format({ - expected: positionParamLimitIndex, - }), - argList[argIndex].valueExpression || errorNode - ); + if (paramSpecArgList) { + paramSpecArgList.push(argList[argIndex]); + } else { + if (!isDiagnosticSuppressedForNode(errorNode)) { + const fileInfo = AnalyzerNodeInfo.getFileInfo(errorNode); + addDiagnostic( + fileInfo.diagnosticRuleSet.reportGeneralTypeIssues, + DiagnosticRule.reportGeneralTypeIssues, + positionParamLimitIndex === 1 + ? Localizer.Diagnostic.argPositionalExpectedOne() + : Localizer.Diagnostic.argPositionalExpectedCount().format({ + expected: positionParamLimitIndex, + }), + argList[argIndex].valueExpression || errorNode + ); + } + reportedArgError = true; } - reportedArgError = true; } else if (argList[argIndex].argumentCategory === ArgumentCategory.UnpackedList) { // Handle the case where a *args: P.args is passed as an argument to // a function that accepts a ParamSpec. diff --git a/packages/pyright-internal/src/tests/samples/paramSpec45.py b/packages/pyright-internal/src/tests/samples/paramSpec45.py new file mode 100644 index 000000000000..ca1923535125 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/paramSpec45.py @@ -0,0 +1,32 @@ +# This sample tests the case where the same function that uses a ParamSpec +# is called multiple times as arguments to the same call. + +from typing import Callable, ParamSpec + +P = ParamSpec("P") + + +def func1(func: Callable[P, object], *args: P.args, **kwargs: P.kwargs) -> object: + ... + + +def func2(x: str) -> int: + ... + + +def func3(y: str) -> int: + ... + + +print(func1(func2, x="..."), func1(func3, y="...")) + + +def func4(fn: Callable[P, int], *args: P.args, **kwargs: P.kwargs) -> int: + return fn(*args, **kwargs) + + +def func5(x: int, y: int) -> int: + return x + y + + +func5(func4(lambda x: x, 1), func4(lambda x, y: x + y, 2, 3)) diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 510b50df2810..4d1017757f5d 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -1033,6 +1033,11 @@ test('ParamSpec44', () => { TestUtils.validateResults(results, 0); }); +test('ParamSpec45', () => { + const results = TestUtils.typeAnalyzeSampleFiles(['paramSpec45.py']); + TestUtils.validateResults(results, 0); +}); + test('ClassVar1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['classVar1.py']);