diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index dac8d336..7fcbfffb 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -3797,26 +3797,8 @@ void BfExprEvaluator::Visit(BfVariableDeclaration* varDecl) mModule->HandleVariableDeclaration(varDecl, this); } -void BfExprEvaluator::Visit(BfCaseExpression* caseExpr) +void BfExprEvaluator::DoCaseExpression(BfTypedValue caseValAddr, BfCaseExpression* caseExpr) { - if (caseExpr->mEqualsNode != NULL) - { - mModule->Warn(0, "Deprecated case syntax", caseExpr->mEqualsNode); - } - - BfTypedValue caseValAddr; - if (caseExpr->mValueExpression != NULL) - caseValAddr = mModule->CreateValueFromExpression(caseExpr->mValueExpression, NULL, (BfEvalExprFlags)(mBfEvalExprFlags & BfEvalExprFlags_InheritFlags)); - - if ((caseValAddr.mType != NULL) && (caseValAddr.mType->IsPointer())) - { - caseValAddr = mModule->LoadValue(caseValAddr); - caseValAddr = BfTypedValue(caseValAddr.mValue, caseValAddr.mType->GetUnderlyingType(), true); - } - - if (caseValAddr.mType != NULL) - mModule->mBfIRBuilder->PopulateType(caseValAddr.mType); - if ((mModule->mCurMethodState != NULL) && (mModule->mCurMethodState->mDeferredLocalAssignData != NULL)) mModule->mCurMethodState->mDeferredLocalAssignData->BreakExtendChain(); @@ -3839,6 +3821,7 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr) } } + auto boolType = mModule->GetPrimitiveType(BfTypeCode_Boolean); bool isPayloadEnum = (caseValAddr.mType != NULL) && (caseValAddr.mType->IsPayloadEnum()); auto tupleExpr = BfNodeDynCast(caseExpr->mCaseExpression); @@ -3875,7 +3858,7 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr) // An example of requiring clearing is: if ((result case .Ok(out val)) || (force)) if (hasOut) clearOutOnMismatch = !CheckVariableDeclaration(caseExpr, true, true, true); - + bool hadConditional = false; if (isPayloadEnum) { @@ -3918,10 +3901,9 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr) return; } - auto boolType = mModule->GetPrimitiveType(BfTypeCode_Boolean); BfTypedValue caseMatch; if (caseExpr->mCaseExpression != NULL) - caseMatch = mModule->CreateValueFromExpression(caseExpr->mCaseExpression, caseValAddr.mType, BfEvalExprFlags_AllowEnumId); + caseMatch = mModule->CreateValueFromExpression(caseExpr->mCaseExpression, caseValAddr.mType, BfEvalExprFlags_AllowEnumId); if ((!caseMatch) || (!caseValAddr)) { mResult = mModule->GetDefaultTypedValue(boolType); @@ -3952,6 +3934,79 @@ void BfExprEvaluator::Visit(BfCaseExpression* caseExpr) PerformBinaryOperation(caseExpr->mCaseExpression, caseExpr->mValueExpression, BfBinaryOp_Equality, caseExpr->mEqualsNode, BfBinOpFlag_None, caseValAddr, caseMatch); } +void BfExprEvaluator::Visit(BfCaseExpression* caseExpr) +{ + if (caseExpr->mEqualsNode != NULL) + { + mModule->Warn(0, "Deprecated case syntax", caseExpr->mEqualsNode); + } + + auto boolType = mModule->GetPrimitiveType(BfTypeCode_Boolean); + BfTypedValue caseValAddr; + if (caseExpr->mValueExpression != NULL) + caseValAddr = mModule->CreateValueFromExpression(caseExpr->mValueExpression, NULL, (BfEvalExprFlags)(mBfEvalExprFlags & BfEvalExprFlags_InheritFlags)); + + if ((caseValAddr.mType != NULL) && (caseValAddr.mType->IsPointer())) + { + caseValAddr = mModule->LoadValue(caseValAddr); + caseValAddr = BfTypedValue(caseValAddr.mValue, caseValAddr.mType->GetUnderlyingType(), true); + } + + BfIRValue hasValueValue; + if (caseValAddr.mType != NULL) + mModule->mBfIRBuilder->PopulateType(caseValAddr.mType); + + if ((caseValAddr.mType != NULL) && (caseValAddr.mType->IsNullable())) + { + auto nullableElementType = caseValAddr.mType->GetUnderlyingType(); + hasValueValue = mModule->ExtractValue(caseValAddr, nullableElementType->IsValuelessType() ? 1 : 2); + + if (!nullableElementType->IsValuelessType()) + caseValAddr = BfTypedValue(mModule->ExtractValue(caseValAddr, 1), nullableElementType); // value + else + caseValAddr = BfTypedValue(mModule->mBfIRBuilder->GetFakeVal(), nullableElementType); + } + + BfIRBlock nullBB; + BfIRBlock endBB; + + if (hasValueValue) + { + auto caseBB = mModule->mBfIRBuilder->CreateBlock("caseexpr.case"); + endBB = mModule->mBfIRBuilder->CreateBlock("caseexpr.end"); + + mModule->mBfIRBuilder->CreateCondBr(hasValueValue, caseBB, endBB); + nullBB = mModule->mBfIRBuilder->GetInsertBlock(); + + mModule->mBfIRBuilder->AddBlock(caseBB); + mModule->mBfIRBuilder->SetInsertPoint(caseBB); + } + + DoCaseExpression(caseValAddr, caseExpr); + + if (!mResult) + mResult = mModule->GetDefaultTypedValue(boolType); + else + { + BF_ASSERT(mResult.mType == boolType); + } + + if (hasValueValue) + { + auto endCaseBB = mModule->mBfIRBuilder->GetInsertBlock(); + mModule->mBfIRBuilder->CreateBr(endBB); + + mModule->mBfIRBuilder->AddBlock(endBB); + mModule->mBfIRBuilder->SetInsertPoint(endBB); + + auto phiValue = mModule->mBfIRBuilder->CreatePhi(mModule->mBfIRBuilder->MapType(boolType), 2); + mModule->mBfIRBuilder->AddPhiIncoming(phiValue, mModule->GetDefaultValue(boolType), nullBB); + mModule->mBfIRBuilder->AddPhiIncoming(phiValue, mResult.mValue, endCaseBB); + + mResult = BfTypedValue(phiValue, boolType); + } +} + void BfExprEvaluator::Visit(BfTypedValueExpression* typedValueExpr) { mResult = typedValueExpr->mTypedValue; diff --git a/IDEHelper/Compiler/BfExprEvaluator.h b/IDEHelper/Compiler/BfExprEvaluator.h index b648030d..0911d924 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.h +++ b/IDEHelper/Compiler/BfExprEvaluator.h @@ -510,6 +510,7 @@ public: bool HasVariableDeclaration(BfAstNode* checkNode); void DoInvocation(BfInvocationExpression* invocationExpr); void DoInvocation(BfAstNode* target, BfMethodBoundExpression* methodBoundExpr, const BfSizedArray& args, const BfMethodGenericArguments& methodGenericArgs, BfTypedValue* outCascadeValue = NULL); + void DoCaseExpression(BfTypedValue caseValAddr, BfCaseExpression* caseExpr); int GetMixinVariable(); void CheckLocalMethods(BfAstNode* targetSrc, BfTypeInstance* typeInstance, const StringImpl& methodName, BfMethodMatcher& methodMatcher, BfMethodType methodType); void InjectMixin(BfAstNode* targetSrc, BfTypedValue target, bool allowImplicitThis, const StringImpl& name, const BfSizedArray& arguments, const BfMethodGenericArguments& methodGenericArgs); diff --git a/IDEHelper/Tests/src/Nullable.bf b/IDEHelper/Tests/src/Nullable.bf index 237d0542..120abbd4 100644 --- a/IDEHelper/Tests/src/Nullable.bf +++ b/IDEHelper/Tests/src/Nullable.bf @@ -114,6 +114,24 @@ namespace Tests iNull ??= iNull2; Test.Assert(iNull == 123); + + Result ir = .Ok(234); + Result? irn = ir; + + if (irn case .Ok(let val)) + { + Test.Assert(val == 234); + } + else + { + Test.FatalError(); + } + + irn = null; + if (irn case .Ok(let val)) + { + Test.FatalError(); + } } } }