1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-10 20:42:21 +02:00

Merge pull request #1824 from MineGame159/simd_improvements

Simd improvements
This commit is contained in:
Brian Fiete 2023-04-17 11:47:11 -07:00 committed by GitHub
commit a1dbea2574
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 374 additions and 25 deletions

View file

@ -27,6 +27,12 @@ namespace System.Numerics
public extern float4 wzyx { [Intrinsic("shuffle3210")] get; [Intrinsic("shuffle3210")] set; }
[Intrinsic("min")]
public static extern float4 min(float4 lhs, float4 rhs);
[Intrinsic("max")]
public static extern float4 max(float4 lhs, float4 rhs);
[Intrinsic("add")]
public static extern float4 operator+(float4 lhs, float4 rhs);
[Intrinsic("add"), Commutable]

View file

@ -2,8 +2,21 @@ namespace System.Numerics.X86
{
static class SSE
{
[Intrinsic(":add_ps")]
public static extern v128 add_ps(v128 a, v128 b);
public static bool IsSupported => Runtime.Features.SSE;
[Inline]
public static v128 add_ps(v128 a, v128 b) => (.) ((float4) a + (float4) b);
[Inline]
public static v128 sub_ps(v128 a, v128 b) => (.) ((float4) a - (float4) b);
[Inline]
public static v128 mul_ps(v128 a, v128 b) => (.) ((float4) a * (float4) b);
[Inline]
public static v128 div_ps(v128 a, v128 b) => (.) ((float4) a / (float4) b);
[Inline]
public static v128 min_ps(v128 a, v128 b) => (.) float4.min((.) a, (.) b);
[Inline]
public static v128 max_ps(v128 a, v128 b) => (.) float4.max((.) a, (.) b);
[Inline]
public static v128 add_ss(v128 a, v128 b)
@ -99,8 +112,6 @@ namespace System.Numerics.X86
public static extern int32 cvt_ss2si(v128 a);
public static extern v128 div_ps(v128 a, v128 b);
public static extern v128 div_ss(v128 a, v128 b);
public static extern v128 loadu_ps(void* ptr);
@ -111,12 +122,8 @@ namespace System.Numerics.X86
public static extern v128 load_ps(void* ptr);
public static extern v128 max_ps(v128 a, v128 b);
public static extern v128 max_ss(v128 a, v128 b);
public static extern v128 min_ps(v128 a, v128 b);
public static extern v128 min_ss(v128 a, v128 b);
public static extern v128 movehl_ps(v128 a, v128 b);
@ -127,8 +134,6 @@ namespace System.Numerics.X86
public static extern v128 move_ss(v128 a, v128 b);
public static extern v128 mul_ps(v128 a, v128 b);
public static extern v128 mul_ss(v128 a, v128 b);
public static extern v128 or_ps(v128 a, v128 b);
@ -169,8 +174,6 @@ namespace System.Numerics.X86
public static extern void stream_ps(void* mem_addr, v128 a);
public static extern v128 sub_ps(v128 a, v128 b);
public static extern v128 sub_ss(v128 a, v128 b);
public static extern void TRANSPOSE4_PS(ref v128 row0, ref v128 row1, ref v128 row2, ref v128 row3);

View file

@ -2,5 +2,6 @@ namespace System.Numerics.X86
{
static class SSE2
{
public static bool IsSupported => Runtime.Features.SSE2;
}
}

View file

@ -6,6 +6,12 @@ using System.Collections;
namespace System
{
struct RuntimeFeatures
{
public bool SSE, SSE2;
public bool AVX, AVX2, AVX512;
}
[StaticInitPriority(101)]
static class Runtime
{
@ -359,6 +365,9 @@ namespace System
static List<ErrorHandler> sErrorHandlers ~ DeleteContainerAndItems!(_);
static bool sInsideErrorHandler;
static bool sQueriedFeatures = false;
static RuntimeFeatures sFeatures;
public static this()
{
BfRtCallbacks.sCallbacks.Init();
@ -466,5 +475,91 @@ namespace System
}
return .ContinueFailure;
}
public static RuntimeFeatures Features
{
get
{
if (!sQueriedFeatures)
{
#if BF_MACHINE_X86 || BF_MACHINE_X64
QueryFeaturesX86();
#else
sFeatures = .();
sQueriedFeatures = true;
#endif
}
return sFeatures;
}
}
#if BF_MACHINE_X86 || BF_MACHINE_X64
private static void QueryFeaturesX86()
{
sFeatures = .();
sQueriedFeatures = true;
uint32 _ = 0;
// 0: Basic information
uint32 maxBasicLeaf = 0;
cpuid(0, 0, &maxBasicLeaf, &_, &_, &_);
if (maxBasicLeaf < 1)
{
// Earlier Intel 486, CPUID not implemented
return;
}
// 1: Processor Info and Feature Bits
uint32 procInfoEcx = 0;
uint32 procInfoEdx = 0;
cpuid(1, 0, &_, &_, &procInfoEcx, &procInfoEdx);
sFeatures.SSE = (procInfoEdx & (1 << 25)) != 0;
sFeatures.SSE2 = (procInfoEdx & (1 << 26)) != 0;
// 7: Extended Features
uint32 extendedFeaturesEbx = 0;
cpuid(7, 0, &_, &extendedFeaturesEbx, &_, &_);
// `XSAVE` and `AVX` support:
if ((procInfoEcx & (1 << 26)) != 0)
{
// Here the CPU supports `XSAVE`
// Detect `OSXSAVE`, that is, whether the OS is AVX enabled and
// supports saving the state of the AVX/AVX2 vector registers on
// context-switches
if ((procInfoEcx & (1 << 27)) != 0)
{
// The OS must have signaled the CPU that it supports saving and restoring the
uint64 xcr0 = xgetbv(0);
bool avxSupport = (xcr0 & 6) == 6;
bool avx512Support = (xcr0 & 224) == 224;
// Only if the OS and the CPU support saving/restoring the AVX registers we enable `xsave` support
if (avxSupport)
{
sFeatures.AVX = (procInfoEcx & (1 << 28)) != 0;
sFeatures.AVX2 = (extendedFeaturesEbx & (1 << 5)) != 0;
// For AVX-512 the OS also needs to support saving/restoring
// the extended state, only then we enable AVX-512 support:
if (avx512Support)
sFeatures.AVX512 = (extendedFeaturesEbx & (1 << 16)) != 0;
}
}
}
}
[Intrinsic("cpuid")]
private static extern void cpuid(uint32 leaf, uint32 subleaf, uint32* eax, uint32* ebx, uint32* ecx, uint32* edx);
[Intrinsic("xgetbv")]
private static extern uint64 xgetbv(uint32 xcr);
#endif
}
}

View file

@ -441,6 +441,7 @@ enum BfIRIntrinsic : uint8
BfIRIntrinsic_BSwap,
BfIRIntrinsic_Cast,
BfIRIntrinsic_Cos,
BfIRIntrinsic_Cpuid,
BfIRIntrinsic_DebugTrap,
BfIRIntrinsic_Div,
BfIRIntrinsic_Eq,
@ -455,9 +456,11 @@ enum BfIRIntrinsic : uint8
BfIRIntrinsic_Lt,
BfIRIntrinsic_LtE,
BfIRIntrinsic_Malloc,
BfIRIntrinsic_Max,
BfIRIntrinsic_MemCpy,
BfIRIntrinsic_MemMove,
BfIRIntrinsic_MemSet,
BfIRIntrinsic_Min,
BfIRIntrinsic_Mod,
BfIRIntrinsic_Mul,
BfIRIntrinsic_Neq,
@ -477,6 +480,7 @@ enum BfIRIntrinsic : uint8
BfIRIntrinsic_VAArg,
BfIRIntrinsic_VAEnd,
BfIRIntrinsic_VAStart,
BfIRIntrinsic_Xgetbv,
BfIRIntrinsic_Xor,
BfIRIntrinsic_COUNT,

View file

@ -157,6 +157,7 @@ static const BuiltinEntry gIntrinEntries[] =
{"bswap"},
{"cast"},
{"cos"},
{"cpuid"},
{"debugtrap"},
{"div"},
{"eq"},
@ -171,9 +172,11 @@ static const BuiltinEntry gIntrinEntries[] =
{"lt"},
{"lte"},
{"malloc"},
{"max"},
{"memcpy"},
{"memmove"},
{"memset"},
{"min"},
{"mod"},
{"mul"},
{"neq"},
@ -193,6 +196,7 @@ static const BuiltinEntry gIntrinEntries[] =
{"va_arg"},
{"va_end"},
{"va_start"},
{"xgetbv"},
{"xor"},
};
@ -2844,6 +2848,7 @@ void BfIRCodeGen::HandleNextCmd()
{ llvm::Intrinsic::bswap, -1},
{ (llvm::Intrinsic::ID)-2, -1}, // cast,
{ llvm::Intrinsic::cos, 0, -1},
{ (llvm::Intrinsic::ID)-2, -1}, // cpuid
{ llvm::Intrinsic::debugtrap, -1}, // debugtrap,
{ (llvm::Intrinsic::ID)-2, -1}, // div
{ (llvm::Intrinsic::ID)-2, -1}, // eq
@ -2857,10 +2862,12 @@ void BfIRCodeGen::HandleNextCmd()
{ llvm::Intrinsic::log2, 0, -1},
{ (llvm::Intrinsic::ID)-2, -1}, // lt
{ (llvm::Intrinsic::ID)-2, -1}, // lte
{ (llvm::Intrinsic::ID)-2}, // memset
{ (llvm::Intrinsic::ID)-2}, // malloc
{ (llvm::Intrinsic::ID)-2, -1}, // max
{ llvm::Intrinsic::memcpy, 0, 1, 2},
{ llvm::Intrinsic::memmove, 0, 2},
{ llvm::Intrinsic::memset, 0, 2},
{ (llvm::Intrinsic::ID)-2, -1}, // min
{ (llvm::Intrinsic::ID)-2, -1}, // mod
{ (llvm::Intrinsic::ID)-2, -1}, // mul
{ (llvm::Intrinsic::ID)-2, -1}, // neq
@ -2880,6 +2887,7 @@ void BfIRCodeGen::HandleNextCmd()
{ (llvm::Intrinsic::ID)-2, -1}, // va_arg,
{ llvm::Intrinsic::vaend, -1}, // va_end,
{ llvm::Intrinsic::vastart, -1}, // va_start,
{ (llvm::Intrinsic::ID)-2, -1}, // xgetbv
{ (llvm::Intrinsic::ID)-2, -1}, // xor
};
BF_STATIC_ASSERT(BF_ARRAY_COUNT(intrinsics) == BfIRIntrinsic_COUNT);
@ -3068,18 +3076,7 @@ void BfIRCodeGen::HandleNextCmd()
{
case BfIRIntrinsic__PLATFORM:
{
if (intrinsicData->mName == "add_ps")
{
auto val0 = TryToVector(args[0], llvm::Type::getFloatTy(*mLLVMContext));
auto val1 = TryToVector(args[0], llvm::Type::getFloatTy(*mLLVMContext));
//SetResult(curId, TryToVector(mIRBuilder->CreateFAdd(val0, val1), GetElemType(args[0])));
SetResult(curId, mIRBuilder->CreateFAdd(val0, val1));
}
else
{
FatalError(StrFormat("Unable to find intrinsic '%s'", intrinsicData->mName.c_str()));
}
FatalError(StrFormat("Unable to find intrinsic '%s'", intrinsicData->mName.c_str()));
}
break;
@ -3299,6 +3296,147 @@ void BfIRCodeGen::HandleNextCmd()
}
}
break;
case BfIRIntrinsic_Min:
case BfIRIntrinsic_Max:
{
// Get arguments as vectors
auto val0 = TryToVector(args[0]);
if (val0 == NULL)
FatalError("Intrinsic argument error");
auto val1 = TryToVector(args[1]);
if (val1 == NULL)
FatalError("Intrinsic argument error");
// Make sure both argument types are the same
auto vecType = llvm::dyn_cast<llvm::VectorType>(val0->getType());
if (vecType != llvm::dyn_cast<llvm::VectorType>(val1->getType()))
FatalError("Intrinsic argument error");
// Make sure the type is not scalable
if (vecType->getElementCount().isScalable())
FatalError("Intrinsic argument error");
// Make sure the element type is either float or double
auto elemType = vecType->getElementType();
if (!elemType->isFloatTy() && !elemType->isDoubleTy())
FatalError("Intrinsic argument error");
// Get some properties for easier access
bool isFloat = elemType->isFloatTy();
bool isMin = intrinsicData->mIntrinsic == BfIRIntrinsic_Min;
auto elemCount = vecType->getElementCount().getFixedValue();
// Get the intrinsic function
const char* funcName;
if (isFloat)
{
if (elemCount == 4)
{
funcName = isMin ? "llvm.x86.sse.min.ps" : "llvm.x86.sse.max.ps";
SetActiveFunctionSimdType(BfIRSimdType_SSE);
}
else if (elemCount == 8)
{
funcName = isMin ? "llvm.x86.avx.min.ps.256" : "llvm.x86.avx.max.ps.256";
SetActiveFunctionSimdType(BfIRSimdType_AVX2);
}
else if (elemCount == 16)
{
funcName = isMin ? "llvm.x86.avx512.min.ps.512" : "llvm.x86.avx512.max.ps.512";
SetActiveFunctionSimdType(BfIRSimdType_AVX512);
}
else
FatalError("Intrinsic argument error");
}
else
{
if (elemCount == 2)
{
funcName = isMin ? "llvm.x86.sse.min.pd" : "llvm.x86.sse.max.pd";
SetActiveFunctionSimdType(BfIRSimdType_SSE);
}
else if (elemCount == 4)
{
funcName = isMin ? "llvm.x86.avx.min.pd.256" : "llvm.x86.avx.max.pd.256";
SetActiveFunctionSimdType(BfIRSimdType_AVX2);
}
else if (elemCount == 8)
{
funcName = isMin ? "llvm.x86.avx512.min.pd.512" : "llvm.x86.avx512.max.pd.512";
SetActiveFunctionSimdType(BfIRSimdType_AVX512);
}
else
FatalError("Intrinsic argument error");
}
auto func = mLLVMModule->getOrInsertFunction(funcName, vecType, vecType, vecType);
// Call intrinsic
llvm::SmallVector<llvm::Value*, 2> args;
args.push_back(val0);
args.push_back(val1);
SetResult(curId, mIRBuilder->CreateCall(func, args));
}
break;
case BfIRIntrinsic_Cpuid:
{
llvm::Type* elemType = llvm::Type::getInt32Ty(*mLLVMContext);
// Check argument errors
if (args.size() != 6 || !args[0]->getType()->isIntegerTy(32) || !args[1]->getType()->isIntegerTy(32))
FatalError("Intrinsic argument error");
for (int i = 2; i < 6; i++)
{
llvm::Type* type = args[i]->getType();
if (!type->isPointerTy() || !type->getPointerElementType()->isIntegerTy(32))
FatalError("Intrinsic argument error");
}
// Get asm return type
llvm::SmallVector<llvm::Type*, 4> asmReturnTypes;
asmReturnTypes.push_back(elemType);
asmReturnTypes.push_back(elemType);
asmReturnTypes.push_back(elemType);
asmReturnTypes.push_back(elemType);
llvm::Type* returnType = llvm::StructType::get(*mLLVMContext, asmReturnTypes);
// Get asm function
llvm::SmallVector<llvm::Type*, 2> funcParams;
funcParams.push_back(elemType);
funcParams.push_back(elemType);
llvm::FunctionType* funcType = llvm::FunctionType::get(returnType, funcParams, false);
llvm::InlineAsm* func = llvm::InlineAsm::get(funcType, "xchgq %rbx,${1:q}\ncpuid\nxchgq %rbx,${1:q}", "={ax},=r,={cx},={dx},0,2,~{dirflag},~{fpsr},~{flags}", false);
// Call asm function
llvm::SmallVector<llvm::Value*, 2> funcArgs;
funcArgs.push_back(args[0]);
funcArgs.push_back(args[1]);
llvm::Value* asmResult = mIRBuilder->CreateCall(func, funcArgs);
// Store results
mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 0), args[2]);
mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 1), args[3]);
mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 2), args[4]);
mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 3), args[5]);
}
break;
case BfIRIntrinsic_Xgetbv:
{
if (args.size() != 1 || !args[0]->getType()->isIntegerTy(32))
FatalError("Intrinsic argument error");
auto func = mLLVMModule->getOrInsertFunction("llvm.x86.xgetbv", llvm::Type::getInt64Ty(*mLLVMContext), llvm::Type::getInt32Ty(*mLLVMContext));
SetResult(curId, mIRBuilder->CreateCall(func, args[0]));
}
break;
case BfIRIntrinsic_Not:
{
auto val0 = TryToVector(args[0]);
@ -4834,6 +4972,55 @@ void BfIRCodeGen::SetConfigConst(int idx, int value)
mConfigConsts64.Add(constVal);
}
void BfIRCodeGen::SetActiveFunctionSimdType(BfIRSimdType type)
{
BfIRSimdType currentType;
bool contains = mFunctionsUsingSimd.TryGetValue(mActiveFunction, &currentType);
if (!contains || type > currentType)
mFunctionsUsingSimd[mActiveFunction] = type;
}
const StringImpl& BfIRCodeGen::GetSimdTypeString(BfIRSimdType type)
{
switch (type)
{
case BfIRSimdType_SSE:
return "+sse,+mmx";
case BfIRSimdType_SSE2:
return "+sse2,+sse,+mmx";
case BfIRSimdType_AVX:
return "+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
case BfIRSimdType_AVX2:
return "+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
case BfIRSimdType_AVX512:
return "+avx512f,+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx";
default:
return "";
}
}
BfIRSimdType BfIRCodeGen::GetSimdTypeFromFunction(llvm::Function* function)
{
if (function->hasFnAttribute("target-features"))
{
auto str = function->getFnAttribute("target-features").getValueAsString();
if (str.contains("+avx512f"))
return BfIRSimdType_AVX512;
if (str.contains("+avx2"))
return BfIRSimdType_AVX2;
if (str.contains("+avx"))
return BfIRSimdType_AVX;
if (str.contains("+sse2"))
return BfIRSimdType_SSE2;
if (str.contains("+sse"))
return BfIRSimdType_SSE;
}
return BfIRSimdType_None;
}
llvm::Value* BfIRCodeGen::GetLLVMValue(int id)
{
auto& result = mResults[id];
@ -5424,6 +5611,8 @@ llvm::Expected<llvm::BitcodeModule> FindThinLTOModule(llvm::MemoryBufferRef MBRe
bool BfIRCodeGen::WriteObjectFile(const StringImpl& outFileName)
{
ApplySimdFeatures();
// {
// PassManagerBuilderWrapper pmBuilder;
//
@ -5552,6 +5741,40 @@ bool BfIRCodeGen::WriteIR(const StringImpl& outFileName, StringImpl& error)
return true;
}
void BfIRCodeGen::ApplySimdFeatures()
{
Array<std::tuple<llvm::Function*, BfIRSimdType>> functionsToProcess;
for (auto pair : mFunctionsUsingSimd)
functionsToProcess.Add({ pair.mKey, pair.mValue });
while (functionsToProcess.Count() > 0)
{
auto tuple = functionsToProcess.front();
functionsToProcess.RemoveAt(0);
auto function = std::get<0>(tuple);
auto simdType = std::get<1>(tuple);
auto currentSimdType = GetSimdTypeFromFunction(function);
simdType = simdType > currentSimdType ? simdType : currentSimdType;
function->addFnAttr("target-features", GetSimdTypeString(simdType).c_str());
if (function->hasFnAttribute(llvm::Attribute::AlwaysInline))
{
for (auto user : function->users())
{
if (auto call = llvm::dyn_cast<llvm::CallInst>(user))
{
auto func = call->getFunction();
functionsToProcess.Add({ func, simdType });
}
}
}
}
}
int BfIRCodeGen::GetIntrinsicId(const StringImpl& name)
{
auto itr = std::lower_bound(std::begin(gIntrinEntries), std::end(gIntrinEntries), name);

View file

@ -90,6 +90,16 @@ enum BfIRSizeAlignKind
BfIRSizeAlignKind_Aligned,
};
enum BfIRSimdType
{
BfIRSimdType_None,
BfIRSimdType_SSE,
BfIRSimdType_SSE2,
BfIRSimdType_AVX,
BfIRSimdType_AVX2,
BfIRSimdType_AVX512
};
class BfIRCodeGen : public BfIRCodeGenBase
{
public:
@ -130,6 +140,7 @@ public:
Dictionary<llvm::Type*, int> mTypeToTypeIdMap;
HashSet<llvm::BasicBlock*> mLockedBlocks;
OwnedArray<BfIRIntrinsicData> mIntrinsicData;
Dictionary<llvm::Function*, BfIRSimdType> mFunctionsUsingSimd;
public:
void InitTarget();
@ -222,6 +233,10 @@ public:
void SetCodeGenOptions(BfCodeGenOptions codeGenOptions);
void SetConfigConst(int idx, int value) override;
void SetActiveFunctionSimdType(BfIRSimdType type);
const StringImpl& GetSimdTypeString(BfIRSimdType type);
BfIRSimdType GetSimdTypeFromFunction(llvm::Function* function);
llvm::Value* GetLLVMValue(int streamId);
llvm::Type* GetLLVMType(int streamId);
llvm::BasicBlock* GetLLVMBlock(int streamId);
@ -234,6 +249,8 @@ public:
bool WriteObjectFile(const StringImpl& outFileName);
bool WriteIR(const StringImpl& outFileName, StringImpl& error);
void ApplySimdFeatures();
static int GetIntrinsicId(const StringImpl& name);
static const char* GetIntrinsicName(int intrinId);
static void SetAsmKind(BfAsmKind asmKind);