mirror of
https://github.com/beefytech/Beef.git
synced 2025-06-08 19:48:20 +02:00
Lambda return type inference
This commit is contained in:
parent
d557e11dad
commit
bb12a4ec20
6 changed files with 220 additions and 41 deletions
|
@ -588,6 +588,32 @@ bool BfGenericInferContext::InferGenericArgument(BfMethodInstance* methodInstanc
|
|||
return true;
|
||||
}
|
||||
|
||||
void BfGenericInferContext::InferGenericArguments(BfMethodInstance* methodInstance)
|
||||
{
|
||||
// Attempt to infer from other generic args
|
||||
for (int srcGenericIdx = 0; srcGenericIdx < (int)mCheckMethodGenericArguments->size(); srcGenericIdx++)
|
||||
{
|
||||
auto& srcGenericArg = (*mCheckMethodGenericArguments)[srcGenericIdx];
|
||||
if (srcGenericArg == NULL)
|
||||
continue;
|
||||
|
||||
auto srcGenericParam = methodInstance->mMethodInfoEx->mGenericParams[srcGenericIdx];
|
||||
for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
|
||||
{
|
||||
if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
|
||||
{
|
||||
InferGenericArgument(methodInstance, srcGenericArg, ifaceConstraint, BfIRValue());
|
||||
auto typeInstance = srcGenericArg->ToTypeInstance();
|
||||
if (typeInstance != NULL)
|
||||
{
|
||||
for (auto ifaceEntry : typeInstance->mInterfaces)
|
||||
InferGenericArgument(methodInstance, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// void BfGenericInferContext::PropogateInference(BfType* resolvedType, BfType* unresovledType)
|
||||
// {
|
||||
// if (!unresovledType->IsUnspecializedTypeVariation())
|
||||
|
@ -1207,7 +1233,7 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
|
|||
if (!resolvedArg.mTypedValue)
|
||||
{
|
||||
// Resolve for real
|
||||
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, BfEvalExprFlags_NoCast);
|
||||
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
|
||||
}
|
||||
argTypedValue = resolvedArg.mTypedValue;
|
||||
}
|
||||
|
@ -1217,6 +1243,73 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
|
|||
}
|
||||
}
|
||||
}
|
||||
else if ((checkType == NULL) && (origCheckType != NULL) && (origCheckType->IsUnspecializedTypeVariation()) && (genericArgumentsSubstitute != NULL))
|
||||
{
|
||||
BfMethodInstance* methodInstance = mModule->GetRawMethodInstanceAtIdx(origCheckType->ToTypeInstance(), 0, "Invoke");
|
||||
if (methodInstance != NULL)
|
||||
{
|
||||
if ((methodInstance->mReturnType->IsGenericParam()) && (((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamKind == BfGenericParamKind_Method))
|
||||
{
|
||||
bool isValid = true;
|
||||
|
||||
int returnMethodGenericArgIdx = ((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamIdx;
|
||||
if ((*genericArgumentsSubstitute)[returnMethodGenericArgIdx] != NULL)
|
||||
{
|
||||
isValid = false;
|
||||
}
|
||||
|
||||
if (methodInstance->mParams.size() != (int)lambdaBindExpr->mParams.size())
|
||||
isValid = false;
|
||||
|
||||
for (auto& param : methodInstance->mParams)
|
||||
{
|
||||
if (param.mResolvedType->IsGenericParam())
|
||||
{
|
||||
auto genericParamType = (BfGenericParamType*)param.mResolvedType;
|
||||
if ((genericParamType->mGenericParamKind == BfGenericParamKind_Method) && ((*genericArgumentsSubstitute)[genericParamType->mGenericParamIdx] == NULL))
|
||||
{
|
||||
isValid = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isValid)
|
||||
{
|
||||
bool success = false;
|
||||
|
||||
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = mModule->GetPrimitiveType(BfTypeCode_None);
|
||||
auto tryType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
|
||||
if (tryType != NULL)
|
||||
{
|
||||
auto inferredReturnType = mModule->CreateValueFromExpression(lambdaBindExpr, tryType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_InferReturnType | BfEvalExprFlags_NoAutoComplete));
|
||||
if (inferredReturnType.mType != NULL)
|
||||
{
|
||||
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = inferredReturnType.mType;
|
||||
|
||||
if (((flags & BfResolveArgFlag_FromGenericParam) != 0) && (lambdaBindExpr->mNewToken == NULL))
|
||||
{
|
||||
auto resolvedType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
|
||||
if (resolvedType != NULL)
|
||||
{
|
||||
// Resolve for real
|
||||
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, resolvedType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
|
||||
argTypedValue = resolvedArg.mTypedValue;
|
||||
}
|
||||
}
|
||||
|
||||
success = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!success)
|
||||
{
|
||||
// Put back
|
||||
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if ((resolvedArg.mArgFlags & BfArgFlag_UnqualifiedDotAttempt) != 0)
|
||||
{
|
||||
|
@ -1672,7 +1765,12 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
|
|||
argTypedValue = mTarget;
|
||||
}
|
||||
else
|
||||
argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, BfResolveArgFlag_FromGeneric);
|
||||
{
|
||||
BfResolveArgFlags flags = BfResolveArgFlag_FromGeneric;
|
||||
if (wantType->IsGenericParam())
|
||||
flags = (BfResolveArgFlags)(flags | BfResolveArgFlag_FromGenericParam);
|
||||
argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, flags);
|
||||
}
|
||||
if (!argTypedValue.IsUntypedValue())
|
||||
{
|
||||
auto type = argTypedValue.mType;
|
||||
|
@ -1709,6 +1807,11 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
|
|||
goto NoMatch;
|
||||
}
|
||||
|
||||
if (!deferredArgs.IsEmpty())
|
||||
{
|
||||
genericInferContext.InferGenericArguments(methodInstance);
|
||||
}
|
||||
|
||||
while (!deferredArgs.IsEmpty())
|
||||
{
|
||||
int prevDeferredSize = (int)deferredArgs.size();
|
||||
|
@ -11504,7 +11607,12 @@ void BfExprEvaluator::VisitLambdaBodies(BfAstNode* body, BfFieldDtorDeclaration*
|
|||
if (auto blockBody = BfNodeDynCast<BfBlock>(body))
|
||||
mModule->VisitChild(blockBody);
|
||||
else if (auto bodyExpr = BfNodeDynCast<BfExpression>(body))
|
||||
mModule->CreateValueFromExpression(bodyExpr);
|
||||
{
|
||||
auto result = mModule->CreateValueFromExpression(bodyExpr);
|
||||
if ((result) && (mModule->mCurMethodState->mClosureState != NULL) &&
|
||||
(mModule->mCurMethodState->mClosureState->mReturnTypeInferState == BfReturnTypeInferState_Inferring))
|
||||
mModule->mCurMethodState->mClosureState->mReturnType = result.mType;
|
||||
}
|
||||
|
||||
while (fieldDtor != NULL)
|
||||
{
|
||||
|
@ -11530,8 +11638,10 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
|
|||
}
|
||||
}
|
||||
|
||||
bool isInferReturnType = (mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0;
|
||||
|
||||
BfLambdaInstance* lambdaInstance = NULL;
|
||||
if (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance))
|
||||
if ((!isInferReturnType) && (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance)))
|
||||
return lambdaInstance;
|
||||
|
||||
static int sBindCount = 0;
|
||||
|
@ -11683,7 +11793,11 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
|
|||
return NULL;
|
||||
}
|
||||
|
||||
if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
|
||||
if ((lambdaBindExpr->mNewToken == NULL) && (isInferReturnType))
|
||||
{
|
||||
// Method ref, but let this follow infer route
|
||||
}
|
||||
else if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
|
||||
{
|
||||
if ((lambdaBindExpr->mNewToken != NULL) && (isFunctionBind))
|
||||
mModule->Fail("Binds to functions should do not require allocations.", lambdaBindExpr->mNewToken);
|
||||
|
@ -11951,6 +12065,12 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
|
|||
closureState.mCaptureVisitingBody = true;
|
||||
closureState.mClosureInstanceInfo = closureInstanceInfo;
|
||||
|
||||
if ((mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0)
|
||||
{
|
||||
closureState.mReturnType = NULL;
|
||||
closureState.mReturnTypeInferState = BfReturnTypeInferState_Inferring;
|
||||
}
|
||||
|
||||
VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor);
|
||||
|
||||
if (hasExplicitCaptureNames)
|
||||
|
@ -11964,12 +12084,32 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
|
|||
prevClosureState->mCaptureStartAccessId = closureState.mCaptureStartAccessId;
|
||||
}
|
||||
|
||||
if (mModule->mCurMethodInstance->mIsUnspecialized)
|
||||
bool earlyExit = false;
|
||||
if (isInferReturnType)
|
||||
{
|
||||
if ((closureState.mReturnTypeInferState == BfReturnTypeInferState_Fail) ||
|
||||
(closureState.mReturnType == NULL))
|
||||
{
|
||||
mResult = BfTypedValue();
|
||||
}
|
||||
else
|
||||
{
|
||||
mResult = BfTypedValue(closureState.mReturnType);
|
||||
}
|
||||
|
||||
earlyExit = true;
|
||||
}
|
||||
else if (mModule->mCurMethodInstance->mIsUnspecialized)
|
||||
{
|
||||
earlyExit = true;
|
||||
mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
|
||||
}
|
||||
|
||||
if (earlyExit)
|
||||
{
|
||||
prevIgnoreWrites.Restore();
|
||||
mModule->mBfIRBuilder->RestoreDebugLocation();
|
||||
|
||||
mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
|
||||
mModule->mBfIRBuilder->SetActiveFunction(prevActiveFunction);
|
||||
if (!prevInsertBlock.IsFake())
|
||||
mModule->mBfIRBuilder->SetInsertPoint(prevInsertBlock);
|
||||
|
@ -14189,34 +14329,10 @@ BfModuleMethodInstance BfExprEvaluator::GetSelectedMethod(BfAstNode* targetSrc,
|
|||
|
||||
if (genericArg == NULL)
|
||||
{
|
||||
// Attempt to infer from other generic args
|
||||
for (int srcGenericIdx = 0; srcGenericIdx < (int)methodMatcher.mBestMethodGenericArguments.size(); srcGenericIdx++)
|
||||
{
|
||||
auto& srcGenericArg = methodMatcher.mBestMethodGenericArguments[srcGenericIdx];
|
||||
if (srcGenericArg == NULL)
|
||||
continue;
|
||||
|
||||
auto srcGenericParam = unspecializedMethod->mMethodInfoEx->mGenericParams[srcGenericIdx];
|
||||
|
||||
BfGenericInferContext genericInferContext;
|
||||
genericInferContext.mModule = mModule;
|
||||
genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments;
|
||||
|
||||
for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
|
||||
{
|
||||
if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
|
||||
{
|
||||
genericInferContext.InferGenericArgument(unspecializedMethod, srcGenericArg, ifaceConstraint, BfIRValue());
|
||||
|
||||
auto typeInstance = srcGenericArg->ToTypeInstance();
|
||||
if (typeInstance != NULL)
|
||||
{
|
||||
for (auto ifaceEntry : typeInstance->mInterfaces)
|
||||
genericInferContext.InferGenericArgument(unspecializedMethod, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
genericInferContext.InferGenericArguments(unspecializedMethod);
|
||||
}
|
||||
|
||||
if (genericArg == NULL)
|
||||
|
|
|
@ -37,7 +37,8 @@ enum BfResolveArgsFlags
|
|||
enum BfResolveArgFlags
|
||||
{
|
||||
BfResolveArgFlag_None = 0,
|
||||
BfResolveArgFlag_FromGeneric = 1
|
||||
BfResolveArgFlag_FromGeneric = 1,
|
||||
BfResolveArgFlag_FromGenericParam = 2
|
||||
};
|
||||
|
||||
class BfResolvedArg
|
||||
|
@ -141,6 +142,7 @@ public:
|
|||
{
|
||||
return (int)mCheckMethodGenericArguments->size() - mInferredCount;
|
||||
}
|
||||
void InferGenericArguments(BfMethodInstance* methodInstance);
|
||||
};
|
||||
|
||||
class BfMethodMatcher
|
||||
|
|
|
@ -8037,6 +8037,8 @@ BfTypedValue BfModule::CreateValueFromExpression(BfExprEvaluator& exprEvaluator,
|
|||
|
||||
if (!exprEvaluator.mResult)
|
||||
{
|
||||
if ((flags & BfEvalExprFlags_InferReturnType) != 0)
|
||||
return exprEvaluator.mResult;
|
||||
if (!mCompiler->mPassInstance->HasFailed())
|
||||
Fail("INTERNAL ERROR: No expression result returned but no error caught in expression evaluator", expr);
|
||||
return BfTypedValue();
|
||||
|
|
|
@ -74,6 +74,8 @@ enum BfEvalExprFlags
|
|||
BfEvalExprFlags_NoLookupError = 0x40000,
|
||||
BfEvalExprFlags_Comptime = 0x80000,
|
||||
BfEvalExprFlags_InCascade = 0x100000,
|
||||
BfEvalExprFlags_InferReturnType = 0x200000,
|
||||
BfEvalExprFlags_WasMethodRef = 0x400000
|
||||
};
|
||||
|
||||
enum BfCastFlags
|
||||
|
@ -652,6 +654,13 @@ public:
|
|||
Array<BfMixinRecord> mMixinStateRecords;
|
||||
};
|
||||
|
||||
enum BfReturnTypeInferState
|
||||
{
|
||||
BfReturnTypeInferState_None,
|
||||
BfReturnTypeInferState_Inferring,
|
||||
BfReturnTypeInferState_Fail,
|
||||
};
|
||||
|
||||
class BfClosureState
|
||||
{
|
||||
public:
|
||||
|
@ -661,6 +670,7 @@ public:
|
|||
// When we need to look into another local method to determine captures, but we don't want to process local variable declarations or cause infinite recursion
|
||||
bool mBlindCapturing;
|
||||
bool mDeclaringMethodIsMutating;
|
||||
BfReturnTypeInferState mReturnTypeInferState;
|
||||
BfLocalMethod* mLocalMethod;
|
||||
BfClosureInstanceInfo* mClosureInstanceInfo;
|
||||
BfMethodDef* mClosureMethodDef;
|
||||
|
@ -684,6 +694,7 @@ public:
|
|||
mCaptureStartAccessId = -1;
|
||||
mBlindCapturing = false;
|
||||
mDeclaringMethodIsMutating = false;
|
||||
mReturnTypeInferState = BfReturnTypeInferState_None;
|
||||
mActiveDeferredLocalMethod = NULL;
|
||||
mReturnType = NULL;
|
||||
mClosureType = NULL;
|
||||
|
|
|
@ -4900,8 +4900,12 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
|
|||
if (mCurMethodInstance->IsMixin())
|
||||
retType = NULL;
|
||||
|
||||
bool inferReturnType = false;
|
||||
if (mCurMethodState->mClosureState != NULL)
|
||||
{
|
||||
retType = mCurMethodState->mClosureState->mReturnType;
|
||||
inferReturnType = (mCurMethodState->mClosureState->mReturnTypeInferState != BfReturnTypeInferState_None);
|
||||
}
|
||||
|
||||
auto checkScope = mCurMethodState->mCurScope;
|
||||
while (checkScope != NULL)
|
||||
|
@ -4931,7 +4935,7 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
|
|||
checkLocalAssignData = checkLocalAssignData->mChainedAssignData;
|
||||
}
|
||||
|
||||
if (retType == NULL)
|
||||
if ((retType == NULL) && (!inferReturnType))
|
||||
{
|
||||
if (returnStmt->mExpression != NULL)
|
||||
{
|
||||
|
@ -4972,7 +4976,38 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
|
|||
exprEvaluator.mReceivingValue = &mCurMethodState->mRetVal;
|
||||
if (mCurMethodInstance->mMethodDef->mIsReadOnly)
|
||||
exprEvaluator.mAllowReadOnlyReference = true;
|
||||
|
||||
if (inferReturnType)
|
||||
expectingReturnType = NULL;
|
||||
|
||||
auto retValue = CreateValueFromExpression(exprEvaluator, returnStmt->mExpression, expectingReturnType, BfEvalExprFlags_AllowRefExpr, &origType);
|
||||
|
||||
if ((retValue) && (inferReturnType))
|
||||
{
|
||||
if (mCurMethodState->mClosureState->mReturnType == NULL)
|
||||
mCurMethodState->mClosureState->mReturnType = retValue.mType;
|
||||
else
|
||||
{
|
||||
if ((retValue.mType == mCurMethodState->mClosureState->mReturnType) ||
|
||||
(CanCast(retValue, mCurMethodState->mClosureState->mReturnType)))
|
||||
{
|
||||
// Leave as-is
|
||||
}
|
||||
else if (CanCast(GetFakeTypedValue(mCurMethodState->mClosureState->mReturnType), retValue.mType))
|
||||
{
|
||||
mCurMethodState->mClosureState->mReturnType = retValue.mType;
|
||||
}
|
||||
else
|
||||
{
|
||||
mCurMethodState->mClosureState->mReturnTypeInferState = BfReturnTypeInferState_Fail;
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((retType == NULL) && (inferReturnType))
|
||||
retType = mCurMethodState->mClosureState->mReturnType;
|
||||
if (retType == NULL)
|
||||
retType = GetPrimitiveType(BfTypeCode_None);
|
||||
|
||||
if ((!mIsComptimeModule) && (mCurMethodInstance->GetStructRetIdx() != -1))
|
||||
alreadyWritten = exprEvaluator.mReceivingValue == NULL;
|
||||
MarkScopeLeft(&mCurMethodState->mHeadScope);
|
||||
|
|
|
@ -241,6 +241,17 @@ namespace Tests
|
|||
return 0;
|
||||
}
|
||||
|
||||
public static TResult Sum<T, TElem, TDlg, TResult>(this T it, TDlg dlg)
|
||||
where T: concrete, IEnumerable<TElem>
|
||||
where TDlg: delegate TResult(TElem)
|
||||
where TResult: operator TResult + TResult
|
||||
{
|
||||
var result = default(TResult);
|
||||
for(var elem in it)
|
||||
result += dlg(elem);
|
||||
return result;
|
||||
}
|
||||
|
||||
[Test]
|
||||
public static void TestBasics()
|
||||
{
|
||||
|
@ -293,6 +304,8 @@ namespace Tests
|
|||
} == false);*/
|
||||
Test.Assert(MethodE(floatList) == 6);
|
||||
Test.Assert(MethodF(floatList) == 0);
|
||||
|
||||
Test.Assert(floatList.Sum((x) => x * 2) == 12);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue