1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-08 11:38:21 +02:00

Make case expression work with nullable Result<T>

This commit is contained in:
Brian Fiete 2025-01-31 10:15:43 -08:00
parent 206023f4a6
commit 319755ca36
3 changed files with 96 additions and 22 deletions

View file

@ -3797,26 +3797,8 @@ void BfExprEvaluator::Visit(BfVariableDeclaration* varDecl)
mModule->HandleVariableDeclaration(varDecl, this);
}
void BfExprEvaluator::Visit(BfCaseExpression* caseExpr)
void BfExprEvaluator::DoCaseExpression(BfTypedValue caseValAddr, BfCaseExpression* caseExpr)
{
if (caseExpr->mEqualsNode != NULL)
{
mModule->Warn(0, "Deprecated case syntax", caseExpr->mEqualsNode);
}
BfTypedValue caseValAddr;
if (caseExpr->mValueExpression != NULL)
caseValAddr = mModule->CreateValueFromExpression(caseExpr->mValueExpression, NULL, (BfEvalExprFlags)(mBfEvalExprFlags & BfEvalExprFlags_InheritFlags));
if ((caseValAddr.mType != NULL) && (caseValAddr.mType->IsPointer()))
{
caseValAddr = mModule->LoadValue(caseValAddr);
caseValAddr = BfTypedValue(caseValAddr.mValue, caseValAddr.mType->GetUnderlyingType(), true);
}
if (caseValAddr.mType != NULL)
mModule->mBfIRBuilder->PopulateType(caseValAddr.mType);
if ((mModule->mCurMethodState != NULL) && (mModule->mCurMethodState->mDeferredLocalAssignData != NULL))
mModule->mCurMethodState->mDeferredLocalAssignData->BreakExtendChain();
@ -3839,6 +3821,7 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr)
}
}
auto boolType = mModule->GetPrimitiveType(BfTypeCode_Boolean);
bool isPayloadEnum = (caseValAddr.mType != NULL) && (caseValAddr.mType->IsPayloadEnum());
auto tupleExpr = BfNodeDynCast<BfTupleExpression>(caseExpr->mCaseExpression);
@ -3875,7 +3858,7 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr)
// An example of requiring clearing is: if ((result case .Ok(out val)) || (force))
if (hasOut)
clearOutOnMismatch = !CheckVariableDeclaration(caseExpr, true, true, true);
bool hadConditional = false;
if (isPayloadEnum)
{
@ -3918,10 +3901,9 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr)
return;
}
auto boolType = mModule->GetPrimitiveType(BfTypeCode_Boolean);
BfTypedValue caseMatch;
if (caseExpr->mCaseExpression != NULL)
caseMatch = mModule->CreateValueFromExpression(caseExpr->mCaseExpression, caseValAddr.mType, BfEvalExprFlags_AllowEnumId);
caseMatch = mModule->CreateValueFromExpression(caseExpr->mCaseExpression, caseValAddr.mType, BfEvalExprFlags_AllowEnumId);
if ((!caseMatch) || (!caseValAddr))
{
mResult = mModule->GetDefaultTypedValue(boolType);
@ -3952,6 +3934,79 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr)
PerformBinaryOperation(caseExpr->mCaseExpression, caseExpr->mValueExpression, BfBinaryOp_Equality, caseExpr->mEqualsNode, BfBinOpFlag_None, caseValAddr, caseMatch);
}
void BfExprEvaluator::Visit(BfCaseExpression* caseExpr)
{
if (caseExpr->mEqualsNode != NULL)
{
mModule->Warn(0, "Deprecated case syntax", caseExpr->mEqualsNode);
}
auto boolType = mModule->GetPrimitiveType(BfTypeCode_Boolean);
BfTypedValue caseValAddr;
if (caseExpr->mValueExpression != NULL)
caseValAddr = mModule->CreateValueFromExpression(caseExpr->mValueExpression, NULL, (BfEvalExprFlags)(mBfEvalExprFlags & BfEvalExprFlags_InheritFlags));
if ((caseValAddr.mType != NULL) && (caseValAddr.mType->IsPointer()))
{
caseValAddr = mModule->LoadValue(caseValAddr);
caseValAddr = BfTypedValue(caseValAddr.mValue, caseValAddr.mType->GetUnderlyingType(), true);
}
BfIRValue hasValueValue;
if (caseValAddr.mType != NULL)
mModule->mBfIRBuilder->PopulateType(caseValAddr.mType);
if ((caseValAddr.mType != NULL) && (caseValAddr.mType->IsNullable()))
{
auto nullableElementType = caseValAddr.mType->GetUnderlyingType();
hasValueValue = mModule->ExtractValue(caseValAddr, nullableElementType->IsValuelessType() ? 1 : 2);
if (!nullableElementType->IsValuelessType())
caseValAddr = BfTypedValue(mModule->ExtractValue(caseValAddr, 1), nullableElementType); // value
else
caseValAddr = BfTypedValue(mModule->mBfIRBuilder->GetFakeVal(), nullableElementType);
}
BfIRBlock nullBB;
BfIRBlock endBB;
if (hasValueValue)
{
auto caseBB = mModule->mBfIRBuilder->CreateBlock("caseexpr.case");
endBB = mModule->mBfIRBuilder->CreateBlock("caseexpr.end");
mModule->mBfIRBuilder->CreateCondBr(hasValueValue, caseBB, endBB);
nullBB = mModule->mBfIRBuilder->GetInsertBlock();
mModule->mBfIRBuilder->AddBlock(caseBB);
mModule->mBfIRBuilder->SetInsertPoint(caseBB);
}
DoCaseExpression(caseValAddr, caseExpr);
if (!mResult)
mResult = mModule->GetDefaultTypedValue(boolType);
else
{
BF_ASSERT(mResult.mType == boolType);
}
if (hasValueValue)
{
auto endCaseBB = mModule->mBfIRBuilder->GetInsertBlock();
mModule->mBfIRBuilder->CreateBr(endBB);
mModule->mBfIRBuilder->AddBlock(endBB);
mModule->mBfIRBuilder->SetInsertPoint(endBB);
auto phiValue = mModule->mBfIRBuilder->CreatePhi(mModule->mBfIRBuilder->MapType(boolType), 2);
mModule->mBfIRBuilder->AddPhiIncoming(phiValue, mModule->GetDefaultValue(boolType), nullBB);
mModule->mBfIRBuilder->AddPhiIncoming(phiValue, mResult.mValue, endCaseBB);
mResult = BfTypedValue(phiValue, boolType);
}
}
void BfExprEvaluator::Visit(BfTypedValueExpression* typedValueExpr)
{
mResult = typedValueExpr->mTypedValue;

View file

@ -510,6 +510,7 @@ public:
bool HasVariableDeclaration(BfAstNode* checkNode);
void DoInvocation(BfInvocationExpression* invocationExpr);
void DoInvocation(BfAstNode* target, BfMethodBoundExpression* methodBoundExpr, const BfSizedArray<BfExpression*>& args, const BfMethodGenericArguments& methodGenericArgs, BfTypedValue* outCascadeValue = NULL);
void DoCaseExpression(BfTypedValue caseValAddr, BfCaseExpression* caseExpr);
int GetMixinVariable();
void CheckLocalMethods(BfAstNode* targetSrc, BfTypeInstance* typeInstance, const StringImpl& methodName, BfMethodMatcher& methodMatcher, BfMethodType methodType);
void InjectMixin(BfAstNode* targetSrc, BfTypedValue target, bool allowImplicitThis, const StringImpl& name, const BfSizedArray<BfExpression*>& arguments, const BfMethodGenericArguments& methodGenericArgs);

View file

@ -114,6 +114,24 @@ namespace Tests
iNull ??= iNull2;
Test.Assert(iNull == 123);
Result<int32> ir = .Ok(234);
Result<int32>? irn = ir;
if (irn case .Ok(let val))
{
Test.Assert(val == 234);
}
else
{
Test.FatalError();
}
irn = null;
if (irn case .Ok(let val))
{
Test.FatalError();
}
}
}
}