1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-08 03:28:20 +02:00

Fixed payload enum switch case comparison

This commit is contained in:
Brian Fiete 2023-07-24 10:32:31 -07:00
parent 52f746aae9
commit 11bde5caf2
5 changed files with 65 additions and 1 deletions

View file

@ -1386,6 +1386,8 @@ void BeIRCodeGen::HandleNextCmd()
{ {
CMD_PARAM(BeValue*, lhs); CMD_PARAM(BeValue*, lhs);
CMD_PARAM(BeValue*, rhs); CMD_PARAM(BeValue*, rhs);
if (lhs->GetType() != rhs->GetType())
Fail("Type mismatch for CmpEQ");
SetResult(curId, mBeModule->CreateCmp(BeCmpKind_EQ, lhs, rhs)); SetResult(curId, mBeModule->CreateCmp(BeCmpKind_EQ, lhs, rhs));
} }
break; break;

View file

@ -4707,6 +4707,7 @@ void BfModule::Visit(BfSwitchStatement* switchStmt)
BfTypedValue caseValue; BfTypedValue caseValue;
BfIRBlock doBlock = caseBlock; BfIRBlock doBlock = caseBlock;
bool hadConditional = false; bool hadConditional = false;
bool isEnumDescValue = isPayloadEnum;
if (isPayloadEnum) if (isPayloadEnum)
{ {
auto dscrType = switchValue.mType->ToTypeInstance()->GetDiscriminatorType(); auto dscrType = switchValue.mType->ToTypeInstance()->GetDiscriminatorType();
@ -4766,6 +4767,7 @@ void BfModule::Visit(BfSwitchStatement* switchStmt)
caseValue = CreateValueFromExpression(caseExpr, switchValue.mType, (BfEvalExprFlags)(BfEvalExprFlags_AllowEnumId | BfEvalExprFlags_NoCast)); caseValue = CreateValueFromExpression(caseExpr, switchValue.mType, (BfEvalExprFlags)(BfEvalExprFlags_AllowEnumId | BfEvalExprFlags_NoCast));
if (!caseValue) if (!caseValue)
continue; continue;
isEnumDescValue = false;
} }
BfTypedValue caseIntVal = caseValue; BfTypedValue caseIntVal = caseValue;
@ -4835,7 +4837,7 @@ void BfModule::Visit(BfSwitchStatement* switchStmt)
BfExprEvaluator exprEvaluator(this); BfExprEvaluator exprEvaluator(this);
BfAstNode* refNode = switchCase->mColonToken; BfAstNode* refNode = switchCase->mColonToken;
if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType)) if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType) && (isEnumDescValue))
{ {
if (!enumTagVal) if (!enumTagVal)
{ {

View file

@ -297,6 +297,13 @@ namespace Tests
{ {
float mX; float mX;
}; };
struct StructX
{
int mA;
int mB;
char* mC;
};
}; };
int Interop::StructA::sVal = 1234; int Interop::StructA::sVal = 1234;
@ -492,6 +499,15 @@ extern "C" Interop::StructI Func2I(Interop::StructI arg0, int arg1)
return ret; return ret;
} }
extern "C" Interop::StructX Func2X(Interop::StructX arg0, int arg1)
{
Interop::StructX ret;
ret.mA = arg0.mA + arg1;
ret.mB = arg0.mB;
ret.mC = arg0.mC + 1;
return ret;
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
extern "C" int Func3A(Interop::StructA* ptr) extern "C" int Func3A(Interop::StructA* ptr)

View file

@ -84,6 +84,34 @@ namespace Tests
Test.Assert(a == 1); Test.Assert(a == 1);
Test.Assert(b == 2); Test.Assert(b == 2);
ee = default;
switch (ee)
{
case default:
Test.Assert(true);
default:
Test.Assert(false);
}
ee = .B(123);
switch (ee)
{
case default:
Test.Assert(false);
default:
Test.Assert(true);
}
switch (ee)
{
case .B(100):
Test.Assert(false);
case .B(123):
Test.Assert(true);
default:
Test.Assert(false);
}
EnumF ef = .EE(.C(3, 4)); EnumF ef = .EE(.C(3, 4));
switch (ef) switch (ef)
{ {

View file

@ -239,6 +239,14 @@ namespace Tests
public float mX; public float mX;
} }
[CRepr]
public struct StructX
{
public int32 mA;
public int32 mB;
public char8* mC;
}
[LinkName(.C)] [LinkName(.C)]
public static extern int32 Func0(int32 a, int32 b); public static extern int32 Func0(int32 a, int32 b);
[LinkName(.C)] [LinkName(.C)]
@ -307,6 +315,8 @@ namespace Tests
public static extern StructH Func2H(StructH arg0, int32 arg2); public static extern StructH Func2H(StructH arg0, int32 arg2);
[LinkName(.C)] [LinkName(.C)]
public static extern StructI Func2I(StructI arg0, int32 arg2); public static extern StructI Func2I(StructI arg0, int32 arg2);
[LinkName(.C)]
public static extern StructX Func2X(StructX arg0, int32 arg2);
[LinkName(.C)] [LinkName(.C)]
public static extern StructJ Func4J(StructJ arg0, StructJ arg1, StructJ arg2, StructJ arg3); public static extern StructJ Func4J(StructJ arg0, StructJ arg1, StructJ arg2, StructJ arg3);
@ -393,6 +403,7 @@ namespace Tests
StructU su = .() { mK = .(){mX = 3, mY = 4}}; StructU su = .() { mK = .(){mX = 3, mY = 4}};
StructV sv = .() { mX = 3, mY = 4}; StructV sv = .() { mX = 3, mY = 4};
StructW sw = .() { mX = 3 }; StructW sw = .() { mX = 3 };
StructX sx = .() { mA = 3, mB = 4, mC = "ABCD" };
void StartTest(String str) void StartTest(String str)
{ {
@ -510,6 +521,11 @@ namespace Tests
Test.Assert(si0.MethodI1(si1, 12).mA == (int8)193); Test.Assert(si0.MethodI1(si1, 12).mA == (int8)193);
Test.Assert(Func2I(si0, 12).mA == 102); Test.Assert(Func2I(si0, 12).mA == 102);
/*var sx1 = Func2X(sx, 100);
Test.Assert(sx1.mA == 103);
Test.Assert(sx1.mB == 4);
Test.Assert(sx1.mC == (char8*)"BCD"+1);*/
StructJ sj0; StructJ sj0;
sj0.mPtr = "ABC"; sj0.mPtr = "ABC";
sj0.mLength = 3; sj0.mLength = 3;