diff --git a/IDEHelper/Compiler/BfStmtEvaluator.cpp b/IDEHelper/Compiler/BfStmtEvaluator.cpp index 425bf643..babb4991 100644 --- a/IDEHelper/Compiler/BfStmtEvaluator.cpp +++ b/IDEHelper/Compiler/BfStmtEvaluator.cpp @@ -2897,6 +2897,7 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa tagVal = GetDefaultTypedValue(tagType); } + BfIRValue eqResult; for (int fieldIdx = 0; fieldIdx < (int)enumType->mFieldInstances.size(); fieldIdx++) { auto fieldInstance = &enumType->mFieldInstances[fieldIdx]; @@ -2935,6 +2936,18 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa auto dscrType = enumType->GetDiscriminatorType(); BfIRValue eqResult = mBfIRBuilder->CreateCmpEQ(tagVal.mValue, mBfIRBuilder->CreateConst(dscrType->mTypeDef->mTypeCode, tagId)); + bool isConstMatch = false; + bool isConstIgnore = false; + if (auto constant = mBfIRBuilder->GetConstant(eqResult)) + { + isConstMatch = constant->mBool; + isConstIgnore = !constant->mBool; + } + + SetAndRestoreValue prevIgnoreWrite(mBfIRBuilder->mIgnoreWrites); + if (isConstIgnore) + mBfIRBuilder->mIgnoreWrites = true; + BfIRBlock falseBlockStart; BfIRBlock falseBlockEnd; BfIRBlock doneBlockStart; @@ -2954,13 +2967,17 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa BfIRBlock matchedBlockEnd = matchedBlockStart; if (matchBlock != NULL) *matchBlock = matchedBlockStart; - mBfIRBuilder->CreateCondBr(eqResult, matchedBlockStart, falseBlockStart ? falseBlockStart : doneBlockStart); + + if (isConstMatch) + mBfIRBuilder->CreateBr(matchedBlockStart); + else + mBfIRBuilder->CreateCondBr(eqResult, matchedBlockStart, falseBlockStart ? falseBlockStart : doneBlockStart); mBfIRBuilder->AddBlock(matchedBlockStart); mBfIRBuilder->SetInsertPoint(doneBlockEnd); BfIRValue phiVal; - if (eqBlock == NULL) + if ((eqBlock == NULL) && (!isConstIgnore) && (!isConstMatch)) phiVal = mBfIRBuilder->CreatePhi(mBfIRBuilder->MapType(boolType), 1 + (int)tupleType->mFieldInstances.size()); mBfIRBuilder->SetInsertPoint(matchedBlockEnd); @@ -3102,6 +3119,8 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa if (phiVal) return BfTypedValue(phiVal, boolType); + else if (eqResult) + return BfTypedValue(eqResult, boolType); else return GetDefaultTypedValue(boolType); } @@ -4789,6 +4808,11 @@ void BfModule::Visit(BfSwitchStatement* switchStmt) else { eqTypedResult = TryCaseEnumMatch(switchValueAddr, enumTagVal, caseExpr, &caseBlock, ¬EqBB, &matchBlock, tagId, hadConditional, false, prevHadFallthrough); + if (auto constant = mBfIRBuilder->GetConstant(eqTypedResult.mValue)) + { + if (constant->mBool) + mayHaveMatch = true; + } if (hadConditional) hadCondCase = true; } @@ -4918,7 +4942,8 @@ void BfModule::Visit(BfSwitchStatement* switchStmt) BfExprEvaluator exprEvaluator(this); BfAstNode* refNode = switchCase->mColonToken; - if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType) && (isEnumDescValue)) + if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType) && + ((isEnumDescValue) || (constantInt != NULL))) { if (!enumTagVal) { diff --git a/IDEHelper/Tests/src/Switches.bf b/IDEHelper/Tests/src/Switches.bf index 22d101fc..b6949317 100644 --- a/IDEHelper/Tests/src/Switches.bf +++ b/IDEHelper/Tests/src/Switches.bf @@ -34,6 +34,13 @@ namespace Tests } } + enum ETest + { + case A(int a); + case B(float f); + case C; + } + [Test] public static void TestBasics() { @@ -65,6 +72,38 @@ namespace Tests result = 4; } Test.Assert(result == 4); + + result = 0; + const int constVal = 123; + switch (constVal) + { + case 10: + result = 1; + case 123: + result = 2; + default: + result = 3; + } + Test.Assert(result == 2); + + result = 99; + const Result iResult = .Err; + bool eq = iResult case .Ok(ref result); + Test.Assert(result == 99); + + const ETest t = .B(234.5f); + switch (t) + { + case .A(let a): + result = 1; + case .B(let b): + result = (.)b; + case .C: + result = 3; + default: + result = 4; + } + Test.Assert(result == 234); } } }