1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-08 03:28:20 +02:00

Add min and max vector intrinsics and codegen infrastructure

This commit is contained in:
MineGame159 2023-04-02 15:07:15 +02:00
parent 7dad948f20
commit bf36bf4b95
5 changed files with 205 additions and 1 deletions

View file

@ -27,6 +27,12 @@ namespace System.Numerics
public extern float4 wzyx { [Intrinsic("shuffle3210")] get; [Intrinsic("shuffle3210")] set; }
[Intrinsic("min")]
public static extern float4 min(float4 lhs, float4 rhs);
[Intrinsic("max")]
public static extern float4 max(float4 lhs, float4 rhs);
[Intrinsic("add")]
public static extern float4 operator+(float4 lhs, float4 rhs);
[Intrinsic("add"), Commutable]

View file

@ -4,6 +4,11 @@ namespace System.Numerics.X86
{
public static bool IsSupported => Runtime.Features.SSE;
[Inline]
public static v128 min_ps(v128 a, v128 b) => (.) float4.min((.) a, (.) b);
[Inline]
public static v128 max_ps(v128 a, v128 b) => (.) float4.max((.) a, (.) b);
[Inline]
public static v128 add_ss(v128 a, v128 b)
{

View file

@ -456,9 +456,11 @@ enum BfIRIntrinsic : uint8
BfIRIntrinsic_Lt,
BfIRIntrinsic_LtE,
BfIRIntrinsic_Malloc,
BfIRIntrinsic_Max,
BfIRIntrinsic_MemCpy,
BfIRIntrinsic_MemMove,
BfIRIntrinsic_MemSet,
BfIRIntrinsic_Min,
BfIRIntrinsic_Mod,
BfIRIntrinsic_Mul,
BfIRIntrinsic_Neq,

View file

@ -172,9 +172,11 @@ static const BuiltinEntry gIntrinEntries[] =
{"lt"},
{"lte"},
{"malloc"},
{"max"},
{"memcpy"},
{"memmove"},
{"memset"},
{"min"},
{"mod"},
{"mul"},
{"neq"},
@ -2860,10 +2862,12 @@ void BfIRCodeGen::HandleNextCmd()
{ llvm::Intrinsic::log2, 0, -1},
{ (llvm::Intrinsic::ID)-2, -1}, // lt
{ (llvm::Intrinsic::ID)-2, -1}, // lte
{ (llvm::Intrinsic::ID)-2}, // memset
{ (llvm::Intrinsic::ID)-2}, // malloc
{ (llvm::Intrinsic::ID)-2, -1}, // max
{ llvm::Intrinsic::memcpy, 0, 1, 2},
{ llvm::Intrinsic::memmove, 0, 2},
{ llvm::Intrinsic::memset, 0, 2},
{ (llvm::Intrinsic::ID)-2, -1}, // min
{ (llvm::Intrinsic::ID)-2, -1}, // mod
{ (llvm::Intrinsic::ID)-2, -1}, // mul
{ (llvm::Intrinsic::ID)-2, -1}, // neq
@ -3292,6 +3296,91 @@ void BfIRCodeGen::HandleNextCmd()
}
}
break;
case BfIRIntrinsic_Min:
case BfIRIntrinsic_Max:
{
// Get arguments as vectors
auto val0 = TryToVector(args[0]);
if (val0 == NULL)
FatalError("Intrinsic argument error");
auto val1 = TryToVector(args[1]);
if (val1 == NULL)
FatalError("Intrinsic argument error");
// Make sure both argument types are the same
auto vecType = llvm::dyn_cast<llvm::VectorType>(val0->getType());
if (vecType != llvm::dyn_cast<llvm::VectorType>(val1->getType()))
FatalError("Intrinsic argument error");
// Make sure the type is not scalable
if (vecType->getElementCount().isScalable())
FatalError("Intrinsic argument error");
// Make sure the element type is either float or double
auto elemType = vecType->getElementType();
if (!elemType->isFloatTy() && !elemType->isDoubleTy())
FatalError("Intrinsic argument error");
// Get some properties for easier access
bool isFloat = elemType->isFloatTy();
bool isMin = intrinsicData->mIntrinsic == BfIRIntrinsic_Min;
auto elemCount = vecType->getElementCount().getFixedValue();
// Get the intrinsic function
const char* funcName;
if (isFloat)
{
if (elemCount == 4)
{
funcName = isMin ? "llvm.x86.sse.min.ps" : "llvm.x86.sse.max.ps";
SetActiveFunctionSimdType(BfIRSimdType_SSE);
}
else if (elemCount == 8)
{
funcName = isMin ? "llvm.x86.avx.min.ps.256" : "llvm.x86.avx.max.ps.256";
SetActiveFunctionSimdType(BfIRSimdType_AVX2);
}
else if (elemCount == 16)
{
funcName = isMin ? "llvm.x86.avx512.min.ps.512" : "llvm.x86.avx512.max.ps.512";
SetActiveFunctionSimdType(BfIRSimdType_AVX512);
}
else
FatalError("Intrinsic argument error");
}
else
{
if (elemCount == 2)
{
funcName = isMin ? "llvm.x86.sse.min.pd" : "llvm.x86.sse.max.pd";
SetActiveFunctionSimdType(BfIRSimdType_SSE);
}
else if (elemCount == 4)
{
funcName = isMin ? "llvm.x86.avx.min.pd.256" : "llvm.x86.avx.max.pd.256";
SetActiveFunctionSimdType(BfIRSimdType_AVX2);
}
else if (elemCount == 8)
{
funcName = isMin ? "llvm.x86.avx512.min.pd.512" : "llvm.x86.avx512.max.pd.512";
SetActiveFunctionSimdType(BfIRSimdType_AVX512);
}
else
FatalError("Intrinsic argument error");
}
auto func = mLLVMModule->getOrInsertFunction(funcName, vecType, vecType, vecType);
// Call intrinsic
llvm::SmallVector<llvm::Value*, 2> args;
args.push_back(val0);
args.push_back(val1);
SetResult(curId, mIRBuilder->CreateCall(func, args));
}
break;
case BfIRIntrinsic_Cpuid:
{
llvm::Type* elemType = llvm::Type::getInt32Ty(*mLLVMContext);
@ -4883,6 +4972,55 @@ void BfIRCodeGen::SetConfigConst(int idx, int value)
mConfigConsts64.Add(constVal);
}
void BfIRCodeGen::SetActiveFunctionSimdType(BfIRSimdType type)
{
BfIRSimdType currentType;
bool contains = mFunctionsUsingSimd.TryGetValue(mActiveFunction, &currentType);
if (!contains || type > currentType)
mFunctionsUsingSimd[mActiveFunction] = type;
}
const StringImpl& BfIRCodeGen::GetSimdTypeString(BfIRSimdType type)
{
switch (type)
{
case BfIRSimdType_SSE:
return "+sse,+mmx";
case BfIRSimdType_SSE2:
return "+sse2,+sse,+mmx";
case BfIRSimdType_AVX:
return "+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
case BfIRSimdType_AVX2:
return "+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
case BfIRSimdType_AVX512:
return "+avx512f,+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
default:
return "";
}
}
BfIRSimdType BfIRCodeGen::GetSimdTypeFromFunction(llvm::Function* function)
{
if (function->hasFnAttribute("target-features"))
{
auto str = function->getFnAttribute("target-features").getValueAsString();
if (str.contains("+avx512f"))
return BfIRSimdType_AVX512;
if (str.contains("+avx2"))
return BfIRSimdType_AVX2;
if (str.contains("+avx"))
return BfIRSimdType_AVX;
if (str.contains("+sse2"))
return BfIRSimdType_SSE2;
if (str.contains("+sse"))
return BfIRSimdType_SSE;
}
return BfIRSimdType_None;
}
llvm::Value* BfIRCodeGen::GetLLVMValue(int id)
{
auto& result = mResults[id];
@ -5473,6 +5611,8 @@ llvm::Expected<llvm::BitcodeModule> FindThinLTOModule(llvm::MemoryBufferRef MBRe
bool BfIRCodeGen::WriteObjectFile(const StringImpl& outFileName)
{
ApplySimdFeatures();
// {
// PassManagerBuilderWrapper pmBuilder;
//
@ -5601,6 +5741,40 @@ bool BfIRCodeGen::WriteIR(const StringImpl& outFileName, StringImpl& error)
return true;
}
void BfIRCodeGen::ApplySimdFeatures()
{
Array<std::tuple<llvm::Function*, BfIRSimdType>> functionsToProcess;
for (auto pair : mFunctionsUsingSimd)
functionsToProcess.Add({ pair.mKey, pair.mValue });
while (functionsToProcess.Count() > 0)
{
auto tuple = functionsToProcess.front();
functionsToProcess.RemoveAt(0);
auto function = std::get<0>(tuple);
auto simdType = std::get<1>(tuple);
auto currentSimdType = GetSimdTypeFromFunction(function);
simdType = simdType > currentSimdType ? simdType : currentSimdType;
function->addFnAttr("target-features", GetSimdTypeString(simdType).c_str());
if (function->hasFnAttribute(llvm::Attribute::AlwaysInline))
{
for (auto user : function->users())
{
if (auto call = llvm::dyn_cast<llvm::CallInst>(user))
{
auto func = call->getFunction();
functionsToProcess.Add({ func, simdType });
}
}
}
}
}
int BfIRCodeGen::GetIntrinsicId(const StringImpl& name)
{
auto itr = std::lower_bound(std::begin(gIntrinEntries), std::end(gIntrinEntries), name);

View file

@ -90,6 +90,16 @@ enum BfIRSizeAlignKind
BfIRSizeAlignKind_Aligned,
};
enum BfIRSimdType
{
BfIRSimdType_None,
BfIRSimdType_SSE,
BfIRSimdType_SSE2,
BfIRSimdType_AVX,
BfIRSimdType_AVX2,
BfIRSimdType_AVX512
};
class BfIRCodeGen : public BfIRCodeGenBase
{
public:
@ -130,6 +140,7 @@ public:
Dictionary<llvm::Type*, int> mTypeToTypeIdMap;
HashSet<llvm::BasicBlock*> mLockedBlocks;
OwnedArray<BfIRIntrinsicData> mIntrinsicData;
Dictionary<llvm::Function*, BfIRSimdType> mFunctionsUsingSimd;
public:
void InitTarget();
@ -222,6 +233,10 @@ public:
void SetCodeGenOptions(BfCodeGenOptions codeGenOptions);
void SetConfigConst(int idx, int value) override;
void SetActiveFunctionSimdType(BfIRSimdType type);
const StringImpl& GetSimdTypeString(BfIRSimdType type);
BfIRSimdType GetSimdTypeFromFunction(llvm::Function* function);
llvm::Value* GetLLVMValue(int streamId);
llvm::Type* GetLLVMType(int streamId);
llvm::BasicBlock* GetLLVMBlock(int streamId);
@ -234,6 +249,8 @@ public:
bool WriteObjectFile(const StringImpl& outFileName);
bool WriteIR(const StringImpl& outFileName, StringImpl& error);
void ApplySimdFeatures();
static int GetIntrinsicId(const StringImpl& name);
static const char* GetIntrinsicName(int intrinId);
static void SetAsmKind(BfAsmKind asmKind);