diff --git a/BeefLibs/corlib/src/Numerics/Float4.bf b/BeefLibs/corlib/src/Numerics/Float4.bf index e9b1dfb0..9927918f 100644 --- a/BeefLibs/corlib/src/Numerics/Float4.bf +++ b/BeefLibs/corlib/src/Numerics/Float4.bf @@ -27,6 +27,12 @@ namespace System.Numerics public extern float4 wzyx { [Intrinsic("shuffle3210")] get; [Intrinsic("shuffle3210")] set; } + [Intrinsic("min")] + public static extern float4 min(float4 lhs, float4 rhs); + + [Intrinsic("max")] + public static extern float4 max(float4 lhs, float4 rhs); + [Intrinsic("add")] public static extern float4 operator+(float4 lhs, float4 rhs); [Intrinsic("add"), Commutable] diff --git a/BeefLibs/corlib/src/Numerics/X86/SSE.bf b/BeefLibs/corlib/src/Numerics/X86/SSE.bf index ab771500..3c729924 100644 --- a/BeefLibs/corlib/src/Numerics/X86/SSE.bf +++ b/BeefLibs/corlib/src/Numerics/X86/SSE.bf @@ -4,6 +4,11 @@ namespace System.Numerics.X86 { public static bool IsSupported => Runtime.Features.SSE; + [Inline] + public static v128 min_ps(v128 a, v128 b) => (.) float4.min((.) a, (.) b); + [Inline] + public static v128 max_ps(v128 a, v128 b) => (.) float4.max((.) a, (.) b); + [Inline] public static v128 add_ss(v128 a, v128 b) { diff --git a/IDEHelper/Compiler/BfIRBuilder.h b/IDEHelper/Compiler/BfIRBuilder.h index ac358994..03d69c3c 100644 --- a/IDEHelper/Compiler/BfIRBuilder.h +++ b/IDEHelper/Compiler/BfIRBuilder.h @@ -456,9 +456,11 @@ enum BfIRIntrinsic : uint8 BfIRIntrinsic_Lt, BfIRIntrinsic_LtE, BfIRIntrinsic_Malloc, + BfIRIntrinsic_Max, BfIRIntrinsic_MemCpy, BfIRIntrinsic_MemMove, BfIRIntrinsic_MemSet, + BfIRIntrinsic_Min, BfIRIntrinsic_Mod, BfIRIntrinsic_Mul, BfIRIntrinsic_Neq, diff --git a/IDEHelper/Compiler/BfIRCodeGen.cpp b/IDEHelper/Compiler/BfIRCodeGen.cpp index 063c1d49..ae28322d 100644 --- a/IDEHelper/Compiler/BfIRCodeGen.cpp +++ b/IDEHelper/Compiler/BfIRCodeGen.cpp @@ -172,9 +172,11 @@ static const BuiltinEntry gIntrinEntries[] = {"lt"}, {"lte"}, {"malloc"}, + {"max"}, {"memcpy"}, {"memmove"}, {"memset"}, + {"min"}, {"mod"}, {"mul"}, {"neq"}, @@ -2860,10 +2862,12 @@ void BfIRCodeGen::HandleNextCmd() { llvm::Intrinsic::log2, 0, -1}, { (llvm::Intrinsic::ID)-2, -1}, // lt { (llvm::Intrinsic::ID)-2, -1}, // lte - { (llvm::Intrinsic::ID)-2}, // memset + { (llvm::Intrinsic::ID)-2}, // malloc + { (llvm::Intrinsic::ID)-2, -1}, // max { llvm::Intrinsic::memcpy, 0, 1, 2}, { llvm::Intrinsic::memmove, 0, 2}, { llvm::Intrinsic::memset, 0, 2}, + { (llvm::Intrinsic::ID)-2, -1}, // min { (llvm::Intrinsic::ID)-2, -1}, // mod { (llvm::Intrinsic::ID)-2, -1}, // mul { (llvm::Intrinsic::ID)-2, -1}, // neq @@ -3292,6 +3296,91 @@ void BfIRCodeGen::HandleNextCmd() } } break; + case BfIRIntrinsic_Min: + case BfIRIntrinsic_Max: + { + // Get arguments as vectors + auto val0 = TryToVector(args[0]); + if (val0 == NULL) + FatalError("Intrinsic argument error"); + + auto val1 = TryToVector(args[1]); + if (val1 == NULL) + FatalError("Intrinsic argument error"); + + // Make sure both argument types are the same + auto vecType = llvm::dyn_cast(val0->getType()); + if (vecType != llvm::dyn_cast(val1->getType())) + FatalError("Intrinsic argument error"); + + // Make sure the type is not scalable + if (vecType->getElementCount().isScalable()) + FatalError("Intrinsic argument error"); + + // Make sure the element type is either float or double + auto elemType = vecType->getElementType(); + if (!elemType->isFloatTy() && !elemType->isDoubleTy()) + FatalError("Intrinsic argument error"); + + // Get some properties for easier access + bool isFloat = elemType->isFloatTy(); + bool isMin = intrinsicData->mIntrinsic == BfIRIntrinsic_Min; + auto elemCount = vecType->getElementCount().getFixedValue(); + + // Get the intrinsic function + const char* funcName; + + if (isFloat) + { + if (elemCount == 4) + { + funcName = isMin ? "llvm.x86.sse.min.ps" : "llvm.x86.sse.max.ps"; + SetActiveFunctionSimdType(BfIRSimdType_SSE); + } + else if (elemCount == 8) + { + funcName = isMin ? "llvm.x86.avx.min.ps.256" : "llvm.x86.avx.max.ps.256"; + SetActiveFunctionSimdType(BfIRSimdType_AVX2); + } + else if (elemCount == 16) + { + funcName = isMin ? "llvm.x86.avx512.min.ps.512" : "llvm.x86.avx512.max.ps.512"; + SetActiveFunctionSimdType(BfIRSimdType_AVX512); + } + else + FatalError("Intrinsic argument error"); + } + else + { + if (elemCount == 2) + { + funcName = isMin ? "llvm.x86.sse.min.pd" : "llvm.x86.sse.max.pd"; + SetActiveFunctionSimdType(BfIRSimdType_SSE); + } + else if (elemCount == 4) + { + funcName = isMin ? "llvm.x86.avx.min.pd.256" : "llvm.x86.avx.max.pd.256"; + SetActiveFunctionSimdType(BfIRSimdType_AVX2); + } + else if (elemCount == 8) + { + funcName = isMin ? "llvm.x86.avx512.min.pd.512" : "llvm.x86.avx512.max.pd.512"; + SetActiveFunctionSimdType(BfIRSimdType_AVX512); + } + else + FatalError("Intrinsic argument error"); + } + + auto func = mLLVMModule->getOrInsertFunction(funcName, vecType, vecType, vecType); + + // Call intrinsic + llvm::SmallVector args; + args.push_back(val0); + args.push_back(val1); + + SetResult(curId, mIRBuilder->CreateCall(func, args)); + } + break; case BfIRIntrinsic_Cpuid: { llvm::Type* elemType = llvm::Type::getInt32Ty(*mLLVMContext); @@ -4883,6 +4972,55 @@ void BfIRCodeGen::SetConfigConst(int idx, int value) mConfigConsts64.Add(constVal); } +void BfIRCodeGen::SetActiveFunctionSimdType(BfIRSimdType type) +{ + BfIRSimdType currentType; + bool contains = mFunctionsUsingSimd.TryGetValue(mActiveFunction, ¤tType); + + if (!contains || type > currentType) + mFunctionsUsingSimd[mActiveFunction] = type; +} + +const StringImpl& BfIRCodeGen::GetSimdTypeString(BfIRSimdType type) +{ + switch (type) + { + case BfIRSimdType_SSE: + return "+sse,+mmx"; + case BfIRSimdType_SSE2: + return "+sse2,+sse,+mmx"; + case BfIRSimdType_AVX: + return "+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx"; + case BfIRSimdType_AVX2: + return "+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx"; + case BfIRSimdType_AVX512: + return "+avx512f,+avx2,+avx,+sse4.2,+sse4.1,+sse3,+sse2,+sse,+mmx"; + default: + return ""; + } +} + +BfIRSimdType BfIRCodeGen::GetSimdTypeFromFunction(llvm::Function* function) +{ + if (function->hasFnAttribute("target-features")) + { + auto str = function->getFnAttribute("target-features").getValueAsString(); + + if (str.contains("+avx512f")) + return BfIRSimdType_AVX512; + if (str.contains("+avx2")) + return BfIRSimdType_AVX2; + if (str.contains("+avx")) + return BfIRSimdType_AVX; + if (str.contains("+sse2")) + return BfIRSimdType_SSE2; + if (str.contains("+sse")) + return BfIRSimdType_SSE; + } + + return BfIRSimdType_None; +} + llvm::Value* BfIRCodeGen::GetLLVMValue(int id) { auto& result = mResults[id]; @@ -5473,6 +5611,8 @@ llvm::Expected FindThinLTOModule(llvm::MemoryBufferRef MBRe bool BfIRCodeGen::WriteObjectFile(const StringImpl& outFileName) { + ApplySimdFeatures(); + // { // PassManagerBuilderWrapper pmBuilder; // @@ -5601,6 +5741,40 @@ bool BfIRCodeGen::WriteIR(const StringImpl& outFileName, StringImpl& error) return true; } +void BfIRCodeGen::ApplySimdFeatures() +{ + Array> functionsToProcess; + + for (auto pair : mFunctionsUsingSimd) + functionsToProcess.Add({ pair.mKey, pair.mValue }); + + while (functionsToProcess.Count() > 0) + { + auto tuple = functionsToProcess.front(); + functionsToProcess.RemoveAt(0); + + auto function = std::get<0>(tuple); + auto simdType = std::get<1>(tuple); + + auto currentSimdType = GetSimdTypeFromFunction(function); + simdType = simdType > currentSimdType ? simdType : currentSimdType; + + function->addFnAttr("target-features", GetSimdTypeString(simdType).c_str()); + + if (function->hasFnAttribute(llvm::Attribute::AlwaysInline)) + { + for (auto user : function->users()) + { + if (auto call = llvm::dyn_cast(user)) + { + auto func = call->getFunction(); + functionsToProcess.Add({ func, simdType }); + } + } + } + } +} + int BfIRCodeGen::GetIntrinsicId(const StringImpl& name) { auto itr = std::lower_bound(std::begin(gIntrinEntries), std::end(gIntrinEntries), name); diff --git a/IDEHelper/Compiler/BfIRCodeGen.h b/IDEHelper/Compiler/BfIRCodeGen.h index 28773467..9303622f 100644 --- a/IDEHelper/Compiler/BfIRCodeGen.h +++ b/IDEHelper/Compiler/BfIRCodeGen.h @@ -90,6 +90,16 @@ enum BfIRSizeAlignKind BfIRSizeAlignKind_Aligned, }; +enum BfIRSimdType +{ + BfIRSimdType_None, + BfIRSimdType_SSE, + BfIRSimdType_SSE2, + BfIRSimdType_AVX, + BfIRSimdType_AVX2, + BfIRSimdType_AVX512 +}; + class BfIRCodeGen : public BfIRCodeGenBase { public: @@ -130,6 +140,7 @@ public: Dictionary mTypeToTypeIdMap; HashSet mLockedBlocks; OwnedArray mIntrinsicData; + Dictionary mFunctionsUsingSimd; public: void InitTarget(); @@ -222,6 +233,10 @@ public: void SetCodeGenOptions(BfCodeGenOptions codeGenOptions); void SetConfigConst(int idx, int value) override; + void SetActiveFunctionSimdType(BfIRSimdType type); + const StringImpl& GetSimdTypeString(BfIRSimdType type); + BfIRSimdType GetSimdTypeFromFunction(llvm::Function* function); + llvm::Value* GetLLVMValue(int streamId); llvm::Type* GetLLVMType(int streamId); llvm::BasicBlock* GetLLVMBlock(int streamId); @@ -234,6 +249,8 @@ public: bool WriteObjectFile(const StringImpl& outFileName); bool WriteIR(const StringImpl& outFileName, StringImpl& error); + void ApplySimdFeatures(); + static int GetIntrinsicId(const StringImpl& name); static const char* GetIntrinsicName(int intrinId); static void SetAsmKind(BfAsmKind asmKind);