diff --git a/IDEHelper/Backend/BeIRCodeGen.cpp b/IDEHelper/Backend/BeIRCodeGen.cpp index 64d09dbb..58ac930d 100644 --- a/IDEHelper/Backend/BeIRCodeGen.cpp +++ b/IDEHelper/Backend/BeIRCodeGen.cpp @@ -1386,6 +1386,8 @@ void BeIRCodeGen::HandleNextCmd() { CMD_PARAM(BeValue*, lhs); CMD_PARAM(BeValue*, rhs); + if (lhs->GetType() != rhs->GetType()) + Fail("Type mismatch for CmpEQ"); SetResult(curId, mBeModule->CreateCmp(BeCmpKind_EQ, lhs, rhs)); } break; diff --git a/IDEHelper/Compiler/BfStmtEvaluator.cpp b/IDEHelper/Compiler/BfStmtEvaluator.cpp index 57a77ad9..04586c5e 100644 --- a/IDEHelper/Compiler/BfStmtEvaluator.cpp +++ b/IDEHelper/Compiler/BfStmtEvaluator.cpp @@ -4707,6 +4707,7 @@ void BfModule::Visit(BfSwitchStatement* switchStmt) BfTypedValue caseValue; BfIRBlock doBlock = caseBlock; bool hadConditional = false; + bool isEnumDescValue = isPayloadEnum; if (isPayloadEnum) { auto dscrType = switchValue.mType->ToTypeInstance()->GetDiscriminatorType(); @@ -4766,6 +4767,7 @@ void BfModule::Visit(BfSwitchStatement* switchStmt) caseValue = CreateValueFromExpression(caseExpr, switchValue.mType, (BfEvalExprFlags)(BfEvalExprFlags_AllowEnumId | BfEvalExprFlags_NoCast)); if (!caseValue) continue; + isEnumDescValue = false; } BfTypedValue caseIntVal = caseValue; @@ -4835,7 +4837,7 @@ void BfModule::Visit(BfSwitchStatement* switchStmt) BfExprEvaluator exprEvaluator(this); BfAstNode* refNode = switchCase->mColonToken; - if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType)) + if ((caseValue.mType->IsPayloadEnum()) && (caseValue.mValue.IsConst()) && (switchValue.mType == caseValue.mType) && (isEnumDescValue)) { if (!enumTagVal) { diff --git a/IDEHelper/Tests/CLib/main.cpp b/IDEHelper/Tests/CLib/main.cpp index f5692d5c..b846611f 100644 --- a/IDEHelper/Tests/CLib/main.cpp +++ b/IDEHelper/Tests/CLib/main.cpp @@ -297,6 +297,13 @@ namespace Tests { float mX; }; + + struct StructX + { + int mA; + int mB; + char* mC; + }; }; int Interop::StructA::sVal = 1234; @@ -492,6 +499,15 @@ extern "C" Interop::StructI Func2I(Interop::StructI arg0, int arg1) return ret; } +extern "C" Interop::StructX Func2X(Interop::StructX arg0, int arg1) +{ + Interop::StructX ret; + ret.mA = arg0.mA + arg1; + ret.mB = arg0.mB; + ret.mC = arg0.mC + 1; + return ret; +} + ////////////////////////////////////////////////////////////////////////// extern "C" int Func3A(Interop::StructA* ptr) diff --git a/IDEHelper/Tests/src/Enums.bf b/IDEHelper/Tests/src/Enums.bf index 4492d688..44f014de 100644 --- a/IDEHelper/Tests/src/Enums.bf +++ b/IDEHelper/Tests/src/Enums.bf @@ -84,6 +84,34 @@ namespace Tests Test.Assert(a == 1); Test.Assert(b == 2); + ee = default; + switch (ee) + { + case default: + Test.Assert(true); + default: + Test.Assert(false); + } + + ee = .B(123); + switch (ee) + { + case default: + Test.Assert(false); + default: + Test.Assert(true); + } + + switch (ee) + { + case .B(100): + Test.Assert(false); + case .B(123): + Test.Assert(true); + default: + Test.Assert(false); + } + EnumF ef = .EE(.C(3, 4)); switch (ef) { diff --git a/IDEHelper/Tests/src/Interop.bf b/IDEHelper/Tests/src/Interop.bf index 141aa49b..5e243f12 100644 --- a/IDEHelper/Tests/src/Interop.bf +++ b/IDEHelper/Tests/src/Interop.bf @@ -239,6 +239,14 @@ namespace Tests public float mX; } + [CRepr] + public struct StructX + { + public int32 mA; + public int32 mB; + public char8* mC; + } + [LinkName(.C)] public static extern int32 Func0(int32 a, int32 b); [LinkName(.C)] @@ -307,6 +315,8 @@ namespace Tests public static extern StructH Func2H(StructH arg0, int32 arg2); [LinkName(.C)] public static extern StructI Func2I(StructI arg0, int32 arg2); + [LinkName(.C)] + public static extern StructX Func2X(StructX arg0, int32 arg2); [LinkName(.C)] public static extern StructJ Func4J(StructJ arg0, StructJ arg1, StructJ arg2, StructJ arg3); @@ -393,6 +403,7 @@ namespace Tests StructU su = .() { mK = .(){mX = 3, mY = 4}}; StructV sv = .() { mX = 3, mY = 4}; StructW sw = .() { mX = 3 }; + StructX sx = .() { mA = 3, mB = 4, mC = "ABCD" }; void StartTest(String str) { @@ -510,6 +521,11 @@ namespace Tests Test.Assert(si0.MethodI1(si1, 12).mA == (int8)193); Test.Assert(Func2I(si0, 12).mA == 102); + /*var sx1 = Func2X(sx, 100); + Test.Assert(sx1.mA == 103); + Test.Assert(sx1.mB == 4); + Test.Assert(sx1.mC == (char8*)"BCD"+1);*/ + StructJ sj0; sj0.mPtr = "ABC"; sj0.mLength = 3;