From 4c5c89bab51cdd0ae7a494f050f897152369add0 Mon Sep 17 00:00:00 2001 From: Brian Fiete Date: Thu, 17 Feb 2022 05:51:05 -0500 Subject: [PATCH] Comptime GetCustomAttribute for type/field/method --- BeefLibs/corlib/src/Reflection/FieldInfo.bf | 7 ++ BeefLibs/corlib/src/Reflection/MethodInfo.bf | 5 ++ BeefLibs/corlib/src/Type.bf | 2 + IDEHelper/Compiler/BfModuleTypeUtils.cpp | 2 +- IDEHelper/Compiler/BfStmtEvaluator.cpp | 1 + IDEHelper/Compiler/CeMachine.cpp | 79 +++++++++++++++++--- IDEHelper/Compiler/CeMachine.h | 6 +- IDEHelper/Tests/src/Comptime.bf | 21 ++++++ 8 files changed, 110 insertions(+), 13 deletions(-) diff --git a/BeefLibs/corlib/src/Reflection/FieldInfo.bf b/BeefLibs/corlib/src/Reflection/FieldInfo.bf index 88e79c34..e656455e 100644 --- a/BeefLibs/corlib/src/Reflection/FieldInfo.bf +++ b/BeefLibs/corlib/src/Reflection/FieldInfo.bf @@ -229,6 +229,13 @@ namespace System.Reflection public Result GetCustomAttribute() where T : Attribute { + if (Compiler.IsComptime) + { + T val = ?; + if (Type.[Friend]Comptime_Field_GetCustomAttribute((int32)mTypeInstance.TypeId, mFieldData.mCustomAttributesIdx, (.)typeof(T).TypeId, &val)) + return val; + return .Err; + } return mTypeInstance.[Friend]GetCustomAttribute(mFieldData.mCustomAttributesIdx); } diff --git a/BeefLibs/corlib/src/Reflection/MethodInfo.bf b/BeefLibs/corlib/src/Reflection/MethodInfo.bf index 922a27d1..4524be83 100644 --- a/BeefLibs/corlib/src/Reflection/MethodInfo.bf +++ b/BeefLibs/corlib/src/Reflection/MethodInfo.bf @@ -95,7 +95,12 @@ namespace System.Reflection public Result GetCustomAttribute() where T : Attribute { if (Compiler.IsComptime) + { + T val = ?; + if (Type.[Friend]Comptime_Method_GetCustomAttribute(mData.mComptimeMethodInstance, (.)typeof(T).TypeId, &val)) + return val; return .Err; + } return mTypeInstance.[Friend]GetCustomAttribute(mData.mMethodData.mCustomAttributesIdx); } diff --git a/BeefLibs/corlib/src/Type.bf b/BeefLibs/corlib/src/Type.bf index 710ebc3e..c93e2ce4 100644 --- a/BeefLibs/corlib/src/Type.bf +++ b/BeefLibs/corlib/src/Type.bf @@ -529,6 +529,8 @@ namespace System static extern String Comptime_Type_ToString(int32 typeId); static extern Type Comptime_GetSpecializedType(Type unspecializedType, Span typeArgs); static extern bool Comptime_Type_GetCustomAttribute(int32 typeId, int32 attributeId, void* dataPtr); + static extern bool Comptime_Field_GetCustomAttribute(int32 typeId, int32 fieldIdx, int32 attributeId, void* dataPtr); + static extern bool Comptime_Method_GetCustomAttribute(int64 methodHandle, int32 attributeId, void* dataPtr); static extern int32 Comptime_GetMethodCount(int32 typeId); static extern int64 Comptime_GetMethod(int32 typeId, int32 methodIdx); static extern String Comptime_Method_ToString(int64 methodHandle); diff --git a/IDEHelper/Compiler/BfModuleTypeUtils.cpp b/IDEHelper/Compiler/BfModuleTypeUtils.cpp index f1137ba1..b2fb6d78 100644 --- a/IDEHelper/Compiler/BfModuleTypeUtils.cpp +++ b/IDEHelper/Compiler/BfModuleTypeUtils.cpp @@ -2222,7 +2222,7 @@ void BfModule::HandleCEAttributes(CeEmitContext* ceEmitContext, BfTypeInstance* SetAndRestoreValue prevEmitContext(mCompiler->mCEMachine->mCurEmitContext, ceEmitContext); auto ceContext = mCompiler->mCEMachine->AllocContext(); - BfIRValue attrVal = ceContext->CreateAttribute(customAttribute.mRef, this, typeInstance->mConstHolder, &customAttribute); + BfIRValue attrVal =ceContext->CreateAttribute(customAttribute.mRef, this, typeInstance->mConstHolder, &customAttribute); for (int baseIdx = 0; baseIdx < checkDepth; baseIdx++) attrVal = mBfIRBuilder->CreateExtractValue(attrVal, 0); diff --git a/IDEHelper/Compiler/BfStmtEvaluator.cpp b/IDEHelper/Compiler/BfStmtEvaluator.cpp index bb7579ba..363411c3 100644 --- a/IDEHelper/Compiler/BfStmtEvaluator.cpp +++ b/IDEHelper/Compiler/BfStmtEvaluator.cpp @@ -1422,6 +1422,7 @@ BfLocalVariable* BfModule::HandleVariableDeclaration(BfVariableDeclaration* varD auto eqVal = mBfIRBuilder->CreateCmpEQ(dscVal, GetConstValue(tagId, dscType)); exprEvaluator->mResult = BfTypedValue(eqVal, boolType); + PopulateType(outType); if (!outType->IsValuelessType()) { auto outPtrType = CreatePointerType(outType); diff --git a/IDEHelper/Compiler/CeMachine.cpp b/IDEHelper/Compiler/CeMachine.cpp index e8c63fac..faadf893 100644 --- a/IDEHelper/Compiler/CeMachine.cpp +++ b/IDEHelper/Compiler/CeMachine.cpp @@ -3458,7 +3458,7 @@ bool CeContext::GetStringFromStringView(addr_ce addr, StringImpl& str) return true; } -bool CeContext::GetCustomAttribute(BfCustomAttributes* customAttributes, int attributeTypeId, addr_ce resultAddr) +bool CeContext::GetCustomAttribute(BfModule* module, BfIRConstHolder* constHolder, BfCustomAttributes* customAttributes, int attributeTypeId, addr_ce resultAddr) { if (customAttributes == NULL) return false; @@ -3470,17 +3470,26 @@ bool CeContext::GetCustomAttribute(BfCustomAttributes* customAttributes, int att auto customAttr = customAttributes->Get(attributeType); if (customAttr == NULL) return false; - - if (resultAddr != 0) + + auto ceContext = mCeMachine->AllocContext(); + BfIRValue foreignValue = ceContext->CreateAttribute(mCurTargetSrc, module, constHolder, customAttr); + auto foreignConstant = module->mBfIRBuilder->GetConstant(foreignValue); + if (foreignConstant->mConstType == BfConstType_AggCE) { - - } + auto constAggData = (BfConstantAggCE*)foreignConstant; + auto value = ceContext->CreateConstant(module, ceContext->mMemory.mVals + constAggData->mCEAddr, customAttr->mType); + if (!value) + Fail("Failed to encoded attribute"); + auto attrConstant = module->mBfIRBuilder->GetConstant(value); + if (!WriteConstant(module, resultAddr, attrConstant, customAttr->mType)) + Fail("Failed to decode attribute"); + } + + mCeMachine->ReleaseContext(ceContext); return true; } - - //#define CE_GETC(T) *((T*)(addr += sizeof(T)) - 1) #define CE_GETC(T) *(T*)(mMemory.mVals + addr) @@ -4172,12 +4181,13 @@ BfIRValue CeContext::CreateConstant(BfModule* module, uint8* ptr, BfType* bfType return BfIRValue(); } -BfIRValue CeContext::CreateAttribute(BfAstNode* targetSrc, BfModule* module, BfIRConstHolder* constHolder, BfCustomAttribute* customAttribute) +BfIRValue CeContext::CreateAttribute(BfAstNode* targetSrc, BfModule* module, BfIRConstHolder* constHolder, BfCustomAttribute* customAttribute, addr_ce ceAttrAddr) { SetAndRestoreValue prevIgnoreWrites(module->mBfIRBuilder->mIgnoreWrites, true); module->mContext->mUnreifiedModule->PopulateType(customAttribute->mType); - auto ceAttrAddr = CeMalloc(customAttribute->mType->mSize) - mMemory.mVals; + if (ceAttrAddr == 0) + ceAttrAddr = CeMalloc(customAttribute->mType->mSize) - mMemory.mVals; BfIRValue ceAttrVal = module->mBfIRBuilder->CreateConstAggCE(module->mBfIRBuilder->MapType(customAttribute->mType, BfIRPopulateType_Identity), ceAttrAddr); BfTypedValue ceAttrTypedValue(ceAttrVal, customAttribute->mType); @@ -5115,11 +5125,52 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8* { auto typeInst = type->ToTypeInstance(); if (typeInst != NULL) - success = GetCustomAttribute(typeInst->mCustomAttributes, attributeTypeId, resultPtr); + success = GetCustomAttribute(mCurModule, typeInst->mConstHolder, typeInst->mCustomAttributes, attributeTypeId, resultPtr); } *(addr_ce*)(stackPtr + 0) = success; } + else if (checkFunction->mFunctionKind == CeFunctionKind_Field_GetCustomAttribute) + { + int32 typeId = *(int32*)((uint8*)stackPtr + 1); + int32 fieldIdx = *(int32*)((uint8*)stackPtr + 1 + 4); + int32 attributeTypeId = *(int32*)((uint8*)stackPtr + 1 + 4 + 4); + addr_ce resultPtr = *(addr_ce*)((uint8*)stackPtr + 1 + 4 + 4 + 4); + + BfType* type = GetBfType(typeId); + bool success = false; + if (type != NULL) + { + auto typeInst = type->ToTypeInstance(); + if (typeInst != NULL) + { + if (typeInst->mDefineState < BfTypeDefineState_CETypeInit) + mCurModule->PopulateType(typeInst); + if (fieldIdx < typeInst->mFieldInstances.mSize) + { + auto& fieldInstance = typeInst->mFieldInstances[fieldIdx]; + success = GetCustomAttribute(mCurModule, typeInst->mConstHolder, fieldInstance.mCustomAttributes, attributeTypeId, resultPtr); + } + } + } + + *(addr_ce*)(stackPtr + 0) = success; + } + else if (checkFunction->mFunctionKind == CeFunctionKind_Method_GetCustomAttribute) + { + int64 methodHandle = *(int64*)((uint8*)stackPtr + 1); + int32 attributeTypeId = *(int32*)((uint8*)stackPtr + 1 + 8); + addr_ce resultPtr = *(addr_ce*)((uint8*)stackPtr + 1 + 8 + 4); + + auto methodInstance = mCeMachine->GetMethodInstance(methodHandle); + if (methodInstance == NULL) + { + _Fail("Invalid method instance"); + return false; + } + bool success = GetCustomAttribute(mCurModule, methodInstance->GetOwner()->mConstHolder, methodInstance->GetCustomAttributes(), attributeTypeId, resultPtr); + *(addr_ce*)(stackPtr + 0) = success; + } else if (checkFunction->mFunctionKind == CeFunctionKind_GetMethodCount) { int32 typeId = *(int32*)((uint8*)stackPtr + 4); @@ -8070,6 +8121,14 @@ void CeMachine::CheckFunctionKind(CeFunction* ceFunction) { ceFunction->mFunctionKind = CeFunctionKind_Type_GetCustomAttribute; } + else if (methodDef->mName == "Comptime_Field_GetCustomAttribute") + { + ceFunction->mFunctionKind = CeFunctionKind_Field_GetCustomAttribute; + } + else if (methodDef->mName == "Comptime_Method_GetCustomAttribute") + { + ceFunction->mFunctionKind = CeFunctionKind_Method_GetCustomAttribute; + } else if (methodDef->mName == "Comptime_GetMethod") { ceFunction->mFunctionKind = CeFunctionKind_GetMethod; diff --git a/IDEHelper/Compiler/CeMachine.h b/IDEHelper/Compiler/CeMachine.h index 61aedcc8..58808656 100644 --- a/IDEHelper/Compiler/CeMachine.h +++ b/IDEHelper/Compiler/CeMachine.h @@ -326,6 +326,8 @@ enum CeFunctionKind CeFunctionKind_GetReflectSpecializedType, CeFunctionKind_Type_ToString, CeFunctionKind_Type_GetCustomAttribute, + CeFunctionKind_Field_GetCustomAttribute, + CeFunctionKind_Method_GetCustomAttribute, CeFunctionKind_GetMethodCount, CeFunctionKind_GetMethod, CeFunctionKind_Method_ToString, @@ -906,11 +908,11 @@ public: bool CheckMemory(addr_ce addr, int32 size); bool GetStringFromAddr(addr_ce strInstAddr, StringImpl& str); bool GetStringFromStringView(addr_ce addr, StringImpl& str); - bool GetCustomAttribute(BfCustomAttributes* customAttributes, int attributeTypeId, addr_ce resultAddr); + bool GetCustomAttribute(BfModule* module, BfIRConstHolder* constHolder, BfCustomAttributes* customAttributes, int attributeTypeId, addr_ce resultAddr); bool WriteConstant(BfModule* module, addr_ce addr, BfConstant* constant, BfType* type, bool isParams = false); BfIRValue CreateConstant(BfModule* module, uint8* ptr, BfType* type, BfType** outType = NULL); - BfIRValue CreateAttribute(BfAstNode* targetSrc, BfModule* module, BfIRConstHolder* constHolder, BfCustomAttribute* customAttribute); + BfIRValue CreateAttribute(BfAstNode* targetSrc, BfModule* module, BfIRConstHolder* constHolder, BfCustomAttribute* customAttribute, addr_ce ceAttrAddr = 0); bool Execute(CeFunction* startFunction, uint8* startStackPtr, uint8* startFramePtr, BfType*& returnType); BfTypedValue Call(BfAstNode* targetSrc, BfModule* module, BfMethodInstance* methodInstance, const BfSizedArray& args, CeEvalFlags flags, BfType* expectingType); diff --git a/IDEHelper/Tests/src/Comptime.bf b/IDEHelper/Tests/src/Comptime.bf index b327b839..725082f1 100644 --- a/IDEHelper/Tests/src/Comptime.bf +++ b/IDEHelper/Tests/src/Comptime.bf @@ -7,6 +7,21 @@ namespace Tests { class Comptime { + [AttributeUsage(.All)] + struct AddFieldAttribute : Attribute + { + public Type mType; + public String mName; + public int mVal; + + public this(Type type, String name, int val) + { + mType = type; + mName = name; + mVal = val; + } + } + [AttributeUsage(.All)] struct IFaceAAttribute : Attribute, IComptimeTypeApply { @@ -33,6 +48,7 @@ namespace Tests Compiler.EmitTypeBody(type, scope $""" public int32 m{mMemberName} = {mInitVal}; public int32 GetVal{mMemberName}() => mC; + """); } } @@ -65,6 +81,7 @@ namespace Tests } } + [AddField(typeof(float), "D", 4)] [IFaceA("C", InitVal=345)] class ClassA { @@ -76,7 +93,10 @@ namespace Tests Compiler.EmitTypeBody(typeof(Self), """ public int32 mB = 234; public int32 GetValB() => mB; + """); + if (var addFieldAttr = typeof(Self).GetCustomAttribute()) + Compiler.EmitTypeBody(typeof(Self), scope $"public {addFieldAttr.mType} {addFieldAttr.mName} = {addFieldAttr.mVal};"); } } @@ -469,6 +489,7 @@ namespace Tests Test.Assert(ca.GetValB() == 234); Test.Assert(ca.mC == 345); Test.Assert(ca.GetValC() == 345); + Test.Assert(ca.D == 4); StructA sa = .(); Test.Assert(sa.mA == 123);