1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-10 20:42:21 +02:00

Arithmetic overflow checks

This commit is contained in:
Brian Fiete 2022-01-11 08:17:09 -05:00
parent 1f0d2dcc82
commit eb375362a1
29 changed files with 503 additions and 87 deletions

View file

@ -590,6 +590,11 @@ void BeIRCodeGen::Read(bool& val)
BE_MEM_END("bool");
}
void BeIRCodeGen::Read(int8& val)
{
val = mStream->Read();
}
void BeIRCodeGen::Read(BeIRTypeEntry*& type)
{
BE_MEM_START;
@ -1432,21 +1437,24 @@ void BeIRCodeGen::HandleNextCmd()
{
CMD_PARAM(BeValue*, lhs);
CMD_PARAM(BeValue*, rhs);
SetResult(curId, mBeModule->CreateBinaryOp(BeBinaryOpKind_Add, lhs, rhs));
CMD_PARAM(int8, overflowCheckKind);
SetResult(curId, mBeModule->CreateBinaryOp(BeBinaryOpKind_Add, lhs, rhs, (BfOverflowCheckKind)overflowCheckKind));
}
break;
case BfIRCmd_Sub:
{
CMD_PARAM(BeValue*, lhs);
CMD_PARAM(BeValue*, rhs);
SetResult(curId, mBeModule->CreateBinaryOp(BeBinaryOpKind_Subtract, lhs, rhs));
CMD_PARAM(int8, overflowCheckKind);
SetResult(curId, mBeModule->CreateBinaryOp(BeBinaryOpKind_Subtract, lhs, rhs, (BfOverflowCheckKind)overflowCheckKind));
}
break;
case BfIRCmd_Mul:
{
CMD_PARAM(BeValue*, lhs);
CMD_PARAM(BeValue*, rhs);
SetResult(curId, mBeModule->CreateBinaryOp(BeBinaryOpKind_Multiply, lhs, rhs));
CMD_PARAM(int8, overflowCheckKind);
SetResult(curId, mBeModule->CreateBinaryOp(BeBinaryOpKind_Multiply, lhs, rhs, (BfOverflowCheckKind)overflowCheckKind));
}
break;
case BfIRCmd_SDiv:
@ -2100,6 +2108,11 @@ void BeIRCodeGen::HandleNextCmd()
mBeModule->RemoveBlock(mActiveFunction, fromBlock);
}
break;
case BfIRCmd_GetInsertBlock:
{
SetResult(curId, mBeModule->mActiveBlock);
}
break;
case BfIRCmd_SetInsertPoint:
{
CMD_PARAM(BeBlock*, block);

View file

@ -119,6 +119,7 @@ public:
void Read(int64& i);
void Read(Val128& i);
void Read(bool& val);
void Read(int8& val);
void Read(BeIRTypeEntry*& type);
void Read(BeType*& beType);
void Read(BeFunctionType*& beType);

View file

@ -6123,6 +6123,10 @@ uint8 BeMCContext::GetJumpOpCode(BeCmpKind cmpKind, bool isLong)
return 0x8D;
case BeCmpKind_UGE: // JAE
return 0x83;
case BeCmpKind_NB: // JNB
return 0x83;
case BeCmpKind_NO: // JNO
return 0x81;
}
}
else
@ -6151,6 +6155,10 @@ uint8 BeMCContext::GetJumpOpCode(BeCmpKind cmpKind, bool isLong)
return 0x7D;
case BeCmpKind_UGE: // JAE
return 0x73;
case BeCmpKind_NB: // JNB
return 0x73;
case BeCmpKind_NO: // JNO
return 0x71;
}
}
@ -10116,8 +10124,10 @@ bool BeMCContext::DoLegalization()
break;
case BeMCInstKind_Mul:
case BeMCInstKind_IMul:
{
if (arg0Type->mSize == 1)
{
bool handled = false;
if ((arg0Type->mSize == 1) && (arg0Type->IsIntable()))
{
if ((!arg0.IsNativeReg()) || (arg0.mReg != X64Reg_AL) || (inst->mResult))
{
@ -10152,8 +10162,54 @@ bool BeMCContext::DoLegalization()
}
BF_ASSERT(!inst->mResult);
handled = true;
}
else
else if ((inst->mKind == BeMCInstKind_Mul) && (arg0Type->IsIntable()))
{
auto wantReg0 = ResizeRegister(X64Reg_RAX, arg0Type->mSize);
if ((!arg0.IsNativeReg()) || (arg0.mReg != wantReg0) || (inst->mResult))
{
auto srcVRegInfo = GetVRegInfo(inst->mArg0);
// unsigned multiplies can only be done on AX/EAX/RAX
AllocInst(BeMCInstKind_PreserveVolatiles, BeMCOperand::FromReg(X64Reg_RAX), instIdx++);
AllocInst(BeMCInstKind_PreserveVolatiles, BeMCOperand::FromReg(X64Reg_RDX), instIdx++);
auto vregInfo0 = GetVRegInfo(inst->mArg0);
if (vregInfo0 != NULL)
{
vregInfo0->mDisableRAX = true;
vregInfo0->mDisableRDX = true;
}
auto vregInfo1 = GetVRegInfo(inst->mArg1);
if (vregInfo1 != NULL)
{
vregInfo1->mDisableRAX = true;
vregInfo1->mDisableRDX = true;
}
AllocInst(BeMCInstKind_Mov, BeMCOperand::FromReg(wantReg0), inst->mArg0, instIdx++);
AllocInst(BeMCInstKind_Mov, inst->mResult ? inst->mResult : inst->mArg0, BeMCOperand::FromReg(wantReg0), instIdx++ + 1);
inst->mArg0 = BeMCOperand::FromReg(wantReg0);
inst->mResult = BeMCOperand();
AllocInst(BeMCInstKind_RestoreVolatiles, BeMCOperand::FromReg(X64Reg_RDX), instIdx++ + 1);
AllocInst(BeMCInstKind_RestoreVolatiles, BeMCOperand::FromReg(X64Reg_RAX), instIdx++ + 1);
isFinalRun = false;
break;
}
if (inst->mArg1.IsImmediateInt())
{
ReplaceWithNewVReg(inst->mArg1, instIdx, true, false);
}
BF_ASSERT(!inst->mResult);
handled = true;
}
if (handled)
{
if (inst->mResult)
{
@ -14404,6 +14460,51 @@ void BeMCContext::DoCodeEmission()
}
break;
case BeMCInstKind_Mul:
{
if (arg0Type->IsIntable())
{
bool isValid = true;
auto typeCode = GetType(inst->mArg1)->mTypeCode;
switch (typeCode)
{
case BeTypeCode_Int8:
isValid = inst->mArg0 == BeMCOperand::FromReg(X64Reg_AL);
break;
case BeTypeCode_Int16:
isValid = inst->mArg0 == BeMCOperand::FromReg(X64Reg_AX);
break;
case BeTypeCode_Int32:
isValid = inst->mArg0 == BeMCOperand::FromReg(X64Reg_EAX);
break;
case BeTypeCode_Int64:
isValid = inst->mArg0 == BeMCOperand::FromReg(X64Reg_RAX);
break;
default:
isValid = false;
}
if (!isValid)
SoftFail("Invalid mul arguments");
switch (typeCode)
{
case BeTypeCode_Int8:
EmitREX(BeMCOperand(), inst->mArg1, false);
Emit(0xF6);
EmitModRM(4, inst->mArg1);
break;
case BeTypeCode_Int16: Emit(0x66); // Fallthrough
case BeTypeCode_Int32:
case BeTypeCode_Int64:
EmitREX(BeMCOperand(), inst->mArg1, typeCode == BeTypeCode_Int64);
Emit(0xF7); EmitModRM(4, inst->mArg1);
break;
default:
NotImpl();
}
break;
}
}
//Fallthrough
case BeMCInstKind_IMul:
{
if (instForm == BeMCInstForm_XMM128_RM128)
@ -14932,17 +15033,25 @@ void BeMCContext::DoCodeEmission()
break;
case BeMCInstKind_CondBr:
{
BF_ASSERT(inst->mArg0.mKind == BeMCOperandKind_Label);
BeMCJump jump;
jump.mCodeOffset = funcCodePos;
jump.mLabelIdx = inst->mArg0.mLabelIdx;
// Speculative make it a short jump
jump.mJumpKind = 0;
jump.mCmpKind = inst->mArg1.mCmpKind;
deferredJumps.push_back(jump);
if (inst->mArg0.mKind == BeMCOperandKind_Immediate_i64)
{
mOut.Write(GetJumpOpCode(inst->mArg1.mCmpKind, false));
mOut.Write((uint8)inst->mArg0.mImmediate);
}
else
{
BF_ASSERT(inst->mArg0.mKind == BeMCOperandKind_Label);
BeMCJump jump;
jump.mCodeOffset = funcCodePos;
jump.mLabelIdx = inst->mArg0.mLabelIdx;
// Speculative make it a short jump
jump.mJumpKind = 0;
jump.mCmpKind = inst->mArg1.mCmpKind;
deferredJumps.push_back(jump);
mOut.Write(GetJumpOpCode(jump.mCmpKind, false));
mOut.Write((uint8)0);
mOut.Write(GetJumpOpCode(jump.mCmpKind, false));
mOut.Write((uint8)0);
}
}
break;
case BeMCInstKind_Br:
@ -15857,7 +15966,7 @@ void BeMCContext::Print(bool showVRegFlags, bool showVRegDetails)
OutputDebugStr(ToString(showVRegFlags, showVRegDetails));
}
BeMCOperand BeMCContext::AllocBinaryOp(BeMCInstKind instKind, const BeMCOperand& lhs, const BeMCOperand& rhs, BeMCBinIdentityKind identityKind)
BeMCOperand BeMCContext::AllocBinaryOp(BeMCInstKind instKind, const BeMCOperand& lhs, const BeMCOperand& rhs, BeMCBinIdentityKind identityKind, BeMCOverflowCheckKind overflowCheckKind)
{
if ((lhs.IsImmediate()) && (lhs.mKind == rhs.mKind))
{
@ -15918,6 +16027,13 @@ BeMCOperand BeMCContext::AllocBinaryOp(BeMCInstKind instKind, const BeMCOperand&
auto mcInst = AllocInst(instKind, lhs, rhs);
mcInst->mResult = result;
if (overflowCheckKind != BeMCOverflowCheckKind_None)
{
AllocInst(BeMCInstKind_CondBr, BeMCOperand::FromImmediate(1), BeMCOperand::FromCmpKind((overflowCheckKind == BeMCOverflowCheckKind_B) ? BeCmpKind_NB : BeCmpKind_NO));
AllocInst(BeMCInstKind_DbgBreak);
}
return result;
}
@ -16399,9 +16515,18 @@ void BeMCContext::Generate(BeFunction* function)
switch (castedInst->mOpKind)
{
case BeBinaryOpKind_Add: result = AllocBinaryOp(BeMCInstKind_Add, mcLHS, mcRHS, BeMCBinIdentityKind_Any_IsZero); break;
case BeBinaryOpKind_Subtract: result = AllocBinaryOp(BeMCInstKind_Sub, mcLHS, mcRHS, BeMCBinIdentityKind_Right_IsZero); break;
case BeBinaryOpKind_Multiply: result = AllocBinaryOp(BeMCInstKind_IMul, mcLHS, mcRHS, BeMCBinIdentityKind_Any_IsOne); break;
case BeBinaryOpKind_Add: result = AllocBinaryOp(BeMCInstKind_Add, mcLHS, mcRHS, BeMCBinIdentityKind_Any_IsZero,
((castedInst->mOverflowCheckKind & BfOverflowCheckKind_Signed) != 0) ? BeMCOverflowCheckKind_O :
((castedInst->mOverflowCheckKind & BfOverflowCheckKind_Unsigned) != 0) ? BeMCOverflowCheckKind_B : BeMCOverflowCheckKind_None);
break;
case BeBinaryOpKind_Subtract: result = AllocBinaryOp(BeMCInstKind_Sub, mcLHS, mcRHS, BeMCBinIdentityKind_Right_IsZero,
((castedInst->mOverflowCheckKind & BfOverflowCheckKind_Signed) != 0) ? BeMCOverflowCheckKind_O :
((castedInst->mOverflowCheckKind & BfOverflowCheckKind_Unsigned) != 0) ? BeMCOverflowCheckKind_B : BeMCOverflowCheckKind_None);
break;
case BeBinaryOpKind_Multiply: result = AllocBinaryOp(((castedInst->mOverflowCheckKind & BfOverflowCheckKind_Unsigned) != 0) ? BeMCInstKind_Mul : BeMCInstKind_IMul, mcLHS, mcRHS, BeMCBinIdentityKind_Any_IsOne,
((castedInst->mOverflowCheckKind & BfOverflowCheckKind_Signed) != 0) ? BeMCOverflowCheckKind_O :
((castedInst->mOverflowCheckKind & BfOverflowCheckKind_Unsigned) != 0) ? BeMCOverflowCheckKind_O : BeMCOverflowCheckKind_None);
break;
case BeBinaryOpKind_SDivide: result = AllocBinaryOp(BeMCInstKind_IDiv, mcLHS, mcRHS, BeMCBinIdentityKind_Right_IsOne); break;
case BeBinaryOpKind_UDivide: result = AllocBinaryOp(BeMCInstKind_Div, mcLHS, mcRHS, BeMCBinIdentityKind_Right_IsOne); break;
case BeBinaryOpKind_SModulus: result = AllocBinaryOp(BeMCInstKind_IRem, mcLHS, mcRHS, type->IsFloat() ? BeMCBinIdentityKind_None : BeMCBinIdentityKind_Right_IsOne_Result_Zero); break;

View file

@ -1293,6 +1293,13 @@ struct BeRMParamsInfo
}
};
enum BeMCOverflowCheckKind
{
BeMCOverflowCheckKind_None,
BeMCOverflowCheckKind_B,
BeMCOverflowCheckKind_O
};
// This class only processes one function per instantiation
class BeMCContext
{
@ -1367,7 +1374,7 @@ public:
BeMCInst* AllocInst(BeMCInstKind instKind, const BeMCOperand& arg0, const BeMCOperand& arg1, int insertIdx = -1);
void MergeInstFlags(BeMCInst* prevInst, BeMCInst* inst, BeMCInst* nextInst);
void RemoveInst(BeMCBlock* block, int instIdx, bool needChangesMerged = true, bool removeFromList = true);
BeMCOperand AllocBinaryOp(BeMCInstKind instKind, const BeMCOperand & lhs, const BeMCOperand & rhs, BeMCBinIdentityKind identityKind);
BeMCOperand AllocBinaryOp(BeMCInstKind instKind, const BeMCOperand & lhs, const BeMCOperand & rhs, BeMCBinIdentityKind identityKind, BeMCOverflowCheckKind overflowCheckKind = BeMCOverflowCheckKind_None);
BeMCOperand GetCallArgVReg(int argIdx, BeTypeCode typeCode);
BeMCOperand CreateCall(const BeMCOperand& func, const SizedArrayImpl<BeMCOperand>& args, BeType* retType = NULL, BfIRCallingConv callingConv = BfIRCallingConv_CDecl, bool structRet = false, bool noReturn = false, bool isVarArg = false);
BeMCOperand CreateCall(const BeMCOperand& func, const SizedArrayImpl<BeValue*>& args, BeType* retType = NULL, BfIRCallingConv callingConv = BfIRCallingConv_CDecl, bool structRet = false, bool noReturn = false, bool isVarArg = false);

View file

@ -1683,6 +1683,8 @@ void BeDumpContext::ToString(StringImpl& str, BeCmpKind cmpKind)
case BeCmpKind_UGT: str += "ugt"; return;
case BeCmpKind_SGE: str += "sge"; return;
case BeCmpKind_UGE: str += "uge"; return;
case BeCmpKind_NB: str += "nb"; return;
case BeCmpKind_NO: str += "no"; return;
default:
str += "???";
}
@ -3292,7 +3294,7 @@ BeCmpInst* BeModule::CreateCmp(BeCmpKind cmpKind, BeValue* lhs, BeValue* rhs)
return inst;
}
BeBinaryOpInst* BeModule::CreateBinaryOp(BeBinaryOpKind opKind, BeValue* lhs, BeValue* rhs)
BeBinaryOpInst* BeModule::CreateBinaryOp(BeBinaryOpKind opKind, BeValue* lhs, BeValue* rhs, BfOverflowCheckKind overflowCheckKind)
{
#ifdef _DEBUG
auto leftType = lhs->GetType();
@ -3303,6 +3305,7 @@ BeBinaryOpInst* BeModule::CreateBinaryOp(BeBinaryOpKind opKind, BeValue* lhs, Be
inst->mOpKind = opKind;
inst->mLHS = lhs;
inst->mRHS = rhs;
inst->mOverflowCheckKind = overflowCheckKind;
AddInst(inst);
return inst;
}

View file

@ -820,6 +820,7 @@ public:
BE_VALUE_TYPE(BeBinaryOpInst, BeInst);
BeBinaryOpKind mOpKind;
BfOverflowCheckKind mOverflowCheckKind;
BeValue* mLHS;
BeValue* mRHS;
@ -829,6 +830,7 @@ public:
{
hashCtx.Mixin(TypeId);
hashCtx.Mixin(mOpKind);
hashCtx.Mixin(mOverflowCheckKind);
mLHS->HashReference(hashCtx);
mRHS->HashReference(hashCtx);
}
@ -847,7 +849,9 @@ enum BeCmpKind
BeCmpKind_SGT,
BeCmpKind_UGT,
BeCmpKind_SGE,
BeCmpKind_UGE
BeCmpKind_UGE,
BeCmpKind_NB,
BeCmpKind_NO,
};
class BeCmpInst : public BeInst
@ -2338,7 +2342,7 @@ public:
BeNumericCastInst* CreateNumericCast(BeValue* value, BeType* toType, bool valSigned, bool toSigned);
BeBitCastInst* CreateBitCast(BeValue* value, BeType* toType);;
BeCmpInst* CreateCmp(BeCmpKind cmpKind, BeValue* lhs, BeValue* rhs);
BeBinaryOpInst* CreateBinaryOp(BeBinaryOpKind opKind, BeValue* lhs, BeValue* rhs);
BeBinaryOpInst* CreateBinaryOp(BeBinaryOpKind opKind, BeValue* lhs, BeValue* rhs, BfOverflowCheckKind overflowCheckKind = BfOverflowCheckKind_None);
BeAllocaInst* CreateAlloca(BeType* type);
BeLoadInst* CreateLoad(BeValue* value, bool isVolatile);