From a35618651466aee43db9c74c08b9d8003a72bf08 Mon Sep 17 00:00:00 2001 From: Brian Fiete Date: Mon, 22 Jan 2024 08:12:15 -0500 Subject: [PATCH] Fixed comptime reflected static field accesses --- BeefLibs/corlib/src/Reflection/FieldInfo.bf | 8 ++ BeefLibs/corlib/src/Type.bf | 1 + IDEHelper/Compiler/BfModule.cpp | 9 +- IDEHelper/Compiler/BfModuleTypeUtils.cpp | 12 ++ IDEHelper/Compiler/CeMachine.cpp | 142 +++++++++++++++++--- IDEHelper/Compiler/CeMachine.h | 2 + IDEHelper/Tests/src/Comptime.bf | 19 +++ 7 files changed, 177 insertions(+), 16 deletions(-) diff --git a/BeefLibs/corlib/src/Reflection/FieldInfo.bf b/BeefLibs/corlib/src/Reflection/FieldInfo.bf index fbead848..bdd92491 100644 --- a/BeefLibs/corlib/src/Reflection/FieldInfo.bf +++ b/BeefLibs/corlib/src/Reflection/FieldInfo.bf @@ -327,6 +327,14 @@ namespace System.Reflection if (!mFieldData.mFlags.HasFlag(FieldFlags.Static)) return .Err(.InvalidTargetType); + if (Compiler.IsComptime) + { + void* dataPtr = Type.[Friend]Comptime_Field_GetStatic((int32)mTypeInstance.TypeId, (int32)mFieldData.mData); + if (dataPtr != null) + value = *(TMember*)dataPtr; + return .Ok; + } + targetDataAddr = null; } else diff --git a/BeefLibs/corlib/src/Type.bf b/BeefLibs/corlib/src/Type.bf index fc2bb3c6..0b134109 100644 --- a/BeefLibs/corlib/src/Type.bf +++ b/BeefLibs/corlib/src/Type.bf @@ -567,6 +567,7 @@ namespace System static extern Type Comptime_Method_GetGenericArg(int64 methodHandle, int32 genericArgIdx); static extern String Comptime_Field_GetName(int64 fieldHandle); static extern ComptimeFieldInfo Comptime_Field_GetInfo(int64 fieldHandle); + static extern void* Comptime_Field_GetStatic(int32 typeId, int32 fieldIdx); protected static Type GetType(TypeId typeId) { diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index cc2ef6e4..4ee2286f 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -5859,8 +5859,15 @@ BfIRValue BfModule::CreateFieldData(BfFieldInstance* fieldInstance, int customAt else if (fieldInstance->GetFieldDef()->mIsStatic) { BfTypedValue refVal; - if (!mIsComptimeModule) // This can create circular reference issues for a `Self` static + if (mIsComptimeModule) + { + constValue = mBfIRBuilder->CreateConst(BfTypeCode_IntPtr, fieldInstance->mFieldIdx); + } + else + { refVal = ReferenceStaticField(fieldInstance); + } + if (refVal.mValue.IsConst()) { auto constant = mBfIRBuilder->GetConstant(refVal.mValue); diff --git a/IDEHelper/Compiler/BfModuleTypeUtils.cpp b/IDEHelper/Compiler/BfModuleTypeUtils.cpp index 68ff329e..8b056096 100644 --- a/IDEHelper/Compiler/BfModuleTypeUtils.cpp +++ b/IDEHelper/Compiler/BfModuleTypeUtils.cpp @@ -2140,6 +2140,15 @@ BfCEParseContext BfModule::CEEmitParse(BfTypeInstance* typeInstance, BfTypeDef* ceParseContext.mFailIdx = mCompiler->mPassInstance->mFailedIdx; ceParseContext.mWarnIdx = mCompiler->mPassInstance->mWarnIdx; + if (typeInstance->mTypeDef->mEmitParent == NULL) + { + if (typeInstance->mTypeDef->mNextRevision != NULL) + { + InternalError("CEEmitParse preconditions failed"); + return ceParseContext; + } + } + bool createdParser = false; int startSrcIdx = 0; @@ -5203,6 +5212,9 @@ void BfModule::DoPopulateType(BfType* resolvedTypeRef, BfPopulateType populateTy if (hadNewMembers) { + // Avoid getting stale cached comptime reflection info + mCompiler->mCeMachine->mCeModule->mTypeDataRefs.Remove(resolvedTypeRef); + // We need to avoid passing in BfPopulateType_Interfaces_All because it could cause us to miss out on new member processing, // including resizing the method group table DoPopulateType(resolvedTypeRef, BF_MAX(populateType, BfPopulateType_Data)); diff --git a/IDEHelper/Compiler/CeMachine.cpp b/IDEHelper/Compiler/CeMachine.cpp index 7cca34cb..5fed2a75 100644 --- a/IDEHelper/Compiler/CeMachine.cpp +++ b/IDEHelper/Compiler/CeMachine.cpp @@ -7,6 +7,7 @@ #include "BfReducer.h" #include "BfExprEvaluator.h" #include "BfResolvePass.h" +#include "BfMangler.h" #include "../Backend/BeIRCodeGen.h" #include "BeefySysLib/platform/PlatformHelper.h" #include "../DebugManager.h" @@ -5883,7 +5884,7 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8* return false; }; - auto _CheckFunction = [&](CeFunction* checkFunction, bool& handled) + std::function _CheckFunction = [&](CeFunction* checkFunction, bool& handled) { if (checkFunction == NULL) { @@ -6326,6 +6327,106 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8* _FixVariables(); CeSetAddrVal(stackPtr + 0, reflectType, ptrSize); } + else if (checkFunction->mFunctionKind == CeFunctionKind_Field_GetStatic) + { + int32 typeId = *(int32*)((uint8*)stackPtr + ptrSize); + int32 fieldIdx = *(int32*)((uint8*)stackPtr + ptrSize + 4); + + CeFunction* ctorCallFunction = NULL; + + BfType* bfType = GetBfType(typeId); + bool success = false; + if (bfType != NULL) + { + auto typeInst = bfType->ToTypeInstance(); + if (typeInst != NULL) + { + if (typeInst->mDefineState < BfTypeDefineState_CETypeInit) + mCurModule->PopulateType(typeInst); + if ((fieldIdx >= 0) && (fieldIdx < typeInst->mFieldInstances.mSize)) + { + auto& fieldInstance = typeInst->mFieldInstances[fieldIdx]; + + auto fieldType = fieldInstance.mResolvedType; + ceModule->PopulateType(fieldType, BfPopulateType_Full_Force); + + int64 fieldId = ((int64)typeId << 32) | fieldIdx; + + CeStaticFieldInfo* staticFieldInfo = NULL; + if (mStaticFieldIdMap.TryAdd(fieldId, NULL, &staticFieldInfo)) + { + if (mStaticCtorExecSet.TryAdd(typeId, NULL)) + { + BfTypeInstance* bfTypeInstance = NULL; + if (bfType != NULL) + bfTypeInstance = bfType->ToTypeInstance(); + if (bfTypeInstance == NULL) + { + _Fail("Invalid type"); + return false; + } + + auto methodDef = bfTypeInstance->mTypeDef->GetMethodByName("__BfStaticCtor"); + if (methodDef == NULL) + { + _Fail("No static ctor found"); + return false; + } + + auto moduleMethodInstance = ceModule->GetMethodInstance(bfTypeInstance, methodDef, BfTypeVector()); + if (!moduleMethodInstance) + { + _Fail("No static ctor instance found"); + return false; + } + + bool added = false; + ctorCallFunction = mCeMachine->GetFunction(moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, added); + if (ctorCallFunction->mInitializeState < CeFunction::InitializeState_Initialized) + mCeMachine->PrepareFunction(ctorCallFunction, NULL); + } + + _FixVariables(); + + StringT<4096> staticVarName; + BfMangler::Mangle(staticVarName, ceModule->mCompiler->GetMangleKind(), &fieldInstance); + + CeStaticFieldInfo* nameStaticFieldInfo = NULL; + mStaticFieldMap.TryAdd(staticVarName, NULL, &nameStaticFieldInfo); + + if (nameStaticFieldInfo->mAddr == 0) + { + int fieldSize = fieldInstance.mResolvedType->mSize; + CE_CHECKALLOC(fieldSize); + uint8* ptr = CeMalloc(fieldSize); + _FixVariables(); + if (fieldSize > 0) + memset(ptr, 0, fieldSize); + nameStaticFieldInfo->mAddr = (addr_ce)(ptr - memStart); + } + + staticFieldInfo->mAddr = nameStaticFieldInfo->mAddr; + } + + CeSetAddrVal(stackPtr + 0, staticFieldInfo->mAddr, ptrSize); + } + else if (fieldIdx != -1) + { + _Fail("Invalid field"); + return false; + } + } + } + + if (ctorCallFunction != NULL) + { + bool handled = false; + if (!_CheckFunction(ctorCallFunction, handled)) + return false; + if (!handled) + CE_CALL(ctorCallFunction); + } + } else if (checkFunction->mFunctionKind == CeFunctionKind_SetReturnType) { int32 typeId = *(int32*)((uint8*)stackPtr); @@ -7920,24 +8021,30 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8* return false; } - auto methodDef = bfTypeInstance->mTypeDef->GetMethodByName("__BfStaticCtor"); - if (methodDef == NULL) + if (bfType->mDefineState == BfTypeDefineState_CETypeInit) { - _Fail("No static ctor found"); - return false; + // Don't create circular references } - - auto moduleMethodInstance = ceModule->GetMethodInstance(bfTypeInstance, methodDef, BfTypeVector()); - if (!moduleMethodInstance) + else { - _Fail("No static ctor instance found"); - return false; - } + auto methodDef = bfTypeInstance->mTypeDef->GetMethodByName("__BfStaticCtor"); + if (methodDef != NULL) + { + auto moduleMethodInstance = ceModule->GetMethodInstance(bfTypeInstance, methodDef, BfTypeVector()); + if (!moduleMethodInstance) + { + _Fail("No static ctor instance found"); + return false; + } - bool added = false; - ctorCallFunction = mCeMachine->GetFunction(moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, added); - if (ctorCallFunction->mInitializeState < CeFunction::InitializeState_Initialized) - mCeMachine->PrepareFunction(ctorCallFunction, NULL); + ceModule->PopulateType(bfTypeInstance, BfPopulateType_DataAndMethods); + + bool added = false; + ctorCallFunction = mCeMachine->GetFunction(moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, added); + if (ctorCallFunction->mInitializeState < CeFunction::InitializeState_Initialized) + mCeMachine->PrepareFunction(ctorCallFunction, NULL); + } + } } CeStaticFieldInfo* staticFieldInfo = NULL; @@ -9549,6 +9656,10 @@ void CeMachine::CheckFunctionKind(CeFunction* ceFunction) { ceFunction->mFunctionKind = CeFunctionKind_Method_GetGenericArg; } + else if (methodDef->mName == "Comptime_Field_GetStatic") + { + ceFunction->mFunctionKind = CeFunctionKind_Field_GetStatic; + } } else if (owner->IsInstanceOf(mCeModule->mCompiler->mCompilerTypeDef)) { @@ -10089,6 +10200,7 @@ void CeMachine::ReleaseContext(CeContext* ceContext) ceContext->mMemory.Dispose(); ceContext->mStaticCtorExecSet.Clear(); ceContext->mStaticFieldMap.Clear(); + ceContext->mStaticFieldIdMap.Clear(); ceContext->mHeap->Clear(BF_CE_MAX_CARRYOVER_HEAP); ceContext->mReflectTypeIdOffset = -1; mCurEmitContext = ceContext->mCurEmitContext; diff --git a/IDEHelper/Compiler/CeMachine.h b/IDEHelper/Compiler/CeMachine.h index 0741bc4f..6a35737f 100644 --- a/IDEHelper/Compiler/CeMachine.h +++ b/IDEHelper/Compiler/CeMachine.h @@ -447,6 +447,7 @@ enum CeFunctionKind CeFunctionKind_Method_GetInfo, CeFunctionKind_Method_GetParamInfo, CeFunctionKind_Method_GetGenericArg, + CeFunctionKind_Field_GetStatic, CeFunctionKind_SetReturnType, CeFunctionKind_Align, @@ -1106,6 +1107,7 @@ public: Dictionary mConstDataMap; HashSet mStaticCtorExecSet; Dictionary mStaticFieldMap; + Dictionary mStaticFieldIdMap; Dictionary mInternalDataMap; int mCurHandleId; diff --git a/IDEHelper/Tests/src/Comptime.bf b/IDEHelper/Tests/src/Comptime.bf index b4d4aeb9..66887837 100644 --- a/IDEHelper/Tests/src/Comptime.bf +++ b/IDEHelper/Tests/src/Comptime.bf @@ -481,6 +481,21 @@ namespace Tests } } + class ClassB + { + public static int mA = 123; + } + + class ClassC + { + [OnCompile(.TypeInit), Comptime] + static void Init() + { + typeof(ClassB).GetField("mA").Value.GetValue(null, var value); + Compiler.EmitTypeBody(typeof(Self), scope $"public static int sA = {1000 + value};"); + } + } + [Test] public static void TestBasics() { @@ -499,6 +514,8 @@ namespace Tests Test.Assert(sa.mC == 345); Test.Assert(sa.GetValC() == 345); + Test.Assert(ClassC.sA == 1123); + Compiler.Mixin("int val = 99;"); Test.Assert(val == 99); @@ -532,6 +549,8 @@ namespace Tests Test.Assert(typeof(decltype(f)) == typeof(float)); Test.Assert(ClassB.cTimesTen == 30); + + DictWrapper> dictWrap = scope .(); dictWrap.[Friend]mValue.Add(1, 2.3f); dictWrap.[Friend]mValue.Add(2, 3.4f);