diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index b9f1f16e..791544d0 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -1865,7 +1865,7 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst if (methodInstance->GetParamKind(paramIdx) == BfParamKind_Params) { paramsParamIdx = paramIdx; - if ((wantType->IsArray()) || (wantType->IsInstanceOf(mModule->mCompiler->mSpanTypeDef))) + if ((wantType->IsArray()) || (wantType->IsSizedArray()) || (wantType->IsInstanceOf(mModule->mCompiler->mSpanTypeDef))) wantType = wantType->GetUnderlyingType(); } @@ -1888,6 +1888,26 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst goto NoMatch; } + if ((checkMethod->mParams.mSize > 0) && (methodInstance->GetParamKind(checkMethod->mParams.mSize - 1) == BfParamKind_Params)) + { + // Handle `params int[C]` generic sized array params case + auto paramsType = methodInstance->GetParamType(checkMethod->mParams.mSize - 1); + if (paramsType->IsUnknownSizedArrayType()) + { + auto unknownSizedArray = (BfUnknownSizedArrayType*)paramsType; + if (unknownSizedArray->mElementCountSource->IsMethodGenericParam()) + { + auto genericParam = (BfGenericParamType*)unknownSizedArray->mElementCountSource; + if ((*genericArgumentsSubstitute)[genericParam->mGenericParamIdx] == NULL) + { + int paramsCount = (int)mArguments.mSize - inferParamOffset; + (*genericArgumentsSubstitute)[genericParam->mGenericParamIdx] = mModule->CreateConstExprValueType( + BfTypedValue(mModule->mBfIRBuilder->CreateConst(BfTypeCode_IntPtr, paramsCount), mModule->GetPrimitiveType(BfTypeCode_IntPtr))); + } + } + } + } + if (!deferredArgs.IsEmpty()) { genericInferContext.InferGenericArguments(methodInstance); @@ -6915,6 +6935,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu PushArg(expandedParamsArray, irArgs); } + continue; } else if (wantType->IsArray()) { @@ -6947,10 +6968,11 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu mModule->mBfIRBuilder->CreateAlignedStore(mModule->GetConstValue32(numElements), addr, 4); PushArg(expandedParamsArray, irArgs); + continue; } else if (wantType->IsInstanceOf(mModule->mCompiler->mSpanTypeDef)) { - auto genericTypeInst = wantType->ToGenericTypeInstance(); + auto genericTypeInst = wantType->ToGenericTypeInstance(); expandedParamsElementType = genericTypeInst->mGenericTypeInfo->mTypeGenericArguments[0]; expandedParamsArray = BfTypedValue(mModule->CreateAlloca(wantType), wantType, true); @@ -6959,9 +6981,26 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu mModule->mBfIRBuilder->CreateStore(mModule->GetConstValue(numElements), mModule->mBfIRBuilder->CreateInBoundsGEP(expandedParamsArray.mValue, 0, 2)); PushArg(expandedParamsArray, irArgs); + continue; } + else if (wantType->IsSizedArray()) + { + BfSizedArrayType* sizedArrayType = (BfSizedArrayType*)wantType; + expandedParamsElementType = wantType->GetUnderlyingType(); - continue; + if (numElements != sizedArrayType->mElementCount) + { + BfAstNode* refNode = targetSrc; + if (argExprIdx < (int)argValues.size()) + refNode = argValues[argExprIdx].mExpression; + mModule->Fail(StrFormat("Incorrect number of arguments to match params type '%s'", mModule->TypeToString(wantType).c_str()), refNode); + } + + expandedParamsArray = BfTypedValue(mModule->CreateAlloca(wantType), wantType, true); + expandedParamAlloca = mModule->mBfIRBuilder->CreateBitCast(expandedParamsArray.mValue, mModule->mBfIRBuilder->GetPointerTo(mModule->mBfIRBuilder->MapType(expandedParamsElementType))); + PushArg(expandedParamsArray, irArgs); + continue; + } } } } @@ -20828,7 +20867,7 @@ void BfExprEvaluator::PerformUnaryOperation_OnResult(BfExpression* unaryOpExpr, auto genericTypeInst = mResult.mType->ToGenericTypeInstance(); if ((genericTypeInst != NULL) && (genericTypeInst->IsInstanceOf(mModule->mCompiler->mSpanTypeDef))) isValid = true; - else if (mResult.mType->IsArray()) + else if ((mResult.mType->IsArray()) || (mResult.mType->IsSizedArray())) isValid = true; if (!isValid) { diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index 4300a2df..d2b1d903 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -22672,6 +22672,10 @@ void BfModule::DoMethodDeclaration(BfMethodDeclaration* methodDeclaration, bool // Array is the 'normal' params type isValid = true; } + else if (resolvedParamType->IsSizedArray()) + { + isValid = true; + } else if ((resolvedParamType->IsDelegate()) || (resolvedParamType->IsFunction())) { hadDelegateParams = true; @@ -22700,7 +22704,7 @@ void BfModule::DoMethodDeclaration(BfMethodDeclaration* methodDeclaration, bool if (genericParamInstance->mTypeConstraint != NULL) { auto typeInstConstraint = genericParamInstance->mTypeConstraint->ToTypeInstance(); - if (genericParamInstance->mTypeConstraint->IsArray()) + if ((genericParamInstance->mTypeConstraint->IsArray()) || (genericParamInstance->mTypeConstraint->IsSizedArray())) { BfMethodParam methodParam; methodParam.mResolvedType = resolvedParamType; diff --git a/IDEHelper/Compiler/BfResolvedTypeUtils.h b/IDEHelper/Compiler/BfResolvedTypeUtils.h index 20f90cb3..df8253b0 100644 --- a/IDEHelper/Compiler/BfResolvedTypeUtils.h +++ b/IDEHelper/Compiler/BfResolvedTypeUtils.h @@ -1081,7 +1081,7 @@ public: public: bool IsGenericParam() override { return true; } bool IsTypeGenericParam() override { return mGenericParamKind == BfGenericParamKind_Type; } - bool IsMethodGenericParam() override { return mGenericParamKind == BfGenericParamKind_Type; } + bool IsMethodGenericParam() override { return mGenericParamKind == BfGenericParamKind_Method; } virtual bool IsUnspecializedType() override { return true; } virtual bool IsReified() override { return false; } }; diff --git a/IDEHelper/Tests/src/MethodCalls.bf b/IDEHelper/Tests/src/MethodCalls.bf index 13d4ef0d..26260594 100644 --- a/IDEHelper/Tests/src/MethodCalls.bf +++ b/IDEHelper/Tests/src/MethodCalls.bf @@ -96,6 +96,14 @@ namespace Tests return val3; } + public static float AddFloats(params float[C] vals) where C : const int + { + float total = 0; + for (var val in vals) + total += val; + return total; + } + [Test] public static void TestBasics() { @@ -117,6 +125,8 @@ namespace Tests Test.Assert(self.Method3b(sa) == sa); Test.Assert(self.Method4b(sa, sa2) == sa2); Test.Assert(self.Method5b(sa, sa2, sa3) == sa3); + + Test.Assert(AddFloats(1.0f, 2, 3) == 6.0f); } } }