1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-10 04:22:20 +02:00

Lambda return type inference

This commit is contained in:
Brian Fiete 2021-01-14 06:24:34 -08:00
parent d557e11dad
commit bb12a4ec20
6 changed files with 220 additions and 41 deletions

View file

@ -588,6 +588,32 @@ bool BfGenericInferContext::InferGenericArgument(BfMethodInstance* methodInstanc
return true;
}
void BfGenericInferContext::InferGenericArguments(BfMethodInstance* methodInstance)
{
// Attempt to infer from other generic args
for (int srcGenericIdx = 0; srcGenericIdx < (int)mCheckMethodGenericArguments->size(); srcGenericIdx++)
{
auto& srcGenericArg = (*mCheckMethodGenericArguments)[srcGenericIdx];
if (srcGenericArg == NULL)
continue;
auto srcGenericParam = methodInstance->mMethodInfoEx->mGenericParams[srcGenericIdx];
for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
{
if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
{
InferGenericArgument(methodInstance, srcGenericArg, ifaceConstraint, BfIRValue());
auto typeInstance = srcGenericArg->ToTypeInstance();
if (typeInstance != NULL)
{
for (auto ifaceEntry : typeInstance->mInterfaces)
InferGenericArgument(methodInstance, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
}
}
}
}
}
// void BfGenericInferContext::PropogateInference(BfType* resolvedType, BfType* unresovledType)
// {
// if (!unresovledType->IsUnspecializedTypeVariation())
@ -1207,7 +1233,7 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
if (!resolvedArg.mTypedValue)
{
// Resolve for real
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, BfEvalExprFlags_NoCast);
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
}
argTypedValue = resolvedArg.mTypedValue;
}
@ -1217,6 +1243,73 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
}
}
}
else if ((checkType == NULL) && (origCheckType != NULL) && (origCheckType->IsUnspecializedTypeVariation()) && (genericArgumentsSubstitute != NULL))
{
BfMethodInstance* methodInstance = mModule->GetRawMethodInstanceAtIdx(origCheckType->ToTypeInstance(), 0, "Invoke");
if (methodInstance != NULL)
{
if ((methodInstance->mReturnType->IsGenericParam()) && (((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamKind == BfGenericParamKind_Method))
{
bool isValid = true;
int returnMethodGenericArgIdx = ((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamIdx;
if ((*genericArgumentsSubstitute)[returnMethodGenericArgIdx] != NULL)
{
isValid = false;
}
if (methodInstance->mParams.size() != (int)lambdaBindExpr->mParams.size())
isValid = false;
for (auto& param : methodInstance->mParams)
{
if (param.mResolvedType->IsGenericParam())
{
auto genericParamType = (BfGenericParamType*)param.mResolvedType;
if ((genericParamType->mGenericParamKind == BfGenericParamKind_Method) && ((*genericArgumentsSubstitute)[genericParamType->mGenericParamIdx] == NULL))
{
isValid = false;
}
}
}
if (isValid)
{
bool success = false;
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = mModule->GetPrimitiveType(BfTypeCode_None);
auto tryType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
if (tryType != NULL)
{
auto inferredReturnType = mModule->CreateValueFromExpression(lambdaBindExpr, tryType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_InferReturnType | BfEvalExprFlags_NoAutoComplete));
if (inferredReturnType.mType != NULL)
{
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = inferredReturnType.mType;
if (((flags & BfResolveArgFlag_FromGenericParam) != 0) && (lambdaBindExpr->mNewToken == NULL))
{
auto resolvedType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
if (resolvedType != NULL)
{
// Resolve for real
resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, resolvedType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
argTypedValue = resolvedArg.mTypedValue;
}
}
success = true;
}
}
if (!success)
{
// Put back
(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = NULL;
}
}
}
}
}
}
else if ((resolvedArg.mArgFlags & BfArgFlag_UnqualifiedDotAttempt) != 0)
{
@ -1672,7 +1765,12 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
argTypedValue = mTarget;
}
else
argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, BfResolveArgFlag_FromGeneric);
{
BfResolveArgFlags flags = BfResolveArgFlag_FromGeneric;
if (wantType->IsGenericParam())
flags = (BfResolveArgFlags)(flags | BfResolveArgFlag_FromGenericParam);
argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, flags);
}
if (!argTypedValue.IsUntypedValue())
{
auto type = argTypedValue.mType;
@ -1709,6 +1807,11 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
goto NoMatch;
}
if (!deferredArgs.IsEmpty())
{
genericInferContext.InferGenericArguments(methodInstance);
}
while (!deferredArgs.IsEmpty())
{
int prevDeferredSize = (int)deferredArgs.size();
@ -11504,7 +11607,12 @@ void BfExprEvaluator::VisitLambdaBodies(BfAstNode* body, BfFieldDtorDeclaration*
if (auto blockBody = BfNodeDynCast<BfBlock>(body))
mModule->VisitChild(blockBody);
else if (auto bodyExpr = BfNodeDynCast<BfExpression>(body))
mModule->CreateValueFromExpression(bodyExpr);
{
auto result = mModule->CreateValueFromExpression(bodyExpr);
if ((result) && (mModule->mCurMethodState->mClosureState != NULL) &&
(mModule->mCurMethodState->mClosureState->mReturnTypeInferState == BfReturnTypeInferState_Inferring))
mModule->mCurMethodState->mClosureState->mReturnType = result.mType;
}
while (fieldDtor != NULL)
{
@ -11530,8 +11638,10 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
}
}
bool isInferReturnType = (mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0;
BfLambdaInstance* lambdaInstance = NULL;
if (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance))
if ((!isInferReturnType) && (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance)))
return lambdaInstance;
static int sBindCount = 0;
@ -11683,7 +11793,11 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
return NULL;
}
if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
if ((lambdaBindExpr->mNewToken == NULL) && (isInferReturnType))
{
// Method ref, but let this follow infer route
}
else if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
{
if ((lambdaBindExpr->mNewToken != NULL) && (isFunctionBind))
mModule->Fail("Binds to functions should do not require allocations.", lambdaBindExpr->mNewToken);
@ -11951,8 +12065,14 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
closureState.mCaptureVisitingBody = true;
closureState.mClosureInstanceInfo = closureInstanceInfo;
VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor);
if ((mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0)
{
closureState.mReturnType = NULL;
closureState.mReturnTypeInferState = BfReturnTypeInferState_Inferring;
}
VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor);
if (hasExplicitCaptureNames)
_SetNotCapturedFlag(false);
@ -11964,12 +12084,32 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
prevClosureState->mCaptureStartAccessId = closureState.mCaptureStartAccessId;
}
if (mModule->mCurMethodInstance->mIsUnspecialized)
bool earlyExit = false;
if (isInferReturnType)
{
if ((closureState.mReturnTypeInferState == BfReturnTypeInferState_Fail) ||
(closureState.mReturnType == NULL))
{
mResult = BfTypedValue();
}
else
{
mResult = BfTypedValue(closureState.mReturnType);
}
earlyExit = true;
}
else if (mModule->mCurMethodInstance->mIsUnspecialized)
{
earlyExit = true;
mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
}
if (earlyExit)
{
prevIgnoreWrites.Restore();
mModule->mBfIRBuilder->RestoreDebugLocation();
mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
mModule->mBfIRBuilder->SetActiveFunction(prevActiveFunction);
if (!prevInsertBlock.IsFake())
mModule->mBfIRBuilder->SetInsertPoint(prevInsertBlock);
@ -11981,7 +12121,7 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
closureState.mCaptureVisitingBody = false;
prevIgnoreWrites.Restore();
mModule->mBfIRBuilder->RestoreDebugLocation();
mModule->mBfIRBuilder->RestoreDebugLocation();
auto _GetCaptureType = [&](const StringImpl& str)
{
@ -14189,34 +14329,10 @@ BfModuleMethodInstance BfExprEvaluator::GetSelectedMethod(BfAstNode* targetSrc,
if (genericArg == NULL)
{
// Attempt to infer from other generic args
for (int srcGenericIdx = 0; srcGenericIdx < (int)methodMatcher.mBestMethodGenericArguments.size(); srcGenericIdx++)
{
auto& srcGenericArg = methodMatcher.mBestMethodGenericArguments[srcGenericIdx];
if (srcGenericArg == NULL)
continue;
auto srcGenericParam = unspecializedMethod->mMethodInfoEx->mGenericParams[srcGenericIdx];
BfGenericInferContext genericInferContext;
genericInferContext.mModule = mModule;
genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments;
for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
{
if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
{
genericInferContext.InferGenericArgument(unspecializedMethod, srcGenericArg, ifaceConstraint, BfIRValue());
auto typeInstance = srcGenericArg->ToTypeInstance();
if (typeInstance != NULL)
{
for (auto ifaceEntry : typeInstance->mInterfaces)
genericInferContext.InferGenericArgument(unspecializedMethod, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
}
}
}
}
BfGenericInferContext genericInferContext;
genericInferContext.mModule = mModule;
genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments;
genericInferContext.InferGenericArguments(unspecializedMethod);
}
if (genericArg == NULL)