1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-14 14:24:10 +02:00

Fixes some handling of const payload enums cases

This commit is contained in:
Brian Fiete 2025-02-04 10:23:22 -08:00
parent 8b7dd19f4b
commit 568e54821d
2 changed files with 67 additions and 3 deletions

View file

@ -2897,6 +2897,7 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa
tagVal = GetDefaultTypedValue(tagType); tagVal = GetDefaultTypedValue(tagType);
} }
BfIRValue eqResult;
for (int fieldIdx = 0; fieldIdx < (int)enumType->mFieldInstances.size(); fieldIdx++) for (int fieldIdx = 0; fieldIdx < (int)enumType->mFieldInstances.size(); fieldIdx++)
{ {
auto fieldInstance = &enumType->mFieldInstances[fieldIdx]; auto fieldInstance = &enumType->mFieldInstances[fieldIdx];
@ -2935,6 +2936,18 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa
auto dscrType = enumType->GetDiscriminatorType(); auto dscrType = enumType->GetDiscriminatorType();
BfIRValue eqResult = mBfIRBuilder->CreateCmpEQ(tagVal.mValue, mBfIRBuilder->CreateConst(dscrType->mTypeDef->mTypeCode, tagId)); BfIRValue eqResult = mBfIRBuilder->CreateCmpEQ(tagVal.mValue, mBfIRBuilder->CreateConst(dscrType->mTypeDef->mTypeCode, tagId));
bool isConstMatch = false;
bool isConstIgnore = false;
if (auto constant = mBfIRBuilder->GetConstant(eqResult))
{
isConstMatch = constant->mBool;
isConstIgnore = !constant->mBool;
}
SetAndRestoreValue<bool> prevIgnoreWrite(mBfIRBuilder->mIgnoreWrites);
if (isConstIgnore)
mBfIRBuilder->mIgnoreWrites = true;
BfIRBlock falseBlockStart; BfIRBlock falseBlockStart;
BfIRBlock falseBlockEnd; BfIRBlock falseBlockEnd;
BfIRBlock doneBlockStart; BfIRBlock doneBlockStart;
@ -2954,13 +2967,17 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa
BfIRBlock matchedBlockEnd = matchedBlockStart; BfIRBlock matchedBlockEnd = matchedBlockStart;
if (matchBlock != NULL) if (matchBlock != NULL)
*matchBlock = matchedBlockStart; *matchBlock = matchedBlockStart;
mBfIRBuilder->CreateCondBr(eqResult, matchedBlockStart, falseBlockStart ? falseBlockStart : doneBlockStart);
if (isConstMatch)
mBfIRBuilder->CreateBr(matchedBlockStart);
else
mBfIRBuilder->CreateCondBr(eqResult, matchedBlockStart, falseBlockStart ? falseBlockStart : doneBlockStart);
mBfIRBuilder->AddBlock(matchedBlockStart); mBfIRBuilder->AddBlock(matchedBlockStart);
mBfIRBuilder->SetInsertPoint(doneBlockEnd); mBfIRBuilder->SetInsertPoint(doneBlockEnd);
BfIRValue phiVal; BfIRValue phiVal;
if (eqBlock == NULL) if ((eqBlock == NULL) && (!isConstIgnore) && (!isConstMatch))
phiVal = mBfIRBuilder->CreatePhi(mBfIRBuilder->MapType(boolType), 1 + (int)tupleType->mFieldInstances.size()); phiVal = mBfIRBuilder->CreatePhi(mBfIRBuilder->MapType(boolType), 1 + (int)tupleType->mFieldInstances.size());
mBfIRBuilder->SetInsertPoint(matchedBlockEnd); mBfIRBuilder->SetInsertPoint(matchedBlockEnd);
@ -3102,6 +3119,8 @@ BfTypedValue BfModule::TryCaseEnumMatch(BfTypedValue enumVal, BfTypedValue tagVa
if (phiVal) if (phiVal)
return BfTypedValue(phiVal, boolType); return BfTypedValue(phiVal, boolType);
else if (eqResult)
return BfTypedValue(eqResult, boolType);
else else
return GetDefaultTypedValue(boolType); return GetDefaultTypedValue(boolType);
} }
@ -4789,6 +4808,11 @@ void BfModule::Visit(BfSwitchStatement* switchStmt)
else else
{ {
eqTypedResult = TryCaseEnumMatch(switchValueAddr, enumTagVal, caseExpr, &caseBlock, &notEqBB, &matchBlock, tagId, hadConditional, false, prevHadFallthrough); eqTypedResult = TryCaseEnumMatch(switchValueAddr, enumTagVal, caseExpr, &caseBlock, &notEqBB, &matchBlock, tagId, hadConditional, false, prevHadFallthrough);
if (auto constant = mBfIRBuilder->GetConstant(eqTypedResult.mValue))
{
if (constant->mBool)
mayHaveMatch = true;
}
if (hadConditional) if (hadConditional)
hadCondCase = true; hadCondCase = true;
} }
@ -4918,7 +4942,8 @@ void BfModule::Visit(BfSwitchStatement* switchStmt)
BfExprEvaluator exprEvaluator(this); BfExprEvaluator exprEvaluator(this);
BfAstNode* refNode = switchCase->mColonToken; BfAstNode* refNode = switchCase->mColonToken;
if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType) && (isEnumDescValue)) if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType) &&
((isEnumDescValue) || (constantInt != NULL)))
{ {
if (!enumTagVal) if (!enumTagVal)
{ {

View file

@ -34,6 +34,13 @@ namespace Tests
} }
} }
enum ETest
{
case A(int a);
case B(float f);
case C;
}
[Test] [Test]
public static void TestBasics() public static void TestBasics()
{ {
@ -65,6 +72,38 @@ namespace Tests
result = 4; result = 4;
} }
Test.Assert(result == 4); Test.Assert(result == 4);
result = 0;
const int constVal = 123;
switch (constVal)
{
case 10:
result = 1;
case 123:
result = 2;
default:
result = 3;
}
Test.Assert(result == 2);
result = 99;
const Result<int> iResult = .Err;
bool eq = iResult case .Ok(ref result);
Test.Assert(result == 99);
const ETest t = .B(234.5f);
switch (t)
{
case .A(let a):
result = 1;
case .B(let b):
result = (.)b;
case .C:
result = 3;
default:
result = 4;
}
Test.Assert(result == 234);
} }
} }
} }