1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-08 11:38:21 +02:00

More SIMD work

This commit is contained in:
Brian Fiete 2020-08-25 07:33:55 -07:00
parent b57cbe2d69
commit ca4b383339
19 changed files with 695 additions and 76 deletions

View file

@ -133,6 +133,7 @@ struct BuiltinEntry
static const BuiltinEntry gIntrinEntries[] =
{
{":PLATFORM"},
{"abs"},
{"add"},
{"and"},
@ -161,7 +162,8 @@ static const BuiltinEntry gIntrinEntries[] =
{"floor"},
{"free"},
{"gt"},
{"gte"},
{"gte"},
("index"),
{"log"},
{"log10"},
{"log2"},
@ -1094,6 +1096,71 @@ llvm::Value* BfIRCodeGen::TryToVector(llvm::Value* value)
return NULL;
}
llvm::Value* BfIRCodeGen::TryToVector(llvm::Value* value, llvm::Type* elemType)
{
auto valueType = value->getType();
if (auto vecType = llvm::dyn_cast<llvm::VectorType>(valueType))
{
if (vecType->getVectorElementType() == elemType)
return value;
//TODO: We need an alloca....
FatalError("Failed to get vector");
return value;
}
if (auto ptrType = llvm::dyn_cast<llvm::PointerType>(valueType))
{
auto ptrElemType = ptrType->getElementType();
if (auto arrType = llvm::dyn_cast<llvm::ArrayType>(ptrElemType))
{
auto vecType = llvm::VectorType::get(arrType->getArrayElementType(), (uint)arrType->getArrayNumElements());
auto vecPtrType = vecType->getPointerTo();
auto ptrVal0 = mIRBuilder->CreateBitCast(value, vecPtrType);
return mIRBuilder->CreateAlignedLoad(ptrVal0, 1);
}
if (auto vecType = llvm::dyn_cast<llvm::VectorType>(ptrElemType))
{
if (vecType->getVectorElementType() == elemType)
return mIRBuilder->CreateAlignedLoad(value, 1);
auto dataLayout = llvm::DataLayout(mLLVMModule);
int wantNumElements = (int)vecType->getVectorNumElements() * (int)dataLayout.getTypeSizeInBits(vecType->getVectorElementType()) / (int)dataLayout.getTypeSizeInBits(elemType);
auto newVecType = llvm::VectorType::get(elemType, wantNumElements);
auto vecPtrType = newVecType->getPointerTo();
auto ptrVal0 = mIRBuilder->CreateBitCast(value, vecPtrType);
return mIRBuilder->CreateAlignedLoad(ptrVal0, 1);
}
}
return NULL;
}
llvm::Type* BfIRCodeGen::GetElemType(llvm::Value* value)
{
auto valueType = value->getType();
if (auto vecType = llvm::dyn_cast<llvm::VectorType>(valueType))
return vecType->getVectorElementType();;
if (auto ptrType = llvm::dyn_cast<llvm::PointerType>(valueType))
{
auto ptrElemType = ptrType->getElementType();
if (auto arrType = llvm::dyn_cast<llvm::ArrayType>(ptrElemType))
return arrType->getArrayElementType();
if (auto vecType = llvm::dyn_cast<llvm::VectorType>(ptrElemType))
return vecType->getVectorElementType();
}
return NULL;
}
bool BfIRCodeGen::TryMemCpy(llvm::Value* ptr, llvm::Value* val)
{
auto valType = val->getType();
@ -1160,23 +1227,31 @@ bool BfIRCodeGen::TryVectorCpy(llvm::Value* ptr, llvm::Value* val)
if (ptr->getType()->getPointerElementType() == val->getType())
return false;
auto valType = val->getType();
auto vecType = llvm::dyn_cast<llvm::VectorType>(valType);
if (vecType == NULL)
if (!llvm::isa<llvm::VectorType>(val->getType()))
{
return false;
for (int i = 0; i < (int)vecType->getVectorNumElements(); i++)
{
auto extract = mIRBuilder->CreateExtractElement(val, i);
llvm::Value* gepArgs[] = {
llvm::ConstantInt::get(llvm::Type::getInt32Ty(*mLLVMContext), 0),
llvm::ConstantInt::get(llvm::Type::getInt32Ty(*mLLVMContext), i) };
auto gep = mIRBuilder->CreateInBoundsGEP(ptr, llvm::makeArrayRef(gepArgs));
mIRBuilder->CreateStore(extract, gep);
}
auto usePtr = mIRBuilder->CreateBitCast(ptr, val->getType()->getPointerTo());
mIRBuilder->CreateAlignedStore(val, usePtr, 1);
// auto valType = val->getType();
// auto vecType = llvm::dyn_cast<llvm::VectorType>(valType);
// if (vecType == NULL)
// return false;
//
// for (int i = 0; i < (int)vecType->getVectorNumElements(); i++)
// {
// auto extract = mIRBuilder->CreateExtractElement(val, i);
//
// llvm::Value* gepArgs[] = {
// llvm::ConstantInt::get(llvm::Type::getInt32Ty(*mLLVMContext), 0),
// llvm::ConstantInt::get(llvm::Type::getInt32Ty(*mLLVMContext), i) };
// auto gep = mIRBuilder->CreateInBoundsGEP(ptr, llvm::makeArrayRef(gepArgs));
//
// mIRBuilder->CreateStore(extract, gep);
// }
return true;
}
@ -2215,6 +2290,7 @@ void BfIRCodeGen::HandleNextCmd()
static _Intrinsics intrinsics[] =
{
{ (llvm::Intrinsic::ID)-1, -1}, // PLATFORM,
{ llvm::Intrinsic::fabs, 0, -1},
{ (llvm::Intrinsic::ID)-2, -1}, // add,
{ (llvm::Intrinsic::ID)-2, -1}, // and,
@ -2243,7 +2319,8 @@ void BfIRCodeGen::HandleNextCmd()
{ llvm::Intrinsic::floor, 0, -1},
{ (llvm::Intrinsic::ID)-2, -1}, // free
{ (llvm::Intrinsic::ID)-2, -1}, // gt
{ (llvm::Intrinsic::ID)-2, -1}, // gte
{ (llvm::Intrinsic::ID)-2, -1}, // gte
{ (llvm::Intrinsic::ID)-2, -1}, // index
{ llvm::Intrinsic::log, 0, -1},
{ llvm::Intrinsic::log10, 0, -1},
{ llvm::Intrinsic::log2, 0, -1},
@ -2269,21 +2346,6 @@ void BfIRCodeGen::HandleNextCmd()
};
BF_STATIC_ASSERT(BF_ARRAY_COUNT(intrinsics) == BfIRIntrinsic_COUNT);
bool isFakeIntrinsic = (int)intrinsics[intrinId].mID == -2;
if (isFakeIntrinsic)
{
auto intrinsicData = mAlloc.Alloc<BfIRIntrinsicData>();
intrinsicData->mName = intrinName;
intrinsicData->mIntrinsic = (BfIRIntrinsic)intrinId;
intrinsicData->mReturnType = returnType;
BfIRCodeGenEntry entry;
entry.mKind = BfIRCodeGenEntryKind_IntrinsicData;
entry.mIntrinsicData = intrinsicData;
mResults.TryAdd(curId, entry);
break;
}
CmdParamVec<llvm::Type*> useParams;
if (intrinsics[intrinId].mArg0 != -1)
{
@ -2298,11 +2360,55 @@ void BfIRCodeGen::HandleNextCmd()
}
}
BF_ASSERT(intrinsics[intrinId].mID != (llvm::Intrinsic::ID) - 1);
func = llvm::Intrinsic::getDeclaration(mLLVMModule, intrinsics[intrinId].mID, useParams);
bool isFakeIntrinsic = (int)intrinsics[intrinId].mID == -2;
if (isFakeIntrinsic)
{
auto intrinsicData = mAlloc.Alloc<BfIRIntrinsicData>();
intrinsicData->mName = intrinName;
intrinsicData->mIntrinsic = (BfIRIntrinsic)intrinId;
intrinsicData->mReturnType = returnType;
BfIRCodeGenEntry entry;
entry.mKind = BfIRCodeGenEntryKind_IntrinsicData;
entry.mIntrinsicData = intrinsicData;
mResults.TryAdd(curId, entry);
break;
}
if (intrinId == BfIRIntrinsic__PLATFORM)
{
int colonPos = (int)intrinName.IndexOf(':');
String platName = intrinName.Substring(0, colonPos);
String platIntrinName = intrinName.Substring(colonPos + 1);
if (platName.IsEmpty())
{
auto intrinsicData = mAlloc.Alloc<BfIRIntrinsicData>();
intrinsicData->mName = platIntrinName;
intrinsicData->mIntrinsic = BfIRIntrinsic__PLATFORM;
intrinsicData->mReturnType = returnType;
BfIRCodeGenEntry entry;
entry.mKind = BfIRCodeGenEntryKind_IntrinsicData;
entry.mIntrinsicData = intrinsicData;
mResults.TryAdd(curId, entry);
break;
}
llvm::Intrinsic::ID intrin = llvm::Intrinsic::getIntrinsicForGCCBuiltin(platName.c_str(), platIntrinName.c_str());
if ((int)intrin <= 0)
FatalError(StrFormat("Unable to find intrinsic '%s'", intrinName.c_str()));
else
func = llvm::Intrinsic::getDeclaration(mLLVMModule, intrinsics[intrinId].mID, useParams);
}
else
{
BF_ASSERT(intrinsics[intrinId].mID != (llvm::Intrinsic::ID)-1);
func = llvm::Intrinsic::getDeclaration(mLLVMModule, intrinsics[intrinId].mID, useParams);
}
mIntrinsicReverseMap[func] = intrinId;
SetResult(curId, func);
SetResult(curId, func);
}
break;
case BfIRCmd_CreateFunctionType:
@ -2414,6 +2520,23 @@ void BfIRCodeGen::HandleNextCmd()
switch (intrinsicData->mIntrinsic)
{
case BfIRIntrinsic__PLATFORM:
{
if (intrinsicData->mName == "add_ps")
{
auto val0 = TryToVector(args[0], llvm::Type::getFloatTy(*mLLVMContext));
auto val1 = TryToVector(args[0], llvm::Type::getFloatTy(*mLLVMContext));
//SetResult(curId, TryToVector(mIRBuilder->CreateFAdd(val0, val1), GetElemType(args[0])));
SetResult(curId, mIRBuilder->CreateFAdd(val0, val1));
}
else
{
FatalError(StrFormat("Unable to find intrinsic '%s'", intrinsicData->mName.c_str()));
}
}
break;
case BfIRIntrinsic_Add:
case BfIRIntrinsic_And:
case BfIRIntrinsic_Div:
@ -2453,7 +2576,11 @@ void BfIRCodeGen::HandleNextCmd()
{
auto ptrVal1 = mIRBuilder->CreateBitCast(args[1], vecType->getPointerTo());
val1 = mIRBuilder->CreateAlignedLoad(ptrVal1, 1);
}
}
else if (args[1]->getType()->isVectorTy())
{
val1 = args[1];
}
else
{
val1 = mIRBuilder->CreateInsertElement(llvm::UndefValue::get(vecType), args[1], (uint64)0);
@ -2644,6 +2771,19 @@ void BfIRCodeGen::HandleNextCmd()
}
}
break;
case BfIRIntrinsic_Index:
{
llvm::Value* gepArgs[] = {
llvm::ConstantInt::get(llvm::Type::getInt32Ty(*mLLVMContext), 0),
args[1] };
auto gep = mIRBuilder->CreateInBoundsGEP(args[0], llvm::makeArrayRef(gepArgs));
if (args.size() >= 3)
mIRBuilder->CreateStore(args[2], gep);
else
SetResult(curId, mIRBuilder->CreateLoad(gep));
}
break;
case BfIRIntrinsic_AtomicCmpStore:
case BfIRIntrinsic_AtomicCmpStore_Weak:
case BfIRIntrinsic_AtomicCmpXChg:
@ -3016,6 +3156,11 @@ void BfIRCodeGen::HandleNextCmd()
int intrinId = -1;
if (mIntrinsicReverseMap.TryGetValue(funcPtr, &intrinId))
{
if (intrinId == BfIRIntrinsic__PLATFORM)
{
NOP;
}
if (intrinId == BfIRIntrinsic_MemSet)
{
int align = 1;
@ -4876,6 +5021,9 @@ int BfIRCodeGen::GetIntrinsicId(const StringImpl& name)
if (name.StartsWith("shuffle"))
return BfIRIntrinsic_Shuffle;
if (name.Contains(':'))
return BfIRIntrinsic__PLATFORM;
return -1;
}