1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-08 11:38:21 +02:00

Fixes NaN comparisons

This commit is contained in:
Brian Fiete 2024-10-16 13:25:17 -04:00
parent 7c0293620a
commit 55298ee884
8 changed files with 256 additions and 38 deletions

View file

@ -2106,6 +2106,8 @@ String BeMCContext::ToString(const BeMCOperand& operand)
if (cmpResult.mResultVRegIdx != -1)
result += StrFormat("<vreg%d>", cmpResult.mResultVRegIdx);
result += " ";
if (cmpResult.mInverted)
result += "!";
result += BeDumpContext::ToString(cmpResult.mCmpKind);
return result;
}
@ -3518,9 +3520,18 @@ void BeMCContext::CreateCondBr(BeMCBlock* mcBlock, BeMCOperand& testVal, const B
{
// Beef-specific: assuming CMP results aren't stomped
auto& cmpResult = mCmpResults[testVal.mCmpResultIdx];
AllocInst(BeMCInstKind_CondBr, trueBlock, BeMCOperand::FromCmpKind(cmpResult.mCmpKind));
if (cmpResult.mInverted)
{
AllocInst(BeMCInstKind_CondBr, falseBlock, BeMCOperand::FromCmpKind(cmpResult.mCmpKind));
AllocInst(BeMCInstKind_Br, trueBlock);
}
else
{
AllocInst(BeMCInstKind_CondBr, trueBlock, BeMCOperand::FromCmpKind(cmpResult.mCmpKind));
AllocInst(BeMCInstKind_Br, falseBlock);
}
}
else if (testVal.mKind == BeMCOperandKind_Phi)
{
auto phi = testVal.mPhi;
@ -6135,22 +6146,30 @@ uint8 BeMCContext::GetJumpOpCode(BeCmpKind cmpKind, bool isLong)
case BeCmpKind_SLT: // JL
return 0x8C;
case BeCmpKind_ULT: // JB
case BeCmpKind_OLT: // JB
return 0x82;
case BeCmpKind_SLE: // JLE
return 0x8E;
case BeCmpKind_ULE: // JBE
case BeCmpKind_OLE: // JBE
return 0x86;
case BeCmpKind_EQ: // JE
return 0x84;
case BeCmpKind_NE: // JNE
return 0x85;
case BeCmpKind_OEQ: // JNP
return 0x8B;
case BeCmpKind_UNE: // JP
return 0x8A;
case BeCmpKind_SGT: // JG
return 0x8F;
case BeCmpKind_UGT: // JA
case BeCmpKind_OGT: // JA
return 0x87;
case BeCmpKind_SGE: // JGE
return 0x8D;
case BeCmpKind_UGE: // JAE
case BeCmpKind_OGE: // JAE
return 0x83;
case BeCmpKind_NB: // JNB
return 0x83;
@ -6167,22 +6186,30 @@ uint8 BeMCContext::GetJumpOpCode(BeCmpKind cmpKind, bool isLong)
case BeCmpKind_SLT: // JL
return 0x7C;
case BeCmpKind_ULT: // JB
case BeCmpKind_OLT: // JB
return 0x72;
case BeCmpKind_SLE: // JLE
return 0x7E;
case BeCmpKind_ULE: // JBE
case BeCmpKind_OLE: // JBE
return 0x76;
case BeCmpKind_EQ: // JE
return 0x74;
case BeCmpKind_NE: // JNE
return 0x75;
case BeCmpKind_OEQ: // JNP
return 0x7B;
case BeCmpKind_UNE: // JP
return 0x7A;
case BeCmpKind_SGT: // JG
return 0x7F;
case BeCmpKind_UGT: // JA
case BeCmpKind_OGT: // JA
return 0x77;
case BeCmpKind_SGE: // JGE
return 0x7D;
case BeCmpKind_UGE: // JAE
case BeCmpKind_OGE: // JAE
return 0x73;
case BeCmpKind_NB: // JNB
return 0x73;
@ -6669,6 +6696,8 @@ void BeMCContext::InitializedPassHelper(BeMCBlock* mcBlock, BeVTrackingGenContex
{
auto cmpToBoolInst = AllocInst(BeMCInstKind_CmpToBool, BeMCOperand::FromCmpKind(cmpResult.mCmpKind), BeMCOperand(), instIdx + 1);
cmpToBoolInst->mResult = BeMCOperand::FromVReg(cmpResult.mResultVRegIdx);
if (cmpResult.mInverted)
AllocInst(BeMCInstKind_Not, cmpToBoolInst->mResult, instIdx + 2);
}
inst->mResult = BeMCOperand();
@ -10384,6 +10413,14 @@ bool BeMCContext::DoLegalization()
if (arg0Type->IsFloat())
{
if (!arg0.IsNativeReg())
{
// We need an <xmm> for reg0. We're not allowed to reorder SwapCmpSides due to NaN handling
ReplaceWithNewVReg(inst->mArg0, instIdx, true, true);
isFinalRun = false;
break;
}
// Cmp <r/m>, <xmm> is not valid, only Cmp <xmm>, <r/m>
if ((!arg0.IsNativeReg()) && (arg1.IsNativeReg()))
{
@ -11347,7 +11384,8 @@ bool BeMCContext::DoJumpRemovePass()
auto nextNextInst = mcBlock->mInstructions[nextNextIdx];
if ((nextInst->mKind == BeMCInstKind_Br) &&
(nextNextInst->mKind == BeMCInstKind_Label) && (inst->mArg0 == nextNextInst->mArg0))
(nextNextInst->mKind == BeMCInstKind_Label) && (inst->mArg0 == nextNextInst->mArg0) &&
(!BeModule::IsCmpOrdered(inst->mArg1.mCmpKind)))
{
didWork = true;
inst->mArg0 = nextInst->mArg0;
@ -14815,17 +14853,17 @@ void BeMCContext::DoCodeEmission()
{
case BeMCInstForm_XMM32_FRM32:
case BeMCInstForm_XMM32_IMM:
// COMISS
// UCOMISS
EmitREX(inst->mArg0, inst->mArg1, false);
Emit(0x0F); Emit(0x2F);
Emit(0x0F); Emit(0x2E);
EmitModRM(inst->mArg0, inst->mArg1);
break;
case BeMCInstForm_XMM64_FRM64:
case BeMCInstForm_XMM64_IMM:
// COMISD
// UCOMISD
Emit(0x66);
EmitREX(inst->mArg0, inst->mArg1, false);
Emit(0x0F); Emit(0x2F);
Emit(0x0F); Emit(0x2E);
EmitModRM(inst->mArg0, inst->mArg1);
break;
default:
@ -15102,25 +15140,51 @@ void BeMCContext::DoCodeEmission()
}
break;
case BeMCInstKind_CondBr:
{
for (int pass = 0; pass < 2; pass++)
{
if (inst->mArg0.mKind == BeMCOperandKind_Immediate_i64)
{
if (pass == 1)
break;
mOut.Write(GetJumpOpCode(inst->mArg1.mCmpKind, false));
mOut.Write((uint8)inst->mArg0.mImmediate);
}
else
{
BeCmpKind cmpKind = inst->mArg1.mCmpKind;
if (pass == 1)
{
switch (cmpKind)
{
case BeCmpKind_OEQ:
cmpKind = BeCmpKind_EQ;
break;
case BeCmpKind_UNE:
cmpKind = BeCmpKind_NE;
break;
default:
cmpKind = BeCmpKind_None;
}
if (cmpKind == BeCmpKind_None)
break;
}
BF_ASSERT(inst->mArg0.mKind == BeMCOperandKind_Label);
BeMCJump jump;
jump.mCodeOffset = funcCodePos;
jump.mLabelIdx = inst->mArg0.mLabelIdx;
// Speculative make it a short jump
jump.mJumpKind = 0;
jump.mCmpKind = inst->mArg1.mCmpKind;
jump.mCmpKind = cmpKind;;
deferredJumps.push_back(jump);
mOut.Write(GetJumpOpCode(jump.mCmpKind, false));
mOut.Write((uint8)0);
funcCodePos += 2;
}
}
}
break;
@ -16687,19 +16751,31 @@ void BeMCContext::Generate(BeFunction* function)
if (valType->IsFloat())
{
// These operations are set up to properly handle NaN comparisons
switch (cmpResult.mCmpKind)
{
case BeCmpKind_SLT:
cmpResult.mCmpKind = BeCmpKind_ULT;
cmpResult.mCmpKind = BeCmpKind_OLE;
cmpResult.mInverted = true;
BF_SWAP(mcInst->mArg0, mcInst->mArg1);
break;
case BeCmpKind_SLE:
cmpResult.mCmpKind = BeCmpKind_ULE;
cmpResult.mCmpKind = BeCmpKind_OLT;
cmpResult.mInverted = true;
BF_SWAP(mcInst->mArg0, mcInst->mArg1);
break;
case BeCmpKind_SGT:
cmpResult.mCmpKind = BeCmpKind_UGT;
cmpResult.mCmpKind = BeCmpKind_OGT;
break;
case BeCmpKind_SGE:
cmpResult.mCmpKind = BeCmpKind_UGE;
cmpResult.mCmpKind = BeCmpKind_OGE;
break;
case BeCmpKind_EQ:
cmpResult.mCmpKind = BeCmpKind_UNE;
cmpResult.mInverted = true;
break;
case BeCmpKind_NE:
cmpResult.mCmpKind = BeCmpKind_UNE;
break;
}
}
@ -16710,6 +16786,7 @@ void BeMCContext::Generate(BeFunction* function)
result.mCmpResultIdx = cmpResultIdx;
mcInst->mResult = result;
break;
}
break;
case BeObjectAccessCheckInst::TypeId:
@ -18083,7 +18160,7 @@ void BeMCContext::Generate(BeFunction* function)
BEMC_ASSERT(retCount == 1);
bool wantDebug = mDebugging;
wantDebug |= function->mName == "?Test@Program@BeefTest@bf@@SAXXZ";
wantDebug |= function->mName == "?GetVal@TestProgram@BeefTest@bf@@SATint@@M@Z";
//wantDebug |= function->mName == "?Testos@Fartso@@SAHPEA1@HH@Z";
//wantDebug |= function->mName == "?GetYoopA@Fartso@@QEAAUYoop@@XZ";
//"?TestVals@Fartso@@QEAATint@@XZ";

View file

@ -166,11 +166,13 @@ struct BeCmpResult
{
BeCmpKind mCmpKind;
int mResultVRegIdx;
bool mInverted;
BeCmpResult()
{
mCmpKind = BeCmpKind_None;
mResultVRegIdx = -1;
mInverted = false;
}
};

View file

@ -1726,14 +1726,18 @@ void BeDumpContext::ToString(StringImpl& str, BeCmpKind cmpKind)
{
case BeCmpKind_SLT: str += "slt"; return;
case BeCmpKind_ULT: str += "ult"; return;
case BeCmpKind_OLT: str += "olt"; return;
case BeCmpKind_SLE: str += "sle"; return;
case BeCmpKind_ULE: str += "ule"; return;
case BeCmpKind_OLE: str += "ole"; return;
case BeCmpKind_EQ: str += "eq"; return;
case BeCmpKind_NE: str += "ne"; return;
case BeCmpKind_SGT: str += "sgt"; return;
case BeCmpKind_UGT: str += "ugt"; return;
case BeCmpKind_OGT: str += "ogt"; return;
case BeCmpKind_SGE: str += "sge"; return;
case BeCmpKind_UGE: str += "uge"; return;
case BeCmpKind_OGE: str += "oge"; return;
case BeCmpKind_NB: str += "nb"; return;
case BeCmpKind_NO: str += "no"; return;
default:
@ -3017,22 +3021,34 @@ BeCmpKind BeModule::InvertCmp(BeCmpKind cmpKind)
return BeCmpKind_SGE;
case BeCmpKind_ULT:
return BeCmpKind_UGE;
case BeCmpKind_OLT:
return BeCmpKind_OGE;
case BeCmpKind_SLE:
return BeCmpKind_SGT;
case BeCmpKind_ULE:
return BeCmpKind_UGT;
case BeCmpKind_OLE:
return BeCmpKind_OGT;
case BeCmpKind_EQ:
return BeCmpKind_NE;
case BeCmpKind_OEQ:
return BeCmpKind_UNE;
case BeCmpKind_NE:
return BeCmpKind_EQ;
case BeCmpKind_UNE:
return BeCmpKind_OEQ;
case BeCmpKind_SGT:
return BeCmpKind_SLE;
case BeCmpKind_UGT:
return BeCmpKind_ULE;
case BeCmpKind_OGT:
return BeCmpKind_OLE;
case BeCmpKind_SGE:
return BeCmpKind_SLT;
case BeCmpKind_UGE:
return BeCmpKind_ULT;
case BeCmpKind_OGE:
return BeCmpKind_OLT;
}
return cmpKind;
}
@ -3045,10 +3061,14 @@ BeCmpKind BeModule::SwapCmpSides(BeCmpKind cmpKind)
return BeCmpKind_SGT;
case BeCmpKind_ULT:
return BeCmpKind_UGT;
case BeCmpKind_OLT:
return BeCmpKind_OGT;
case BeCmpKind_SLE:
return BeCmpKind_SGE;
case BeCmpKind_ULE:
return BeCmpKind_UGE;
case BeCmpKind_OLE:
return BeCmpKind_OGE;
case BeCmpKind_EQ:
return BeCmpKind_EQ;
case BeCmpKind_NE:
@ -3057,14 +3077,33 @@ BeCmpKind BeModule::SwapCmpSides(BeCmpKind cmpKind)
return BeCmpKind_SLT;
case BeCmpKind_UGT:
return BeCmpKind_ULT;
case BeCmpKind_OGT:
return BeCmpKind_OLT;
case BeCmpKind_SGE:
return BeCmpKind_SLE;
case BeCmpKind_UGE:
return BeCmpKind_ULE;
case BeCmpKind_OGE:
return BeCmpKind_OLE;
}
return cmpKind;
}
bool BeModule::IsCmpOrdered(BeCmpKind cmpKind)
{
switch (cmpKind)
{
case BeCmpKind_OLT:
case BeCmpKind_OLE:
case BeCmpKind_OGT:
case BeCmpKind_OGE:
case BeCmpKind_OEQ:
case BeCmpKind_UNE:
return true;
}
return false;
}
void BeModule::AddInst(BeInst* inst)
{
inst->mDbgLoc = mCurDbgLoc;

View file

@ -858,14 +858,20 @@ enum BeCmpKind
BeCmpKind_SLT,
BeCmpKind_ULT,
BeCmpKind_OLT,
BeCmpKind_SLE,
BeCmpKind_ULE,
BeCmpKind_OLE,
BeCmpKind_EQ,
BeCmpKind_OEQ,
BeCmpKind_NE,
BeCmpKind_UNE,
BeCmpKind_SGT,
BeCmpKind_UGT,
BeCmpKind_OGT,
BeCmpKind_SGE,
BeCmpKind_UGE,
BeCmpKind_OGE,
BeCmpKind_NB,
BeCmpKind_NO,
};
@ -2371,6 +2377,7 @@ public:
static BeCmpKind InvertCmp(BeCmpKind cmpKind);
static BeCmpKind SwapCmpSides(BeCmpKind cmpKind);
static bool IsCmpOrdered(BeCmpKind cmpKind);
void SetActiveFunction(BeFunction* function);
BeArgument* GetArgument(int arg);
BeBlock* CreateBlock(const StringImpl& name);

View file

@ -297,12 +297,13 @@ String BfIRConstHolder::ToString(BfIRValue irValue)
}
else if (constant->mTypeCode == BfTypeCode_NullPtr)
{
String ret = "null";
String ret;
if (constant->mIRType)
{
ret += "\n";
ret += ToString(constant->mIRType);
ret += " ";
}
ret += "null";
return ret;
}
else if (constant->mTypeCode == BfTypeCode_Boolean)
@ -334,13 +335,13 @@ String BfIRConstHolder::ToString(BfIRValue irValue)
{
auto bitcast = (BfConstantBitCast*)constant;
BfIRValue targetConst(BfIRValueFlags_Const, bitcast->mTarget);
return ToString(targetConst) + " BitCast to " + ToString(bitcast->mToType);
return ToString(bitcast->mToType) += " bitcast " + ToString(targetConst);
}
else if (constant->mConstType == BfConstType_Box)
{
auto box = (BfConstantBox*)constant;
BfIRValue targetConst(BfIRValueFlags_Const, box->mTarget);
return ToString(targetConst) + " box to " + ToString(box->mToType);
return ToString(box->mToType) + " box " + ToString(targetConst);
}
else if (constant->mConstType == BfConstType_GEP32_1)
{
@ -376,7 +377,7 @@ String BfIRConstHolder::ToString(BfIRValue irValue)
{
auto constAgg = (BfConstantAgg*)constant;
String str = ToString(constAgg->mType);
str += "(";
str += " (";
for (int i = 0; i < (int)constAgg->mValues.size(); i++)
{
@ -4778,6 +4779,12 @@ BfIRValue BfIRBuilder::CreateInBoundsGEP(BfIRValue val, int idx0, int idx1)
{
if (val.IsConst())
{
#ifdef _DEBUG
auto targetConstant = GetConstant(val);
BF_ASSERT((mBfIRCodeGen == NULL) ||
((targetConstant->mTypeCode != BfTypeCode_NullPtr) && (targetConstant->mConstType != BfConstType_BitCastNull)));
#endif
auto constGEP = mTempAlloc.Alloc<BfConstantGEP32_2>();
constGEP->mConstType = BfConstType_GEP32_2;
constGEP->mTarget = val.mId;
@ -4804,6 +4811,10 @@ BfIRValue BfIRBuilder::CreateInBoundsGEP(BfIRValue val, BfIRValue idx0)
auto constant = GetConstant(val);
if (constant != NULL)
{
#ifdef _DEBUG
//BF_ASSERT((constant->mTypeCode != BfTypeCode_NullPtr) && (constant->mConstType != BfConstType_BitCastNull));
#endif
if (constant->mConstType == BfConstType_IntToPtr)
{
auto fromPtrToInt = (BfConstantIntToPtr*)constant;
@ -5051,6 +5062,12 @@ void BfIRBuilder::CreateValueScopeHardEnd(BfIRValue scopeStart)
BfIRValue BfIRBuilder::CreateLoad(BfIRValue val, bool isVolatile)
{
#ifdef _DEBUG
// auto targetConstant = GetConstant(val);
// if (targetConstant != NULL)
// BF_ASSERT((targetConstant->mTypeCode != BfTypeCode_NullPtr) && (targetConstant->mConstType != BfConstType_BitCastNull));
#endif
BfIRValue retVal = WriteCmd(BfIRCmd_Load, val, isVolatile);
NEW_CMD_INSERTED_IRVALUE;
return retVal;
@ -5058,6 +5075,12 @@ BfIRValue BfIRBuilder::CreateLoad(BfIRValue val, bool isVolatile)
BfIRValue BfIRBuilder::CreateAlignedLoad(BfIRValue val, int align, bool isVolatile)
{
#ifdef _DEBUG
// auto targetConstant = GetConstant(val);
// if (targetConstant != NULL)
// BF_ASSERT((targetConstant->mTypeCode != BfTypeCode_NullPtr) && (targetConstant->mConstType != BfConstType_BitCastNull));
#endif
BfIRValue retVal = WriteCmd(BfIRCmd_AlignedLoad, val, align, isVolatile);
NEW_CMD_INSERTED_IRVALUE;
return retVal;

View file

@ -2400,7 +2400,7 @@ void BfIRCodeGen::HandleNextCmd()
CMD_PARAM(llvm::Value*, lhs);
CMD_PARAM(llvm::Value*, rhs);
if (lhs->getType()->isFloatingPointTy())
SetResult(curId, mIRBuilder->CreateFCmpONE(lhs, rhs));
SetResult(curId, mIRBuilder->CreateFCmpUNE(lhs, rhs));
else
SetResult(curId, mIRBuilder->CreateICmpNE(lhs, rhs));
}
@ -2450,7 +2450,7 @@ void BfIRCodeGen::HandleNextCmd()
CMD_PARAM(llvm::Value*, lhs);
CMD_PARAM(llvm::Value*, rhs);
if (lhs->getType()->isFloatingPointTy())
SetResult(curId, mIRBuilder->CreateFCmpUGT(lhs, rhs));
SetResult(curId, mIRBuilder->CreateFCmpOGT(lhs, rhs));
else
SetResult(curId, mIRBuilder->CreateICmpSGT(lhs, rhs));
}

View file

@ -10126,7 +10126,17 @@ BfIRValue BfModule::AllocFromType(BfType* type, const BfAllocTarget& allocTarget
{
if ((mBfIRBuilder->mIgnoreWrites) ||
((mCompiler->mIsResolveOnly) && (!mIsComptimeModule)))
{
if (mBfIRBuilder->mIgnoreWrites)
{
return GetDefaultValue(typeInstance);
}
else
{
// Fake with alloca
return mBfIRBuilder->CreateAlloca(allocType);
}
}
auto classVDataType = ResolveTypeDef(mCompiler->mClassVDataTypeDef);
auto vData = mBfIRBuilder->CreateBitCast(vDataRef, mBfIRBuilder->MapTypeInstPtr(classVDataType->ToTypeInstance()));

View file

@ -51,6 +51,66 @@ namespace Tests
FloatParseErrTest("6e+");
}
[Test]
public static void TestCmp()
{
float fNeg = -1;
float fNan = float.NaN;
if (fNeg < 0)
{
}
else
{
Test.FatalError();
}
if (fNeg > 0)
Test.FatalError();
if (fNan < 0)
Test.FatalError();
if (fNan <= 0)
Test.FatalError();
if (fNan > 0)
Test.FatalError();
if (fNan >= 0)
Test.FatalError();
if (fNan == 0)
Test.FatalError();
if (fNan != 0)
{
}
else
{
Test.FatalError();
}
if (fNan == fNan)
Test.FatalError();
if (fNan != fNan)
{
}
else
{
Test.FatalError();
}
bool b0 = fNan < 0;
bool b1 = fNan > 0;
bool b2 = fNan == fNan;
bool b3 = fNan != fNan;
bool b4 = fNan != 0;
Test.Assert(!b0);
Test.Assert(!b1);
Test.Assert(!b2);
Test.Assert(b3);
Test.Assert(b4);
}
public static void MinMaxTest<T>(T expectedMinValue, T expectedMaxValue)
where T : IMinMaxValue<T>
where int : operator T <=> T