mirror of
https://github.com/beefytech/Beef.git
synced 2025-06-08 03:28:20 +02:00
Add min and max vector intrinsics and codegen infrastructure
This commit is contained in:
parent
7dad948f20
commit
bf36bf4b95
5 changed files with 205 additions and 1 deletions
|
@ -27,6 +27,12 @@ namespace System.Numerics
|
|||
|
||||
public extern float4 wzyx { [Intrinsic("shuffle3210")] get; [Intrinsic("shuffle3210")] set; }
|
||||
|
||||
[Intrinsic("min")]
|
||||
public static extern float4 min(float4 lhs, float4 rhs);
|
||||
|
||||
[Intrinsic("max")]
|
||||
public static extern float4 max(float4 lhs, float4 rhs);
|
||||
|
||||
[Intrinsic("add")]
|
||||
public static extern float4 operator+(float4 lhs, float4 rhs);
|
||||
[Intrinsic("add"), Commutable]
|
||||
|
|
|
@ -4,6 +4,11 @@ namespace System.Numerics.X86
|
|||
{
|
||||
public static bool IsSupported => Runtime.Features.SSE;
|
||||
|
||||
[Inline]
|
||||
public static v128 min_ps(v128 a, v128 b) => (.) float4.min((.) a, (.) b);
|
||||
[Inline]
|
||||
public static v128 max_ps(v128 a, v128 b) => (.) float4.max((.) a, (.) b);
|
||||
|
||||
[Inline]
|
||||
public static v128 add_ss(v128 a, v128 b)
|
||||
{
|
||||
|
|
|
@ -456,9 +456,11 @@ enum BfIRIntrinsic : uint8
|
|||
BfIRIntrinsic_Lt,
|
||||
BfIRIntrinsic_LtE,
|
||||
BfIRIntrinsic_Malloc,
|
||||
BfIRIntrinsic_Max,
|
||||
BfIRIntrinsic_MemCpy,
|
||||
BfIRIntrinsic_MemMove,
|
||||
BfIRIntrinsic_MemSet,
|
||||
BfIRIntrinsic_Min,
|
||||
BfIRIntrinsic_Mod,
|
||||
BfIRIntrinsic_Mul,
|
||||
BfIRIntrinsic_Neq,
|
||||
|
|
|
@ -172,9 +172,11 @@ static const BuiltinEntry gIntrinEntries[] =
|
|||
{"lt"},
|
||||
{"lte"},
|
||||
{"malloc"},
|
||||
{"max"},
|
||||
{"memcpy"},
|
||||
{"memmove"},
|
||||
{"memset"},
|
||||
{"min"},
|
||||
{"mod"},
|
||||
{"mul"},
|
||||
{"neq"},
|
||||
|
@ -2860,10 +2862,12 @@ void BfIRCodeGen::HandleNextCmd()
|
|||
{ llvm::Intrinsic::log2, 0, -1},
|
||||
{ (llvm::Intrinsic::ID)-2, -1}, // lt
|
||||
{ (llvm::Intrinsic::ID)-2, -1}, // lte
|
||||
{ (llvm::Intrinsic::ID)-2}, // memset
|
||||
{ (llvm::Intrinsic::ID)-2}, // malloc
|
||||
{ (llvm::Intrinsic::ID)-2, -1}, // max
|
||||
{ llvm::Intrinsic::memcpy, 0, 1, 2},
|
||||
{ llvm::Intrinsic::memmove, 0, 2},
|
||||
{ llvm::Intrinsic::memset, 0, 2},
|
||||
{ (llvm::Intrinsic::ID)-2, -1}, // min
|
||||
{ (llvm::Intrinsic::ID)-2, -1}, // mod
|
||||
{ (llvm::Intrinsic::ID)-2, -1}, // mul
|
||||
{ (llvm::Intrinsic::ID)-2, -1}, // neq
|
||||
|
@ -3292,6 +3296,91 @@ void BfIRCodeGen::HandleNextCmd()
|
|||
}
|
||||
}
|
||||
break;
|
||||
case BfIRIntrinsic_Min:
|
||||
case BfIRIntrinsic_Max:
|
||||
{
|
||||
// Get arguments as vectors
|
||||
auto val0 = TryToVector(args[0]);
|
||||
if (val0 == NULL)
|
||||
FatalError("Intrinsic argument error");
|
||||
|
||||
auto val1 = TryToVector(args[1]);
|
||||
if (val1 == NULL)
|
||||
FatalError("Intrinsic argument error");
|
||||
|
||||
// Make sure both argument types are the same
|
||||
auto vecType = llvm::dyn_cast<llvm::VectorType>(val0->getType());
|
||||
if (vecType != llvm::dyn_cast<llvm::VectorType>(val1->getType()))
|
||||
FatalError("Intrinsic argument error");
|
||||
|
||||
// Make sure the type is not scalable
|
||||
if (vecType->getElementCount().isScalable())
|
||||
FatalError("Intrinsic argument error");
|
||||
|
||||
// Make sure the element type is either float or double
|
||||
auto elemType = vecType->getElementType();
|
||||
if (!elemType->isFloatTy() && !elemType->isDoubleTy())
|
||||
FatalError("Intrinsic argument error");
|
||||
|
||||
// Get some properties for easier access
|
||||
bool isFloat = elemType->isFloatTy();
|
||||
bool isMin = intrinsicData->mIntrinsic == BfIRIntrinsic_Min;
|
||||
auto elemCount = vecType->getElementCount().getFixedValue();
|
||||
|
||||
// Get the intrinsic function
|
||||
const char* funcName;
|
||||
|
||||
if (isFloat)
|
||||
{
|
||||
if (elemCount == 4)
|
||||
{
|
||||
funcName = isMin ? "llvm.x86.sse.min.ps" : "llvm.x86.sse.max.ps";
|
||||
SetActiveFunctionSimdType(BfIRSimdType_SSE);
|
||||
}
|
||||
else if (elemCount == 8)
|
||||
{
|
||||
funcName = isMin ? "llvm.x86.avx.min.ps.256" : "llvm.x86.avx.max.ps.256";
|
||||
SetActiveFunctionSimdType(BfIRSimdType_AVX2);
|
||||
}
|
||||
else if (elemCount == 16)
|
||||
{
|
||||
funcName = isMin ? "llvm.x86.avx512.min.ps.512" : "llvm.x86.avx512.max.ps.512";
|
||||
SetActiveFunctionSimdType(BfIRSimdType_AVX512);
|
||||
}
|
||||
else
|
||||
FatalError("Intrinsic argument error");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (elemCount == 2)
|
||||
{
|
||||
funcName = isMin ? "llvm.x86.sse.min.pd" : "llvm.x86.sse.max.pd";
|
||||
SetActiveFunctionSimdType(BfIRSimdType_SSE);
|
||||
}
|
||||
else if (elemCount == 4)
|
||||
{
|
||||
funcName = isMin ? "llvm.x86.avx.min.pd.256" : "llvm.x86.avx.max.pd.256";
|
||||
SetActiveFunctionSimdType(BfIRSimdType_AVX2);
|
||||
}
|
||||
else if (elemCount == 8)
|
||||
{
|
||||
funcName = isMin ? "llvm.x86.avx512.min.pd.512" : "llvm.x86.avx512.max.pd.512";
|
||||
SetActiveFunctionSimdType(BfIRSimdType_AVX512);
|
||||
}
|
||||
else
|
||||
FatalError("Intrinsic argument error");
|
||||
}
|
||||
|
||||
auto func = mLLVMModule->getOrInsertFunction(funcName, vecType, vecType, vecType);
|
||||
|
||||
// Call intrinsic
|
||||
llvm::SmallVector<llvm::Value*, 2> args;
|
||||
args.push_back(val0);
|
||||
args.push_back(val1);
|
||||
|
||||
SetResult(curId, mIRBuilder->CreateCall(func, args));
|
||||
}
|
||||
break;
|
||||
case BfIRIntrinsic_Cpuid:
|
||||
{
|
||||
llvm::Type* elemType = llvm::Type::getInt32Ty(*mLLVMContext);
|
||||
|
@ -4883,6 +4972,55 @@ void BfIRCodeGen::SetConfigConst(int idx, int value)
|
|||
mConfigConsts64.Add(constVal);
|
||||
}
|
||||
|
||||
void BfIRCodeGen::SetActiveFunctionSimdType(BfIRSimdType type)
|
||||
{
|
||||
BfIRSimdType currentType;
|
||||
bool contains = mFunctionsUsingSimd.TryGetValue(mActiveFunction, ¤tType);
|
||||
|
||||
if (!contains || type > currentType)
|
||||
mFunctionsUsingSimd[mActiveFunction] = type;
|
||||
}
|
||||
|
||||
const StringImpl& BfIRCodeGen::GetSimdTypeString(BfIRSimdType type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case BfIRSimdType_SSE:
|
||||
return "+sse,+mmx";
|
||||
case BfIRSimdType_SSE2:
|
||||
return "+sse2,+sse,+mmx";
|
||||
case BfIRSimdType_AVX:
|
||||
return "+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
|
||||
case BfIRSimdType_AVX2:
|
||||
return "+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
|
||||
case BfIRSimdType_AVX512:
|
||||
return "+avx512f,+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
BfIRSimdType BfIRCodeGen::GetSimdTypeFromFunction(llvm::Function* function)
|
||||
{
|
||||
if (function->hasFnAttribute("target-features"))
|
||||
{
|
||||
auto str = function->getFnAttribute("target-features").getValueAsString();
|
||||
|
||||
if (str.contains("+avx512f"))
|
||||
return BfIRSimdType_AVX512;
|
||||
if (str.contains("+avx2"))
|
||||
return BfIRSimdType_AVX2;
|
||||
if (str.contains("+avx"))
|
||||
return BfIRSimdType_AVX;
|
||||
if (str.contains("+sse2"))
|
||||
return BfIRSimdType_SSE2;
|
||||
if (str.contains("+sse"))
|
||||
return BfIRSimdType_SSE;
|
||||
}
|
||||
|
||||
return BfIRSimdType_None;
|
||||
}
|
||||
|
||||
llvm::Value* BfIRCodeGen::GetLLVMValue(int id)
|
||||
{
|
||||
auto& result = mResults[id];
|
||||
|
@ -5473,6 +5611,8 @@ llvm::Expected<llvm::BitcodeModule> FindThinLTOModule(llvm::MemoryBufferRef MBRe
|
|||
|
||||
bool BfIRCodeGen::WriteObjectFile(const StringImpl& outFileName)
|
||||
{
|
||||
ApplySimdFeatures();
|
||||
|
||||
// {
|
||||
// PassManagerBuilderWrapper pmBuilder;
|
||||
//
|
||||
|
@ -5601,6 +5741,40 @@ bool BfIRCodeGen::WriteIR(const StringImpl& outFileName, StringImpl& error)
|
|||
return true;
|
||||
}
|
||||
|
||||
void BfIRCodeGen::ApplySimdFeatures()
|
||||
{
|
||||
Array<std::tuple<llvm::Function*, BfIRSimdType>> functionsToProcess;
|
||||
|
||||
for (auto pair : mFunctionsUsingSimd)
|
||||
functionsToProcess.Add({ pair.mKey, pair.mValue });
|
||||
|
||||
while (functionsToProcess.Count() > 0)
|
||||
{
|
||||
auto tuple = functionsToProcess.front();
|
||||
functionsToProcess.RemoveAt(0);
|
||||
|
||||
auto function = std::get<0>(tuple);
|
||||
auto simdType = std::get<1>(tuple);
|
||||
|
||||
auto currentSimdType = GetSimdTypeFromFunction(function);
|
||||
simdType = simdType > currentSimdType ? simdType : currentSimdType;
|
||||
|
||||
function->addFnAttr("target-features", GetSimdTypeString(simdType).c_str());
|
||||
|
||||
if (function->hasFnAttribute(llvm::Attribute::AlwaysInline))
|
||||
{
|
||||
for (auto user : function->users())
|
||||
{
|
||||
if (auto call = llvm::dyn_cast<llvm::CallInst>(user))
|
||||
{
|
||||
auto func = call->getFunction();
|
||||
functionsToProcess.Add({ func, simdType });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int BfIRCodeGen::GetIntrinsicId(const StringImpl& name)
|
||||
{
|
||||
auto itr = std::lower_bound(std::begin(gIntrinEntries), std::end(gIntrinEntries), name);
|
||||
|
|
|
@ -90,6 +90,16 @@ enum BfIRSizeAlignKind
|
|||
BfIRSizeAlignKind_Aligned,
|
||||
};
|
||||
|
||||
enum BfIRSimdType
|
||||
{
|
||||
BfIRSimdType_None,
|
||||
BfIRSimdType_SSE,
|
||||
BfIRSimdType_SSE2,
|
||||
BfIRSimdType_AVX,
|
||||
BfIRSimdType_AVX2,
|
||||
BfIRSimdType_AVX512
|
||||
};
|
||||
|
||||
class BfIRCodeGen : public BfIRCodeGenBase
|
||||
{
|
||||
public:
|
||||
|
@ -130,6 +140,7 @@ public:
|
|||
Dictionary<llvm::Type*, int> mTypeToTypeIdMap;
|
||||
HashSet<llvm::BasicBlock*> mLockedBlocks;
|
||||
OwnedArray<BfIRIntrinsicData> mIntrinsicData;
|
||||
Dictionary<llvm::Function*, BfIRSimdType> mFunctionsUsingSimd;
|
||||
|
||||
public:
|
||||
void InitTarget();
|
||||
|
@ -222,6 +233,10 @@ public:
|
|||
void SetCodeGenOptions(BfCodeGenOptions codeGenOptions);
|
||||
void SetConfigConst(int idx, int value) override;
|
||||
|
||||
void SetActiveFunctionSimdType(BfIRSimdType type);
|
||||
const StringImpl& GetSimdTypeString(BfIRSimdType type);
|
||||
BfIRSimdType GetSimdTypeFromFunction(llvm::Function* function);
|
||||
|
||||
llvm::Value* GetLLVMValue(int streamId);
|
||||
llvm::Type* GetLLVMType(int streamId);
|
||||
llvm::BasicBlock* GetLLVMBlock(int streamId);
|
||||
|
@ -234,6 +249,8 @@ public:
|
|||
bool WriteObjectFile(const StringImpl& outFileName);
|
||||
bool WriteIR(const StringImpl& outFileName, StringImpl& error);
|
||||
|
||||
void ApplySimdFeatures();
|
||||
|
||||
static int GetIntrinsicId(const StringImpl& name);
|
||||
static const char* GetIntrinsicName(int intrinId);
|
||||
static void SetAsmKind(BfAsmKind asmKind);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue