diff --git a/BeefLibs/corlib/src/Numerics/Float4.bf b/BeefLibs/corlib/src/Numerics/Float4.bf index e9b1dfb0..9927918f 100644 --- a/BeefLibs/corlib/src/Numerics/Float4.bf +++ b/BeefLibs/corlib/src/Numerics/Float4.bf @@ -27,6 +27,12 @@ namespace System.Numerics public extern float4 wzyx { [Intrinsic("shuffle3210")] get; [Intrinsic("shuffle3210")] set; } + [Intrinsic("min")] + public static extern float4 min(float4 lhs, float4 rhs); + + [Intrinsic("max")] + public static extern float4 max(float4 lhs, float4 rhs); + [Intrinsic("add")] public static extern float4 operator+(float4 lhs, float4 rhs); [Intrinsic("add"), Commutable] diff --git a/BeefLibs/corlib/src/Numerics/X86/SSE.bf b/BeefLibs/corlib/src/Numerics/X86/SSE.bf index a996744c..abb6a4cf 100644 --- a/BeefLibs/corlib/src/Numerics/X86/SSE.bf +++ b/BeefLibs/corlib/src/Numerics/X86/SSE.bf @@ -2,8 +2,21 @@ namespace System.Numerics.X86 { static class SSE { - [Intrinsic(":add_ps")] - public static extern v128 add_ps(v128 a, v128 b); + public static bool IsSupported => Runtime.Features.SSE; + + [Inline] + public static v128 add_ps(v128 a, v128 b) => (.) ((float4) a + (float4) b); + [Inline] + public static v128 sub_ps(v128 a, v128 b) => (.) ((float4) a - (float4) b); + [Inline] + public static v128 mul_ps(v128 a, v128 b) => (.) ((float4) a * (float4) b); + [Inline] + public static v128 div_ps(v128 a, v128 b) => (.) ((float4) a / (float4) b); + + [Inline] + public static v128 min_ps(v128 a, v128 b) => (.) float4.min((.) a, (.) b); + [Inline] + public static v128 max_ps(v128 a, v128 b) => (.) float4.max((.) a, (.) b); [Inline] public static v128 add_ss(v128 a, v128 b) @@ -99,8 +112,6 @@ namespace System.Numerics.X86 public static extern int32 cvt_ss2si(v128 a); - public static extern v128 div_ps(v128 a, v128 b); - public static extern v128 div_ss(v128 a, v128 b); public static extern v128 loadu_ps(void* ptr); @@ -111,12 +122,8 @@ namespace System.Numerics.X86 public static extern v128 load_ps(void* ptr); - public static extern v128 max_ps(v128 a, v128 b); - public static extern v128 max_ss(v128 a, v128 b); - public static extern v128 min_ps(v128 a, v128 b); - public static extern v128 min_ss(v128 a, v128 b); public static extern v128 movehl_ps(v128 a, v128 b); @@ -127,8 +134,6 @@ namespace System.Numerics.X86 public static extern v128 move_ss(v128 a, v128 b); - public static extern v128 mul_ps(v128 a, v128 b); - public static extern v128 mul_ss(v128 a, v128 b); public static extern v128 or_ps(v128 a, v128 b); @@ -169,8 +174,6 @@ namespace System.Numerics.X86 public static extern void stream_ps(void* mem_addr, v128 a); - public static extern v128 sub_ps(v128 a, v128 b); - public static extern v128 sub_ss(v128 a, v128 b); public static extern void TRANSPOSE4_PS(ref v128 row0, ref v128 row1, ref v128 row2, ref v128 row3); diff --git a/BeefLibs/corlib/src/Numerics/X86/SSE2.bf b/BeefLibs/corlib/src/Numerics/X86/SSE2.bf index ec926555..4463f4e1 100644 --- a/BeefLibs/corlib/src/Numerics/X86/SSE2.bf +++ b/BeefLibs/corlib/src/Numerics/X86/SSE2.bf @@ -2,5 +2,6 @@ namespace System.Numerics.X86 { static class SSE2 { + public static bool IsSupported => Runtime.Features.SSE2; } } diff --git a/BeefLibs/corlib/src/Runtime.bf b/BeefLibs/corlib/src/Runtime.bf index 528a672c..5f856f20 100644 --- a/BeefLibs/corlib/src/Runtime.bf +++ b/BeefLibs/corlib/src/Runtime.bf @@ -6,6 +6,12 @@ using System.Collections; namespace System { + struct RuntimeFeatures + { + public bool SSE, SSE2; + public bool AVX, AVX2, AVX512; + } + [StaticInitPriority(101)] static class Runtime { @@ -359,6 +365,9 @@ namespace System static List sErrorHandlers ~ DeleteContainerAndItems!(_); static bool sInsideErrorHandler; + static bool sQueriedFeatures = false; + static RuntimeFeatures sFeatures; + public static this() { BfRtCallbacks.sCallbacks.Init(); @@ -466,5 +475,91 @@ namespace System } return .ContinueFailure; } + + public static RuntimeFeatures Features + { + get + { + if (!sQueriedFeatures) + { +#if BF_MACHINE_X86 || BF_MACHINE_X64 + QueryFeaturesX86(); +#else + sFeatures = .(); + sQueriedFeatures = true; +#endif + } + + return sFeatures; + } + } + +#if BF_MACHINE_X86 || BF_MACHINE_X64 + private static void QueryFeaturesX86() + { + sFeatures = .(); + sQueriedFeatures = true; + + uint32 _ = 0; + + // 0: Basic information + uint32 maxBasicLeaf = 0; + cpuid(0, 0, &maxBasicLeaf, &_, &_, &_); + + if (maxBasicLeaf < 1) + { + // Earlier Intel 486, CPUID not implemented + return; + } + + // 1: Processor Info and Feature Bits + uint32 procInfoEcx = 0; + uint32 procInfoEdx = 0; + cpuid(1, 0, &_, &_, &procInfoEcx, &procInfoEdx); + + sFeatures.SSE = (procInfoEdx & (1 << 25)) != 0; + sFeatures.SSE2 = (procInfoEdx & (1 << 26)) != 0; + + // 7: Extended Features + uint32 extendedFeaturesEbx = 0; + cpuid(7, 0, &_, &extendedFeaturesEbx, &_, &_); + + // `XSAVE` and `AVX` support: + if ((procInfoEcx & (1 << 26)) != 0) + { + // Here the CPU supports `XSAVE` + + // Detect `OSXSAVE`, that is, whether the OS is AVX enabled and + // supports saving the state of the AVX/AVX2 vector registers on + // context-switches + if ((procInfoEcx & (1 << 27)) != 0) + { + // The OS must have signaled the CPU that it supports saving and restoring the + uint64 xcr0 = xgetbv(0); + + bool avxSupport = (xcr0 & 6) == 6; + bool avx512Support = (xcr0 & 224) == 224; + + // Only if the OS and the CPU support saving/restoring the AVX registers we enable `xsave` support + if (avxSupport) + { + sFeatures.AVX = (procInfoEcx & (1 << 28)) != 0; + sFeatures.AVX2 = (extendedFeaturesEbx & (1 << 5)) != 0; + + // For AVX-512 the OS also needs to support saving/restoring + // the extended state, only then we enable AVX-512 support: + if (avx512Support) + sFeatures.AVX512 = (extendedFeaturesEbx & (1 << 16)) != 0; + } + } + } + } + + [Intrinsic("cpuid")] + private static extern void cpuid(uint32 leaf, uint32 subleaf, uint32* eax, uint32* ebx, uint32* ecx, uint32* edx); + + [Intrinsic("xgetbv")] + private static extern uint64 xgetbv(uint32 xcr); +#endif } } diff --git a/IDEHelper/Compiler/BfIRBuilder.h b/IDEHelper/Compiler/BfIRBuilder.h index 2374ea81..03d69c3c 100644 --- a/IDEHelper/Compiler/BfIRBuilder.h +++ b/IDEHelper/Compiler/BfIRBuilder.h @@ -441,6 +441,7 @@ enum BfIRIntrinsic : uint8 BfIRIntrinsic_BSwap, BfIRIntrinsic_Cast, BfIRIntrinsic_Cos, + BfIRIntrinsic_Cpuid, BfIRIntrinsic_DebugTrap, BfIRIntrinsic_Div, BfIRIntrinsic_Eq, @@ -455,9 +456,11 @@ enum BfIRIntrinsic : uint8 BfIRIntrinsic_Lt, BfIRIntrinsic_LtE, BfIRIntrinsic_Malloc, + BfIRIntrinsic_Max, BfIRIntrinsic_MemCpy, BfIRIntrinsic_MemMove, BfIRIntrinsic_MemSet, + BfIRIntrinsic_Min, BfIRIntrinsic_Mod, BfIRIntrinsic_Mul, BfIRIntrinsic_Neq, @@ -477,6 +480,7 @@ enum BfIRIntrinsic : uint8 BfIRIntrinsic_VAArg, BfIRIntrinsic_VAEnd, BfIRIntrinsic_VAStart, + BfIRIntrinsic_Xgetbv, BfIRIntrinsic_Xor, BfIRIntrinsic_COUNT, diff --git a/IDEHelper/Compiler/BfIRCodeGen.cpp b/IDEHelper/Compiler/BfIRCodeGen.cpp index ade51e4f..ae28322d 100644 --- a/IDEHelper/Compiler/BfIRCodeGen.cpp +++ b/IDEHelper/Compiler/BfIRCodeGen.cpp @@ -157,6 +157,7 @@ static const BuiltinEntry gIntrinEntries[] = {"bswap"}, {"cast"}, {"cos"}, + {"cpuid"}, {"debugtrap"}, {"div"}, {"eq"}, @@ -171,9 +172,11 @@ static const BuiltinEntry gIntrinEntries[] = {"lt"}, {"lte"}, {"malloc"}, + {"max"}, {"memcpy"}, {"memmove"}, {"memset"}, + {"min"}, {"mod"}, {"mul"}, {"neq"}, @@ -193,6 +196,7 @@ static const BuiltinEntry gIntrinEntries[] = {"va_arg"}, {"va_end"}, {"va_start"}, + {"xgetbv"}, {"xor"}, }; @@ -2844,6 +2848,7 @@ void BfIRCodeGen::HandleNextCmd() { llvm::Intrinsic::bswap, -1}, { (llvm::Intrinsic::ID)-2, -1}, // cast, { llvm::Intrinsic::cos, 0, -1}, + { (llvm::Intrinsic::ID)-2, -1}, // cpuid { llvm::Intrinsic::debugtrap, -1}, // debugtrap, { (llvm::Intrinsic::ID)-2, -1}, // div { (llvm::Intrinsic::ID)-2, -1}, // eq @@ -2857,10 +2862,12 @@ void BfIRCodeGen::HandleNextCmd() { llvm::Intrinsic::log2, 0, -1}, { (llvm::Intrinsic::ID)-2, -1}, // lt { (llvm::Intrinsic::ID)-2, -1}, // lte - { (llvm::Intrinsic::ID)-2}, // memset + { (llvm::Intrinsic::ID)-2}, // malloc + { (llvm::Intrinsic::ID)-2, -1}, // max { llvm::Intrinsic::memcpy, 0, 1, 2}, { llvm::Intrinsic::memmove, 0, 2}, { llvm::Intrinsic::memset, 0, 2}, + { (llvm::Intrinsic::ID)-2, -1}, // min { (llvm::Intrinsic::ID)-2, -1}, // mod { (llvm::Intrinsic::ID)-2, -1}, // mul { (llvm::Intrinsic::ID)-2, -1}, // neq @@ -2880,6 +2887,7 @@ void BfIRCodeGen::HandleNextCmd() { (llvm::Intrinsic::ID)-2, -1}, // va_arg, { llvm::Intrinsic::vaend, -1}, // va_end, { llvm::Intrinsic::vastart, -1}, // va_start, + { (llvm::Intrinsic::ID)-2, -1}, // xgetbv { (llvm::Intrinsic::ID)-2, -1}, // xor }; BF_STATIC_ASSERT(BF_ARRAY_COUNT(intrinsics) == BfIRIntrinsic_COUNT); @@ -3068,18 +3076,7 @@ void BfIRCodeGen::HandleNextCmd() { case BfIRIntrinsic__PLATFORM: { - if (intrinsicData->mName == "add_ps") - { - auto val0 = TryToVector(args[0], llvm::Type::getFloatTy(*mLLVMContext)); - auto val1 = TryToVector(args[0], llvm::Type::getFloatTy(*mLLVMContext)); - //SetResult(curId, TryToVector(mIRBuilder->CreateFAdd(val0, val1), GetElemType(args[0]))); - - SetResult(curId, mIRBuilder->CreateFAdd(val0, val1)); - } - else - { - FatalError(StrFormat("Unable to find intrinsic '%s'", intrinsicData->mName.c_str())); - } + FatalError(StrFormat("Unable to find intrinsic '%s'", intrinsicData->mName.c_str())); } break; @@ -3299,6 +3296,147 @@ void BfIRCodeGen::HandleNextCmd() } } break; + case BfIRIntrinsic_Min: + case BfIRIntrinsic_Max: + { + // Get arguments as vectors + auto val0 = TryToVector(args[0]); + if (val0 == NULL) + FatalError("Intrinsic argument error"); + + auto val1 = TryToVector(args[1]); + if (val1 == NULL) + FatalError("Intrinsic argument error"); + + // Make sure both argument types are the same + auto vecType = llvm::dyn_cast(val0->getType()); + if (vecType != llvm::dyn_cast(val1->getType())) + FatalError("Intrinsic argument error"); + + // Make sure the type is not scalable + if (vecType->getElementCount().isScalable()) + FatalError("Intrinsic argument error"); + + // Make sure the element type is either float or double + auto elemType = vecType->getElementType(); + if (!elemType->isFloatTy() && !elemType->isDoubleTy()) + FatalError("Intrinsic argument error"); + + // Get some properties for easier access + bool isFloat = elemType->isFloatTy(); + bool isMin = intrinsicData->mIntrinsic == BfIRIntrinsic_Min; + auto elemCount = vecType->getElementCount().getFixedValue(); + + // Get the intrinsic function + const char* funcName; + + if (isFloat) + { + if (elemCount == 4) + { + funcName = isMin ? "llvm.x86.sse.min.ps" : "llvm.x86.sse.max.ps"; + SetActiveFunctionSimdType(BfIRSimdType_SSE); + } + else if (elemCount == 8) + { + funcName = isMin ? "llvm.x86.avx.min.ps.256" : "llvm.x86.avx.max.ps.256"; + SetActiveFunctionSimdType(BfIRSimdType_AVX2); + } + else if (elemCount == 16) + { + funcName = isMin ? "llvm.x86.avx512.min.ps.512" : "llvm.x86.avx512.max.ps.512"; + SetActiveFunctionSimdType(BfIRSimdType_AVX512); + } + else + FatalError("Intrinsic argument error"); + } + else + { + if (elemCount == 2) + { + funcName = isMin ? "llvm.x86.sse.min.pd" : "llvm.x86.sse.max.pd"; + SetActiveFunctionSimdType(BfIRSimdType_SSE); + } + else if (elemCount == 4) + { + funcName = isMin ? "llvm.x86.avx.min.pd.256" : "llvm.x86.avx.max.pd.256"; + SetActiveFunctionSimdType(BfIRSimdType_AVX2); + } + else if (elemCount == 8) + { + funcName = isMin ? "llvm.x86.avx512.min.pd.512" : "llvm.x86.avx512.max.pd.512"; + SetActiveFunctionSimdType(BfIRSimdType_AVX512); + } + else + FatalError("Intrinsic argument error"); + } + + auto func = mLLVMModule->getOrInsertFunction(funcName, vecType, vecType, vecType); + + // Call intrinsic + llvm::SmallVector args; + args.push_back(val0); + args.push_back(val1); + + SetResult(curId, mIRBuilder->CreateCall(func, args)); + } + break; + case BfIRIntrinsic_Cpuid: + { + llvm::Type* elemType = llvm::Type::getInt32Ty(*mLLVMContext); + + // Check argument errors + if (args.size() != 6 || !args[0]->getType()->isIntegerTy(32) || !args[1]->getType()->isIntegerTy(32)) + FatalError("Intrinsic argument error"); + + for (int i = 2; i < 6; i++) + { + llvm::Type* type = args[i]->getType(); + + if (!type->isPointerTy() || !type->getPointerElementType()->isIntegerTy(32)) + FatalError("Intrinsic argument error"); + } + + // Get asm return type + llvm::SmallVector asmReturnTypes; + asmReturnTypes.push_back(elemType); + asmReturnTypes.push_back(elemType); + asmReturnTypes.push_back(elemType); + asmReturnTypes.push_back(elemType); + + llvm::Type* returnType = llvm::StructType::get(*mLLVMContext, asmReturnTypes); + + // Get asm function + llvm::SmallVector funcParams; + funcParams.push_back(elemType); + funcParams.push_back(elemType); + + llvm::FunctionType* funcType = llvm::FunctionType::get(returnType, funcParams, false); + llvm::InlineAsm* func = llvm::InlineAsm::get(funcType, "xchgq %rbx,${1:q}\ncpuid\nxchgq %rbx,${1:q}", "={ax},=r,={cx},={dx},0,2,~{dirflag},~{fpsr},~{flags}", false); + + // Call asm function + llvm::SmallVector funcArgs; + funcArgs.push_back(args[0]); + funcArgs.push_back(args[1]); + + llvm::Value* asmResult = mIRBuilder->CreateCall(func, funcArgs); + + // Store results + mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 0), args[2]); + mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 1), args[3]); + mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 2), args[4]); + mIRBuilder->CreateStore(mIRBuilder->CreateExtractValue(asmResult, 3), args[5]); + } + break; + case BfIRIntrinsic_Xgetbv: + { + if (args.size() != 1 || !args[0]->getType()->isIntegerTy(32)) + FatalError("Intrinsic argument error"); + + auto func = mLLVMModule->getOrInsertFunction("llvm.x86.xgetbv", llvm::Type::getInt64Ty(*mLLVMContext), llvm::Type::getInt32Ty(*mLLVMContext)); + SetResult(curId, mIRBuilder->CreateCall(func, args[0])); + } + break; case BfIRIntrinsic_Not: { auto val0 = TryToVector(args[0]); @@ -4834,6 +4972,55 @@ void BfIRCodeGen::SetConfigConst(int idx, int value) mConfigConsts64.Add(constVal); } +void BfIRCodeGen::SetActiveFunctionSimdType(BfIRSimdType type) +{ + BfIRSimdType currentType; + bool contains = mFunctionsUsingSimd.TryGetValue(mActiveFunction, ¤tType); + + if (!contains || type > currentType) + mFunctionsUsingSimd[mActiveFunction] = type; +} + +const StringImpl& BfIRCodeGen::GetSimdTypeString(BfIRSimdType type) +{ + switch (type) + { + case BfIRSimdType_SSE: + return "+sse,+mmx"; + case BfIRSimdType_SSE2: + return "+sse2,+sse,+mmx"; + case BfIRSimdType_AVX: + return "+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx"; + case BfIRSimdType_AVX2: + return "+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx"; + case BfIRSimdType_AVX512: + return "+avx512f,+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx"; + default: + return ""; + } +} + +BfIRSimdType BfIRCodeGen::GetSimdTypeFromFunction(llvm::Function* function) +{ + if (function->hasFnAttribute("target-features")) + { + auto str = function->getFnAttribute("target-features").getValueAsString(); + + if (str.contains("+avx512f")) + return BfIRSimdType_AVX512; + if (str.contains("+avx2")) + return BfIRSimdType_AVX2; + if (str.contains("+avx")) + return BfIRSimdType_AVX; + if (str.contains("+sse2")) + return BfIRSimdType_SSE2; + if (str.contains("+sse")) + return BfIRSimdType_SSE; + } + + return BfIRSimdType_None; +} + llvm::Value* BfIRCodeGen::GetLLVMValue(int id) { auto& result = mResults[id]; @@ -5424,6 +5611,8 @@ llvm::Expected FindThinLTOModule(llvm::MemoryBufferRef MBRe bool BfIRCodeGen::WriteObjectFile(const StringImpl& outFileName) { + ApplySimdFeatures(); + // { // PassManagerBuilderWrapper pmBuilder; // @@ -5552,6 +5741,40 @@ bool BfIRCodeGen::WriteIR(const StringImpl& outFileName, StringImpl& error) return true; } +void BfIRCodeGen::ApplySimdFeatures() +{ + Array> functionsToProcess; + + for (auto pair : mFunctionsUsingSimd) + functionsToProcess.Add({ pair.mKey, pair.mValue }); + + while (functionsToProcess.Count() > 0) + { + auto tuple = functionsToProcess.front(); + functionsToProcess.RemoveAt(0); + + auto function = std::get<0>(tuple); + auto simdType = std::get<1>(tuple); + + auto currentSimdType = GetSimdTypeFromFunction(function); + simdType = simdType > currentSimdType ? simdType : currentSimdType; + + function->addFnAttr("target-features", GetSimdTypeString(simdType).c_str()); + + if (function->hasFnAttribute(llvm::Attribute::AlwaysInline)) + { + for (auto user : function->users()) + { + if (auto call = llvm::dyn_cast(user)) + { + auto func = call->getFunction(); + functionsToProcess.Add({ func, simdType }); + } + } + } + } +} + int BfIRCodeGen::GetIntrinsicId(const StringImpl& name) { auto itr = std::lower_bound(std::begin(gIntrinEntries), std::end(gIntrinEntries), name); diff --git a/IDEHelper/Compiler/BfIRCodeGen.h b/IDEHelper/Compiler/BfIRCodeGen.h index 28773467..9303622f 100644 --- a/IDEHelper/Compiler/BfIRCodeGen.h +++ b/IDEHelper/Compiler/BfIRCodeGen.h @@ -90,6 +90,16 @@ enum BfIRSizeAlignKind BfIRSizeAlignKind_Aligned, }; +enum BfIRSimdType +{ + BfIRSimdType_None, + BfIRSimdType_SSE, + BfIRSimdType_SSE2, + BfIRSimdType_AVX, + BfIRSimdType_AVX2, + BfIRSimdType_AVX512 +}; + class BfIRCodeGen : public BfIRCodeGenBase { public: @@ -130,6 +140,7 @@ public: Dictionary mTypeToTypeIdMap; HashSet mLockedBlocks; OwnedArray mIntrinsicData; + Dictionary mFunctionsUsingSimd; public: void InitTarget(); @@ -222,6 +233,10 @@ public: void SetCodeGenOptions(BfCodeGenOptions codeGenOptions); void SetConfigConst(int idx, int value) override; + void SetActiveFunctionSimdType(BfIRSimdType type); + const StringImpl& GetSimdTypeString(BfIRSimdType type); + BfIRSimdType GetSimdTypeFromFunction(llvm::Function* function); + llvm::Value* GetLLVMValue(int streamId); llvm::Type* GetLLVMType(int streamId); llvm::BasicBlock* GetLLVMBlock(int streamId); @@ -234,6 +249,8 @@ public: bool WriteObjectFile(const StringImpl& outFileName); bool WriteIR(const StringImpl& outFileName, StringImpl& error); + void ApplySimdFeatures(); + static int GetIntrinsicId(const StringImpl& name); static const char* GetIntrinsicName(int intrinId); static void SetAsmKind(BfAsmKind asmKind);