diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index 66266b68..09a9ce34 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -588,6 +588,32 @@ bool BfGenericInferContext::InferGenericArgument(BfMethodInstance* methodInstanc return true; } +void BfGenericInferContext::InferGenericArguments(BfMethodInstance* methodInstance) +{ + // Attempt to infer from other generic args + for (int srcGenericIdx = 0; srcGenericIdx < (int)mCheckMethodGenericArguments->size(); srcGenericIdx++) + { + auto& srcGenericArg = (*mCheckMethodGenericArguments)[srcGenericIdx]; + if (srcGenericArg == NULL) + continue; + + auto srcGenericParam = methodInstance->mMethodInfoEx->mGenericParams[srcGenericIdx]; + for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints) + { + if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance())) + { + InferGenericArgument(methodInstance, srcGenericArg, ifaceConstraint, BfIRValue()); + auto typeInstance = srcGenericArg->ToTypeInstance(); + if (typeInstance != NULL) + { + for (auto ifaceEntry : typeInstance->mInterfaces) + InferGenericArgument(methodInstance, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue()); + } + } + } + } +} + // void BfGenericInferContext::PropogateInference(BfType* resolvedType, BfType* unresovledType) // { // if (!unresovledType->IsUnspecializedTypeVariation()) @@ -1207,7 +1233,7 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B if (!resolvedArg.mTypedValue) { // Resolve for real - resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, BfEvalExprFlags_NoCast); + resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete)); } argTypedValue = resolvedArg.mTypedValue; } @@ -1217,6 +1243,73 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B } } } + else if ((checkType == NULL) && (origCheckType != NULL) && (origCheckType->IsUnspecializedTypeVariation()) && (genericArgumentsSubstitute != NULL)) + { + BfMethodInstance* methodInstance = mModule->GetRawMethodInstanceAtIdx(origCheckType->ToTypeInstance(), 0, "Invoke"); + if (methodInstance != NULL) + { + if ((methodInstance->mReturnType->IsGenericParam()) && (((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamKind == BfGenericParamKind_Method)) + { + bool isValid = true; + + int returnMethodGenericArgIdx = ((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamIdx; + if ((*genericArgumentsSubstitute)[returnMethodGenericArgIdx] != NULL) + { + isValid = false; + } + + if (methodInstance->mParams.size() != (int)lambdaBindExpr->mParams.size()) + isValid = false; + + for (auto& param : methodInstance->mParams) + { + if (param.mResolvedType->IsGenericParam()) + { + auto genericParamType = (BfGenericParamType*)param.mResolvedType; + if ((genericParamType->mGenericParamKind == BfGenericParamKind_Method) && ((*genericArgumentsSubstitute)[genericParamType->mGenericParamIdx] == NULL)) + { + isValid = false; + } + } + } + + if (isValid) + { + bool success = false; + + (*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = mModule->GetPrimitiveType(BfTypeCode_None); + auto tryType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute); + if (tryType != NULL) + { + auto inferredReturnType = mModule->CreateValueFromExpression(lambdaBindExpr, tryType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_InferReturnType | BfEvalExprFlags_NoAutoComplete)); + if (inferredReturnType.mType != NULL) + { + (*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = inferredReturnType.mType; + + if (((flags & BfResolveArgFlag_FromGenericParam) != 0) && (lambdaBindExpr->mNewToken == NULL)) + { + auto resolvedType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute); + if (resolvedType != NULL) + { + // Resolve for real + resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, resolvedType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete)); + argTypedValue = resolvedArg.mTypedValue; + } + } + + success = true; + } + } + + if (!success) + { + // Put back + (*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = NULL; + } + } + } + } + } } else if ((resolvedArg.mArgFlags & BfArgFlag_UnqualifiedDotAttempt) != 0) { @@ -1672,7 +1765,12 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst argTypedValue = mTarget; } else - argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, BfResolveArgFlag_FromGeneric); + { + BfResolveArgFlags flags = BfResolveArgFlag_FromGeneric; + if (wantType->IsGenericParam()) + flags = (BfResolveArgFlags)(flags | BfResolveArgFlag_FromGenericParam); + argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, flags); + } if (!argTypedValue.IsUntypedValue()) { auto type = argTypedValue.mType; @@ -1709,6 +1807,11 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst goto NoMatch; } + if (!deferredArgs.IsEmpty()) + { + genericInferContext.InferGenericArguments(methodInstance); + } + while (!deferredArgs.IsEmpty()) { int prevDeferredSize = (int)deferredArgs.size(); @@ -11504,7 +11607,12 @@ void BfExprEvaluator::VisitLambdaBodies(BfAstNode* body, BfFieldDtorDeclaration* if (auto blockBody = BfNodeDynCast(body)) mModule->VisitChild(blockBody); else if (auto bodyExpr = BfNodeDynCast(body)) - mModule->CreateValueFromExpression(bodyExpr); + { + auto result = mModule->CreateValueFromExpression(bodyExpr); + if ((result) && (mModule->mCurMethodState->mClosureState != NULL) && + (mModule->mCurMethodState->mClosureState->mReturnTypeInferState == BfReturnTypeInferState_Inferring)) + mModule->mCurMethodState->mClosureState->mReturnType = result.mType; + } while (fieldDtor != NULL) { @@ -11530,8 +11638,10 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam } } + bool isInferReturnType = (mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0; + BfLambdaInstance* lambdaInstance = NULL; - if (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance)) + if ((!isInferReturnType) && (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance))) return lambdaInstance; static int sBindCount = 0; @@ -11683,7 +11793,11 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam return NULL; } - if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind)) + if ((lambdaBindExpr->mNewToken == NULL) && (isInferReturnType)) + { + // Method ref, but let this follow infer route + } + else if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind)) { if ((lambdaBindExpr->mNewToken != NULL) && (isFunctionBind)) mModule->Fail("Binds to functions should do not require allocations.", lambdaBindExpr->mNewToken); @@ -11951,8 +12065,14 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam closureState.mCaptureVisitingBody = true; closureState.mClosureInstanceInfo = closureInstanceInfo; - VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor); + if ((mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0) + { + closureState.mReturnType = NULL; + closureState.mReturnTypeInferState = BfReturnTypeInferState_Inferring; + } + VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor); + if (hasExplicitCaptureNames) _SetNotCapturedFlag(false); @@ -11964,12 +12084,32 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam prevClosureState->mCaptureStartAccessId = closureState.mCaptureStartAccessId; } - if (mModule->mCurMethodInstance->mIsUnspecialized) + bool earlyExit = false; + if (isInferReturnType) + { + if ((closureState.mReturnTypeInferState == BfReturnTypeInferState_Fail) || + (closureState.mReturnType == NULL)) + { + mResult = BfTypedValue(); + } + else + { + mResult = BfTypedValue(closureState.mReturnType); + } + + earlyExit = true; + } + else if (mModule->mCurMethodInstance->mIsUnspecialized) + { + earlyExit = true; + mResult = mModule->GetDefaultTypedValue(delegateTypeInstance); + } + + if (earlyExit) { prevIgnoreWrites.Restore(); mModule->mBfIRBuilder->RestoreDebugLocation(); - - mResult = mModule->GetDefaultTypedValue(delegateTypeInstance); + mModule->mBfIRBuilder->SetActiveFunction(prevActiveFunction); if (!prevInsertBlock.IsFake()) mModule->mBfIRBuilder->SetInsertPoint(prevInsertBlock); @@ -11981,7 +12121,7 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam closureState.mCaptureVisitingBody = false; prevIgnoreWrites.Restore(); - mModule->mBfIRBuilder->RestoreDebugLocation(); + mModule->mBfIRBuilder->RestoreDebugLocation(); auto _GetCaptureType = [&](const StringImpl& str) { @@ -14189,34 +14329,10 @@ BfModuleMethodInstance BfExprEvaluator::GetSelectedMethod(BfAstNode* targetSrc, if (genericArg == NULL) { - // Attempt to infer from other generic args - for (int srcGenericIdx = 0; srcGenericIdx < (int)methodMatcher.mBestMethodGenericArguments.size(); srcGenericIdx++) - { - auto& srcGenericArg = methodMatcher.mBestMethodGenericArguments[srcGenericIdx]; - if (srcGenericArg == NULL) - continue; - - auto srcGenericParam = unspecializedMethod->mMethodInfoEx->mGenericParams[srcGenericIdx]; - - BfGenericInferContext genericInferContext; - genericInferContext.mModule = mModule; - genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments; - - for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints) - { - if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance())) - { - genericInferContext.InferGenericArgument(unspecializedMethod, srcGenericArg, ifaceConstraint, BfIRValue()); - - auto typeInstance = srcGenericArg->ToTypeInstance(); - if (typeInstance != NULL) - { - for (auto ifaceEntry : typeInstance->mInterfaces) - genericInferContext.InferGenericArgument(unspecializedMethod, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue()); - } - } - } - } + BfGenericInferContext genericInferContext; + genericInferContext.mModule = mModule; + genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments; + genericInferContext.InferGenericArguments(unspecializedMethod); } if (genericArg == NULL) diff --git a/IDEHelper/Compiler/BfExprEvaluator.h b/IDEHelper/Compiler/BfExprEvaluator.h index 3ba5398b..8401d157 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.h +++ b/IDEHelper/Compiler/BfExprEvaluator.h @@ -37,7 +37,8 @@ enum BfResolveArgsFlags enum BfResolveArgFlags { BfResolveArgFlag_None = 0, - BfResolveArgFlag_FromGeneric = 1 + BfResolveArgFlag_FromGeneric = 1, + BfResolveArgFlag_FromGenericParam = 2 }; class BfResolvedArg @@ -141,6 +142,7 @@ public: { return (int)mCheckMethodGenericArguments->size() - mInferredCount; } + void InferGenericArguments(BfMethodInstance* methodInstance); }; class BfMethodMatcher diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index 77d53d12..37f6d256 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -8037,6 +8037,8 @@ BfTypedValue BfModule::CreateValueFromExpression(BfExprEvaluator& exprEvaluator, if (!exprEvaluator.mResult) { + if ((flags & BfEvalExprFlags_InferReturnType) != 0) + return exprEvaluator.mResult; if (!mCompiler->mPassInstance->HasFailed()) Fail("INTERNAL ERROR: No expression result returned but no error caught in expression evaluator", expr); return BfTypedValue(); diff --git a/IDEHelper/Compiler/BfModule.h b/IDEHelper/Compiler/BfModule.h index 6aa82d2b..6ff53a32 100644 --- a/IDEHelper/Compiler/BfModule.h +++ b/IDEHelper/Compiler/BfModule.h @@ -74,6 +74,8 @@ enum BfEvalExprFlags BfEvalExprFlags_NoLookupError = 0x40000, BfEvalExprFlags_Comptime = 0x80000, BfEvalExprFlags_InCascade = 0x100000, + BfEvalExprFlags_InferReturnType = 0x200000, + BfEvalExprFlags_WasMethodRef = 0x400000 }; enum BfCastFlags @@ -652,6 +654,13 @@ public: Array mMixinStateRecords; }; +enum BfReturnTypeInferState +{ + BfReturnTypeInferState_None, + BfReturnTypeInferState_Inferring, + BfReturnTypeInferState_Fail, +}; + class BfClosureState { public: @@ -661,6 +670,7 @@ public: // When we need to look into another local method to determine captures, but we don't want to process local variable declarations or cause infinite recursion bool mBlindCapturing; bool mDeclaringMethodIsMutating; + BfReturnTypeInferState mReturnTypeInferState; BfLocalMethod* mLocalMethod; BfClosureInstanceInfo* mClosureInstanceInfo; BfMethodDef* mClosureMethodDef; @@ -684,6 +694,7 @@ public: mCaptureStartAccessId = -1; mBlindCapturing = false; mDeclaringMethodIsMutating = false; + mReturnTypeInferState = BfReturnTypeInferState_None; mActiveDeferredLocalMethod = NULL; mReturnType = NULL; mClosureType = NULL; diff --git a/IDEHelper/Compiler/BfStmtEvaluator.cpp b/IDEHelper/Compiler/BfStmtEvaluator.cpp index c5350e93..a367d117 100644 --- a/IDEHelper/Compiler/BfStmtEvaluator.cpp +++ b/IDEHelper/Compiler/BfStmtEvaluator.cpp @@ -4900,8 +4900,12 @@ void BfModule::Visit(BfReturnStatement* returnStmt) if (mCurMethodInstance->IsMixin()) retType = NULL; - if (mCurMethodState->mClosureState != NULL) + bool inferReturnType = false; + if (mCurMethodState->mClosureState != NULL) + { retType = mCurMethodState->mClosureState->mReturnType; + inferReturnType = (mCurMethodState->mClosureState->mReturnTypeInferState != BfReturnTypeInferState_None); + } auto checkScope = mCurMethodState->mCurScope; while (checkScope != NULL) @@ -4931,7 +4935,7 @@ void BfModule::Visit(BfReturnStatement* returnStmt) checkLocalAssignData = checkLocalAssignData->mChainedAssignData; } - if (retType == NULL) + if ((retType == NULL) && (!inferReturnType)) { if (returnStmt->mExpression != NULL) { @@ -4972,7 +4976,38 @@ void BfModule::Visit(BfReturnStatement* returnStmt) exprEvaluator.mReceivingValue = &mCurMethodState->mRetVal; if (mCurMethodInstance->mMethodDef->mIsReadOnly) exprEvaluator.mAllowReadOnlyReference = true; + + if (inferReturnType) + expectingReturnType = NULL; + auto retValue = CreateValueFromExpression(exprEvaluator, returnStmt->mExpression, expectingReturnType, BfEvalExprFlags_AllowRefExpr, &origType); + + if ((retValue) && (inferReturnType)) + { + if (mCurMethodState->mClosureState->mReturnType == NULL) + mCurMethodState->mClosureState->mReturnType = retValue.mType; + else + { + if ((retValue.mType == mCurMethodState->mClosureState->mReturnType) || + (CanCast(retValue, mCurMethodState->mClosureState->mReturnType))) + { + // Leave as-is + } + else if (CanCast(GetFakeTypedValue(mCurMethodState->mClosureState->mReturnType), retValue.mType)) + { + mCurMethodState->mClosureState->mReturnType = retValue.mType; + } + else + { + mCurMethodState->mClosureState->mReturnTypeInferState = BfReturnTypeInferState_Fail; + } + } + } + if ((retType == NULL) && (inferReturnType)) + retType = mCurMethodState->mClosureState->mReturnType; + if (retType == NULL) + retType = GetPrimitiveType(BfTypeCode_None); + if ((!mIsComptimeModule) && (mCurMethodInstance->GetStructRetIdx() != -1)) alreadyWritten = exprEvaluator.mReceivingValue == NULL; MarkScopeLeft(&mCurMethodState->mHeadScope); diff --git a/IDEHelper/Tests/src/Generics.bf b/IDEHelper/Tests/src/Generics.bf index d65b9197..731764df 100644 --- a/IDEHelper/Tests/src/Generics.bf +++ b/IDEHelper/Tests/src/Generics.bf @@ -241,6 +241,17 @@ namespace Tests return 0; } + public static TResult Sum(this T it, TDlg dlg) + where T: concrete, IEnumerable + where TDlg: delegate TResult(TElem) + where TResult: operator TResult + TResult + { + var result = default(TResult); + for(var elem in it) + result += dlg(elem); + return result; + } + [Test] public static void TestBasics() { @@ -293,6 +304,8 @@ namespace Tests } == false);*/ Test.Assert(MethodE(floatList) == 6); Test.Assert(MethodF(floatList) == 0); + + Test.Assert(floatList.Sum((x) => x * 2) == 12); } }