1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-08 19:48:20 +02:00

Lambda return type inference

This commit is contained in:
Brian Fiete 2021-01-14 06:24:34 -08:00
parent d557e11dad
commit bb12a4ec20
6 changed files with 220 additions and 41 deletions

View file

@ -588,6 +588,32 @@ bool BfGenericInferContext::InferGenericArgument(BfMethodInstance* methodInstanc
return true;
}
void BfGenericInferContext::InferGenericArguments(BfMethodInstance* methodInstance)
{
// Attempt to infer from other generic args
for (int srcGenericIdx = 0; srcGenericIdx < (int)mCheckMethodGenericArguments->size(); srcGenericIdx++)
{
auto& srcGenericArg = (*mCheckMethodGenericArguments)[srcGenericIdx];
if (srcGenericArg == NULL)
continue;
auto srcGenericParam = methodInstance->mMethodInfoEx->mGenericParams[srcGenericIdx];
for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
{
if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
{
InferGenericArgument(methodInstance, srcGenericArg, ifaceConstraint, BfIRValue());
auto typeInstance = srcGenericArg->ToTypeInstance();
if (typeInstance != NULL)
{
for (auto ifaceEntry : typeInstance->mInterfaces)
InferGenericArgument(methodInstance, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
}
}
}
}
}
// void BfGenericInferContext::PropogateInference(BfType* resolvedType, BfType* unresovledType)
// {
// if (!unresovledType->IsUnspecializedTypeVariation())
@ -1207,7 +1233,7 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
if (!resolvedArg.mTypedValue)
{
// Resolve for real
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, BfEvalExprFlags_NoCast);
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
}
argTypedValue = resolvedArg.mTypedValue;
}
@ -1217,6 +1243,73 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
}
}
}
else if ((checkType == NULL) && (origCheckType != NULL) && (origCheckType->IsUnspecializedTypeVariation()) && (genericArgumentsSubstitute != NULL))
{
BfMethodInstance* methodInstance = mModule->GetRawMethodInstanceAtIdx(origCheckType->ToTypeInstance(), 0, "Invoke");
if (methodInstance != NULL)
{
if ((methodInstance->mReturnType->IsGenericParam()) && (((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamKind == BfGenericParamKind_Method))
{
bool isValid = true;
int returnMethodGenericArgIdx = ((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamIdx;
if ((*genericArgumentsSubstitute)[returnMethodGenericArgIdx] != NULL)
{
isValid = false;
}
if (methodInstance->mParams.size() != (int)lambdaBindExpr->mParams.size())
isValid = false;
for (auto& param : methodInstance->mParams)
{
if (param.mResolvedType->IsGenericParam())
{
auto genericParamType = (BfGenericParamType*)param.mResolvedType;
if ((genericParamType->mGenericParamKind == BfGenericParamKind_Method) && ((*genericArgumentsSubstitute)[genericParamType->mGenericParamIdx] == NULL))
{
isValid = false;
}
}
}
if (isValid)
{
bool success = false;
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = mModule->GetPrimitiveType(BfTypeCode_None);
auto tryType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
if (tryType != NULL)
{
auto inferredReturnType = mModule->CreateValueFromExpression(lambdaBindExpr, tryType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_InferReturnType | BfEvalExprFlags_NoAutoComplete));
if (inferredReturnType.mType != NULL)
{
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = inferredReturnType.mType;
if (((flags & BfResolveArgFlag_FromGenericParam) != 0) && (lambdaBindExpr->mNewToken == NULL))
{
auto resolvedType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
if (resolvedType != NULL)
{
// Resolve for real
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, resolvedType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
argTypedValue = resolvedArg.mTypedValue;
}
}
success = true;
}
}
if (!success)
{
// Put back
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = NULL;
}
}
}
}
}
}
else if ((resolvedArg.mArgFlags & BfArgFlag_UnqualifiedDotAttempt) != 0)
{
@ -1672,7 +1765,12 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
argTypedValue = mTarget;
}
else
argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, BfResolveArgFlag_FromGeneric);
{
BfResolveArgFlags flags = BfResolveArgFlag_FromGeneric;
if (wantType->IsGenericParam())
flags = (BfResolveArgFlags)(flags | BfResolveArgFlag_FromGenericParam);
argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, flags);
}
if (!argTypedValue.IsUntypedValue())
{
auto type = argTypedValue.mType;
@ -1709,6 +1807,11 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
goto NoMatch;
}
if (!deferredArgs.IsEmpty())
{
genericInferContext.InferGenericArguments(methodInstance);
}
while (!deferredArgs.IsEmpty())
{
int prevDeferredSize = (int)deferredArgs.size();
@ -11504,7 +11607,12 @@ void BfExprEvaluator::VisitLambdaBodies(BfAstNode* body, BfFieldDtorDeclaration*
if (auto blockBody = BfNodeDynCast<BfBlock>(body))
mModule->VisitChild(blockBody);
else if (auto bodyExpr = BfNodeDynCast<BfExpression>(body))
mModule->CreateValueFromExpression(bodyExpr);
{
auto result = mModule->CreateValueFromExpression(bodyExpr);
if ((result) && (mModule->mCurMethodState->mClosureState != NULL) &&
(mModule->mCurMethodState->mClosureState->mReturnTypeInferState == BfReturnTypeInferState_Inferring))
mModule->mCurMethodState->mClosureState->mReturnType = result.mType;
}
while (fieldDtor != NULL)
{
@ -11530,8 +11638,10 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
}
}
bool isInferReturnType = (mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0;
BfLambdaInstance* lambdaInstance = NULL;
if (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance))
if ((!isInferReturnType) && (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance)))
return lambdaInstance;
static int sBindCount = 0;
@ -11683,7 +11793,11 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
return NULL;
}
if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
if ((lambdaBindExpr->mNewToken == NULL) && (isInferReturnType))
{
// Method ref, but let this follow infer route
}
else if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
{
if ((lambdaBindExpr->mNewToken != NULL) && (isFunctionBind))
mModule->Fail("Binds to functions should do not require allocations.", lambdaBindExpr->mNewToken);
@ -11951,6 +12065,12 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
closureState.mCaptureVisitingBody = true;
closureState.mClosureInstanceInfo = closureInstanceInfo;
if ((mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0)
{
closureState.mReturnType = NULL;
closureState.mReturnTypeInferState = BfReturnTypeInferState_Inferring;
}
VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor);
if (hasExplicitCaptureNames)
@ -11964,12 +12084,32 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
prevClosureState->mCaptureStartAccessId = closureState.mCaptureStartAccessId;
}
if (mModule->mCurMethodInstance->mIsUnspecialized)
bool earlyExit = false;
if (isInferReturnType)
{
if ((closureState.mReturnTypeInferState == BfReturnTypeInferState_Fail) ||
(closureState.mReturnType == NULL))
{
mResult = BfTypedValue();
}
else
{
mResult = BfTypedValue(closureState.mReturnType);
}
earlyExit = true;
}
else if (mModule->mCurMethodInstance->mIsUnspecialized)
{
earlyExit = true;
mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
}
if (earlyExit)
{
prevIgnoreWrites.Restore();
mModule->mBfIRBuilder->RestoreDebugLocation();
mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
mModule->mBfIRBuilder->SetActiveFunction(prevActiveFunction);
if (!prevInsertBlock.IsFake())
mModule->mBfIRBuilder->SetInsertPoint(prevInsertBlock);
@ -14189,34 +14329,10 @@ BfModuleMethodInstance BfExprEvaluator::GetSelectedMethod(BfAstNode* targetSrc,
if (genericArg == NULL)
{
// Attempt to infer from other generic args
for (int srcGenericIdx = 0; srcGenericIdx < (int)methodMatcher.mBestMethodGenericArguments.size(); srcGenericIdx++)
{
auto& srcGenericArg = methodMatcher.mBestMethodGenericArguments[srcGenericIdx];
if (srcGenericArg == NULL)
continue;
auto srcGenericParam = unspecializedMethod->mMethodInfoEx->mGenericParams[srcGenericIdx];
BfGenericInferContext genericInferContext;
genericInferContext.mModule = mModule;
genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments;
for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
{
if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
{
genericInferContext.InferGenericArgument(unspecializedMethod, srcGenericArg, ifaceConstraint, BfIRValue());
auto typeInstance = srcGenericArg->ToTypeInstance();
if (typeInstance != NULL)
{
for (auto ifaceEntry : typeInstance->mInterfaces)
genericInferContext.InferGenericArgument(unspecializedMethod, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
}
}
}
}
genericInferContext.InferGenericArguments(unspecializedMethod);
}
if (genericArg == NULL)

View file

@ -37,7 +37,8 @@ enum BfResolveArgsFlags
enum BfResolveArgFlags
{
BfResolveArgFlag_None = 0,
BfResolveArgFlag_FromGeneric = 1
BfResolveArgFlag_FromGeneric = 1,
BfResolveArgFlag_FromGenericParam = 2
};
class BfResolvedArg
@ -141,6 +142,7 @@ public:
{
return (int)mCheckMethodGenericArguments->size() - mInferredCount;
}
void InferGenericArguments(BfMethodInstance* methodInstance);
};
class BfMethodMatcher

View file

@ -8037,6 +8037,8 @@ BfTypedValue BfModule::CreateValueFromExpression(BfExprEvaluator& exprEvaluator,
if (!exprEvaluator.mResult)
{
if ((flags & BfEvalExprFlags_InferReturnType) != 0)
return exprEvaluator.mResult;
if (!mCompiler->mPassInstance->HasFailed())
Fail("INTERNAL ERROR: No expression result returned but no error caught in expression evaluator", expr);
return BfTypedValue();

View file

@ -74,6 +74,8 @@ enum BfEvalExprFlags
BfEvalExprFlags_NoLookupError = 0x40000,
BfEvalExprFlags_Comptime = 0x80000,
BfEvalExprFlags_InCascade = 0x100000,
BfEvalExprFlags_InferReturnType = 0x200000,
BfEvalExprFlags_WasMethodRef = 0x400000
};
enum BfCastFlags
@ -652,6 +654,13 @@ public:
Array<BfMixinRecord> mMixinStateRecords;
};
enum BfReturnTypeInferState
{
BfReturnTypeInferState_None,
BfReturnTypeInferState_Inferring,
BfReturnTypeInferState_Fail,
};
class BfClosureState
{
public:
@ -661,6 +670,7 @@ public:
// When we need to look into another local method to determine captures, but we don't want to process local variable declarations or cause infinite recursion
bool mBlindCapturing;
bool mDeclaringMethodIsMutating;
BfReturnTypeInferState mReturnTypeInferState;
BfLocalMethod* mLocalMethod;
BfClosureInstanceInfo* mClosureInstanceInfo;
BfMethodDef* mClosureMethodDef;
@ -684,6 +694,7 @@ public:
mCaptureStartAccessId = -1;
mBlindCapturing = false;
mDeclaringMethodIsMutating = false;
mReturnTypeInferState = BfReturnTypeInferState_None;
mActiveDeferredLocalMethod = NULL;
mReturnType = NULL;
mClosureType = NULL;

View file

@ -4900,8 +4900,12 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
if (mCurMethodInstance->IsMixin())
retType = NULL;
bool inferReturnType = false;
if (mCurMethodState->mClosureState != NULL)
{
retType = mCurMethodState->mClosureState->mReturnType;
inferReturnType = (mCurMethodState->mClosureState->mReturnTypeInferState != BfReturnTypeInferState_None);
}
auto checkScope = mCurMethodState->mCurScope;
while (checkScope != NULL)
@ -4931,7 +4935,7 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
checkLocalAssignData = checkLocalAssignData->mChainedAssignData;
}
if (retType == NULL)
if ((retType == NULL) && (!inferReturnType))
{
if (returnStmt->mExpression != NULL)
{
@ -4972,7 +4976,38 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
exprEvaluator.mReceivingValue = &mCurMethodState->mRetVal;
if (mCurMethodInstance->mMethodDef->mIsReadOnly)
exprEvaluator.mAllowReadOnlyReference = true;
if (inferReturnType)
expectingReturnType = NULL;
auto retValue = CreateValueFromExpression(exprEvaluator, returnStmt->mExpression, expectingReturnType, BfEvalExprFlags_AllowRefExpr, &origType);
if ((retValue) && (inferReturnType))
{
if (mCurMethodState->mClosureState->mReturnType == NULL)
mCurMethodState->mClosureState->mReturnType = retValue.mType;
else
{
if ((retValue.mType == mCurMethodState->mClosureState->mReturnType) ||
(CanCast(retValue, mCurMethodState->mClosureState->mReturnType)))
{
// Leave as-is
}
else if (CanCast(GetFakeTypedValue(mCurMethodState->mClosureState->mReturnType), retValue.mType))
{
mCurMethodState->mClosureState->mReturnType = retValue.mType;
}
else
{
mCurMethodState->mClosureState->mReturnTypeInferState = BfReturnTypeInferState_Fail;
}
}
}
if ((retType == NULL) && (inferReturnType))
retType = mCurMethodState->mClosureState->mReturnType;
if (retType == NULL)
retType = GetPrimitiveType(BfTypeCode_None);
if ((!mIsComptimeModule) && (mCurMethodInstance->GetStructRetIdx() != -1))
alreadyWritten = exprEvaluator.mReceivingValue == NULL;
MarkScopeLeft(&mCurMethodState->mHeadScope);

View file

@ -241,6 +241,17 @@ namespace Tests
return 0;
}
public static TResult Sum<T, TElem, TDlg, TResult>(this T it, TDlg dlg)
where T: concrete, IEnumerable<TElem>
where TDlg: delegate TResult(TElem)
where TResult: operator TResult + TResult
{
var result = default(TResult);
for(var elem in it)
result += dlg(elem);
return result;
}
[Test]
public static void TestBasics()
{
@ -293,6 +304,8 @@ namespace Tests
} == false);*/
Test.Assert(MethodE(floatList) == 6);
Test.Assert(MethodF(floatList) == 0);
Test.Assert(floatList.Sum((x) => x * 2) == 12);
}
}