From 26506efc1e11ae244d44ae471552b452b8c7b7af Mon Sep 17 00:00:00 2001 From: Brian Fiete Date: Mon, 31 Jan 2022 15:41:05 -0500 Subject: [PATCH] Improved generic param reflection in comptime --- BeefLibs/corlib/src/Type.bf | 24 ++++++++++++++++ IDE/mintest/minlib/src/System/Type.bf | 24 ++++++++++++++++ IDEHelper/Compiler/BfCompiler.cpp | 3 ++ IDEHelper/Compiler/BfCompiler.h | 1 + IDEHelper/Compiler/BfExprEvaluator.cpp | 11 ++----- IDEHelper/Compiler/BfModule.cpp | 26 ++++++++++++++--- IDEHelper/Compiler/CeMachine.cpp | 40 +++++++++++++++++++++++--- IDEHelper/Compiler/CeMachine.h | 8 ++++-- IDEHelper/Tests/src/Comptime.bf | 25 ++++++++++++++++ 9 files changed, 143 insertions(+), 19 deletions(-) diff --git a/BeefLibs/corlib/src/Type.bf b/BeefLibs/corlib/src/Type.bf index 099a20dd..2da60586 100644 --- a/BeefLibs/corlib/src/Type.bf +++ b/BeefLibs/corlib/src/Type.bf @@ -490,6 +490,7 @@ namespace System static extern Type Comptime_GetTypeById(int32 typeId); static extern Type Comptime_GetTypeByName(StringView name); + static extern String Comptime_Type_ToString(int32 typeId); static extern Type Comptime_GetSpecializedType(Type unspecializedType, Span typeArgs); static extern bool Comptime_Type_GetCustomAttribute(int32 typeId, int32 attributeId, void* dataPtr); static extern int32 Comptime_GetMethodCount(int32 typeId); @@ -556,6 +557,12 @@ namespace System } } + void ComptimeToString(String strBuffer) + { + if (Compiler.IsComptime) + strBuffer.Append(Comptime_Type_ToString((.)mTypeId)); + } + public virtual void GetFullName(String strBuffer) { GetBasicName(strBuffer); @@ -1285,6 +1292,23 @@ namespace System.Reflection } } + [Ordered, AlwaysInclude(AssumeInstantiated=true)] + class GenericParamType : Type + { + public override void GetName(String strBuffer) + { + if (Compiler.IsComptime) + this.[Friend]ComptimeToString(strBuffer); + else + strBuffer.Append("$GenericParam"); + } + + public override void GetFullName(String strBuffer) + { + GetName(strBuffer); + } + } + public enum TypeFlags : uint32 { UnspecializedGeneric = 0x0001, diff --git a/IDE/mintest/minlib/src/System/Type.bf b/IDE/mintest/minlib/src/System/Type.bf index d78c37ca..bd6df73e 100644 --- a/IDE/mintest/minlib/src/System/Type.bf +++ b/IDE/mintest/minlib/src/System/Type.bf @@ -482,6 +482,7 @@ namespace System static extern Type Comptime_GetTypeById(int32 typeId); static extern Type Comptime_GetTypeByName(StringView name); + static extern String Comptime_Type_ToString(int32 typeId); static extern Type Comptime_GetSpecializedType(Type unspecializedType, Span typeArgs); static extern bool Comptime_Type_GetCustomAttribute(int32 typeId, int32 attributeId, void* dataPtr); static extern int32 Comptime_GetMethodCount(int32 typeId); @@ -548,6 +549,12 @@ namespace System } } + void ComptimeToString(String strBuffer) + { + if (Compiler.IsComptime) + strBuffer.Append(Comptime_Type_ToString((.)mTypeId)); + } + public virtual void GetFullName(String strBuffer) { GetBasicName(strBuffer); @@ -1261,6 +1268,23 @@ namespace System.Reflection } } + [Ordered, AlwaysInclude(AssumeInstantiated=true)] + class GenericParamType : Type + { + public override void GetName(String strBuffer) + { + if (Compiler.IsComptime) + this.[Friend]ComptimeToString(strBuffer); + else + strBuffer.Append("$GenericParam"); + } + + public override void GetFullName(String strBuffer) + { + GetName(strBuffer); + } + } + public enum TypeFlags : uint32 { UnspecializedGeneric = 0x0001, diff --git a/IDEHelper/Compiler/BfCompiler.cpp b/IDEHelper/Compiler/BfCompiler.cpp index 2fdd2d13..32bf61a1 100644 --- a/IDEHelper/Compiler/BfCompiler.cpp +++ b/IDEHelper/Compiler/BfCompiler.cpp @@ -439,6 +439,7 @@ BfCompiler::BfCompiler(BfSystem* bfSystem, bool isResolveOnly) mPointerTypeDef = NULL; mReflectTypeIdTypeDef = NULL; mReflectArrayType = NULL; + mReflectGenericParamType = NULL; mReflectFieldDataDef = NULL; mReflectFieldSplatDataDef = NULL; mReflectMethodDataDef = NULL; @@ -1345,6 +1346,7 @@ void BfCompiler::CreateVData(BfVDataModule* bfModule) reflectTypeSet.Add(vdataContext->mUnreifiedModule->ResolveTypeDef(mReflectSpecializedGenericType)); reflectTypeSet.Add(vdataContext->mUnreifiedModule->ResolveTypeDef(mReflectUnspecializedGenericType)); reflectTypeSet.Add(vdataContext->mUnreifiedModule->ResolveTypeDef(mReflectArrayType)); + reflectTypeSet.Add(vdataContext->mUnreifiedModule->ResolveTypeDef(mReflectGenericParamType)); SmallVector typeDataVector; for (auto type : vdataTypeList) @@ -6835,6 +6837,7 @@ bool BfCompiler::DoCompile(const StringImpl& outputDirectory) mPointerTypeDef = _GetRequiredType("System.Pointer", 0); mReflectTypeIdTypeDef = _GetRequiredType("System.Reflection.TypeId"); mReflectArrayType = _GetRequiredType("System.Reflection.ArrayType"); + mReflectGenericParamType = _GetRequiredType("System.Reflection.GenericParamType"); mReflectFieldDataDef = _GetRequiredType("System.Reflection.TypeInstance.FieldData"); mReflectFieldSplatDataDef = _GetRequiredType("System.Reflection.TypeInstance.FieldSplatData"); mReflectMethodDataDef = _GetRequiredType("System.Reflection.TypeInstance.MethodData"); diff --git a/IDEHelper/Compiler/BfCompiler.h b/IDEHelper/Compiler/BfCompiler.h index aa0defd3..7878ee96 100644 --- a/IDEHelper/Compiler/BfCompiler.h +++ b/IDEHelper/Compiler/BfCompiler.h @@ -398,6 +398,7 @@ public: BfTypeDef* mPointerTypeDef; BfTypeDef* mReflectTypeIdTypeDef; BfTypeDef* mReflectArrayType; + BfTypeDef* mReflectGenericParamType; BfTypeDef* mReflectFieldDataDef; BfTypeDef* mReflectFieldSplatDataDef; BfTypeDef* mReflectMethodDataDef; diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index 3e850418..8d453cdf 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -10593,15 +10593,8 @@ void BfExprEvaluator::Visit(BfTypeOfExpression* typeOfExpr) return; } - if ((type->IsGenericParam()) && (!mModule->mIsComptimeModule)) - { - mResult = BfTypedValue(mModule->mBfIRBuilder->GetUndefConstValue(mModule->mBfIRBuilder->MapType(typeType)), typeType); - } - else - { - mModule->AddDependency(type, mModule->mCurTypeInstance, BfDependencyMap::DependencyFlag_ExprTypeReference); - mResult = BfTypedValue(mModule->CreateTypeDataRef(type), typeType); - } + mModule->AddDependency(type, mModule->mCurTypeInstance, BfDependencyMap::DependencyFlag_ExprTypeReference); + mResult = BfTypedValue(mModule->CreateTypeDataRef(type), typeType); } bool BfExprEvaluator::LookupTypeProp(BfTypeOfExpression* typeOfExpr, BfIdentifierNode* propName) diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index 408fc845..fdfcc528 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -5396,11 +5396,13 @@ BfIRValue BfModule::CreateTypeData(BfType* type, Dictionary& usedStrin typeDataSource = ResolveTypeDef(mCompiler->mReflectSizedArrayType)->ToTypeInstance(); else if (type->IsConstExprValue()) typeDataSource = ResolveTypeDef(mCompiler->mReflectConstExprType)->ToTypeInstance(); - else - typeDataSource = mContext->mBfTypeType; - - if (type->IsGenericParam()) + else if (type->IsGenericParam()) + { typeFlags |= BfTypeFlags_GenericParam; + typeDataSource = ResolveTypeDef(mCompiler->mReflectGenericParamType)->ToTypeInstance(); + } + else + typeDataSource = mContext->mBfTypeType; if ((!mTypeDataRefs.ContainsKey(typeDataSource)) && (typeDataSource != type)) { @@ -5636,6 +5638,22 @@ BfIRValue BfModule::CreateTypeData(BfType* type, Dictionary& usedStrin mBfIRBuilder->GlobalVar_SetAlignment(typeDataVar, mSystem->mPtrSize); typeDataVar = mBfIRBuilder->CreateBitCast(typeDataVar, mBfIRBuilder->MapType(mContext->mBfTypeType)); } + else if (type->IsGenericParam()) + { + auto genericParamType = (BfGenericParamType*)type; + SizedArray genericParamTypeDataParms = + { + typeData + }; + + auto reflectGenericParamType = ResolveTypeDef(mCompiler->mReflectGenericParamType)->ToTypeInstance(); + FixConstValueParams(reflectGenericParamType, genericParamTypeDataParms); + auto genericParamTypeData = mBfIRBuilder->CreateConstAgg_Value(mBfIRBuilder->MapTypeInst(reflectGenericParamType, BfIRPopulateType_Full), genericParamTypeDataParms); + typeDataVar = mBfIRBuilder->CreateGlobalVariable(mBfIRBuilder->MapTypeInst(reflectGenericParamType), true, + BfIRLinkageType_External, genericParamTypeData, typeDataName); + mBfIRBuilder->GlobalVar_SetAlignment(typeDataVar, mSystem->mPtrSize); + typeDataVar = mBfIRBuilder->CreateBitCast(typeDataVar, mBfIRBuilder->MapType(mContext->mBfTypeType)); + } else { typeDataVar = mBfIRBuilder->CreateGlobalVariable(mBfIRBuilder->MapTypeInst(mContext->mBfTypeType), true, diff --git a/IDEHelper/Compiler/CeMachine.cpp b/IDEHelper/Compiler/CeMachine.cpp index a7a58ae3..7070b9c8 100644 --- a/IDEHelper/Compiler/CeMachine.cpp +++ b/IDEHelper/Compiler/CeMachine.cpp @@ -2913,6 +2913,9 @@ CeContext::CeContext() mCurFrame = NULL; mCurModule = NULL; mCurMethodInstance = NULL; + mCallerMethodInstance = NULL; + mCallerTypeInstance = NULL; + mCallerActiveTypeDef = NULL; mCurExpectingType = NULL; mCurEmitContext = NULL; } @@ -2983,18 +2986,20 @@ BfError* CeContext::Fail(const CeFrame& curFrame, const StringImpl& str) err += " "; } - auto contextMethodInstance = mCurModule->mCurMethodInstance; + auto contextMethodInstance = mCallerMethodInstance; + auto contextTypeInstance = mCallerTypeInstance; if (stackIdx > 1) { auto func = mCallStack[stackIdx - 1].mFunction; contextMethodInstance = func->mCeFunctionInfo->mMethodInstance; + contextTypeInstance = contextMethodInstance->GetOwner(); } err += StrFormat("in comptime "); // { - SetAndRestoreValue prevTypeInstance(mCeMachine->mCeModule->mCurTypeInstance, (contextMethodInstance != NULL) ? contextMethodInstance->GetOwner() : NULL); + SetAndRestoreValue prevTypeInstance(mCeMachine->mCeModule->mCurTypeInstance, contextTypeInstance); SetAndRestoreValue prevMethodInstance(mCeMachine->mCeModule->mCurMethodInstance, contextMethodInstance); if (ceFunction->mMethodInstance != NULL) @@ -3033,7 +3038,7 @@ BfError* CeContext::Fail(const CeFrame& curFrame, const StringImpl& str) void CeContext::FixProjectRelativePath(StringImpl& path) { BfProject* activeProject = NULL; - auto activeTypeDef = mCurModule->GetActiveTypeDef(); + auto activeTypeDef = mCallerActiveTypeDef; if (activeTypeDef != NULL) activeProject = activeTypeDef->mProject; if (activeProject != NULL) @@ -3145,7 +3150,7 @@ addr_ce CeContext::GetReflectType(int typeId) return *addrPtr; auto ceModule = mCeMachine->mCeModule; - SetAndRestoreValue ignoreWrites(ceModule->mBfIRBuilder->mIgnoreWrites, false); + SetAndRestoreValue ignoreWrites(ceModule->mBfIRBuilder->mIgnoreWrites, false); if (ceModule->mContext->mBfTypeType == NULL) ceModule->mContext->ReflectInit(); @@ -4064,8 +4069,14 @@ BfTypedValue CeContext::Call(BfAstNode* targetSrc, BfModule* module, BfMethodIns SetAndRestoreValue prevTargetSrc(mCurTargetSrc, targetSrc); SetAndRestoreValue prevModule(mCurModule, module); SetAndRestoreValue prevMethodInstance(mCurMethodInstance, methodInstance); + SetAndRestoreValue prevCallerMethodInstance(mCallerMethodInstance, module->mCurMethodInstance); + SetAndRestoreValue prevCallerTypeInstance(mCallerTypeInstance, module->mCurTypeInstance); + SetAndRestoreValue prevCallerActiveTypeDef(mCallerActiveTypeDef, module->GetActiveTypeDef()); SetAndRestoreValue prevExpectingType(mCurExpectingType, expectingType); + SetAndRestoreValue moduleCurMethodInstance(module->mCurMethodInstance, methodInstance); + SetAndRestoreValue moduleCurTypeInstance(module->mCurTypeInstance, methodInstance->GetOwner()); + SetAndRestoreValue prevCurExecuteId(mCurModule->mCompiler->mCurCEExecuteId, mCeMachine->mExecuteId); // Reentrancy may occur as methods need defining @@ -4854,6 +4865,23 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8* _FixVariables(); CeSetAddrVal(stackPtr + 0, reflectType, ptrSize); } + else if (checkFunction->mFunctionKind == CeFunctionKind_Type_ToString) + { + int32 typeId = *(int32*)((uint8*)stackPtr + ptrSize); + + BfType* type = GetBfType(typeId); + bool success = false; + if (type == NULL) + { + _Fail("Invalid type"); + return false; + } + + SetAndRestoreValue prevMethodInstance(mCeMachine->mCeModule->mCurMethodInstance, mCallerMethodInstance); + SetAndRestoreValue prevTypeInstance(mCeMachine->mCeModule->mCurTypeInstance, mCallerTypeInstance); + CeSetAddrVal(stackPtr + 0, GetString(mCeMachine->mCeModule->TypeToString(type)), ptrSize); + _FixVariables(); + } else if (checkFunction->mFunctionKind == CeFunctionKind_Type_GetCustomAttribute) { int32 typeId = *(int32*)((uint8*)stackPtr + 1); @@ -7786,6 +7814,10 @@ void CeMachine::CheckFunctionKind(CeFunction* ceFunction) { ceFunction->mFunctionKind = CeFunctionKind_GetReflectSpecializedType; } + else if (methodDef->mName == "Comptime_Type_ToString") + { + ceFunction->mFunctionKind = CeFunctionKind_Type_ToString; + } else if (methodDef->mName == "Comptime_Type_GetCustomAttribute") { ceFunction->mFunctionKind = CeFunctionKind_Type_GetCustomAttribute; diff --git a/IDEHelper/Compiler/CeMachine.h b/IDEHelper/Compiler/CeMachine.h index d8a6d085..9b961ae0 100644 --- a/IDEHelper/Compiler/CeMachine.h +++ b/IDEHelper/Compiler/CeMachine.h @@ -324,6 +324,7 @@ enum CeFunctionKind CeFunctionKind_GetReflectTypeById, CeFunctionKind_GetReflectTypeByName, CeFunctionKind_GetReflectSpecializedType, + CeFunctionKind_Type_ToString, CeFunctionKind_Type_GetCustomAttribute, CeFunctionKind_GetMethodCount, CeFunctionKind_GetMethod, @@ -701,7 +702,7 @@ public: mCurDbgLoc = NULL; mFrameSize = 0; } - + void Fail(const StringImpl& error); CeOperand FrameAlloc(BeType* type); @@ -867,6 +868,9 @@ public: Dictionary mInternalDataMap; int mCurHandleId; + BfMethodInstance* mCallerMethodInstance; + BfTypeInstance* mCallerTypeInstance; + BfTypeDef* mCallerActiveTypeDef; BfMethodInstance* mCurMethodInstance; BfType* mCurExpectingType; BfAstNode* mCurTargetSrc; @@ -877,7 +881,7 @@ public: public: CeContext(); ~CeContext(); - + BfError* Fail(const StringImpl& error); BfError* Fail(const CeFrame& curFrame, const StringImpl& error); diff --git a/IDEHelper/Tests/src/Comptime.bf b/IDEHelper/Tests/src/Comptime.bf index 2405f6d9..6bcaef69 100644 --- a/IDEHelper/Tests/src/Comptime.bf +++ b/IDEHelper/Tests/src/Comptime.bf @@ -405,6 +405,25 @@ namespace Tests } } + public struct GetTupleField where C : const int + { + public typealias Type = comptype(GetTupleFieldType(typeof(TTuple), C)); + + [Comptime] + private static Type GetTupleFieldType(Type type, int index) + { + if (type.IsGenericParam) + { + Compiler.Assert(type.IsGenericParam); + String tName = type.GetFullName(.. scope .()); + Compiler.Assert(tName == "TTuple"); + return typeof(var); + } + Compiler.Assert(type.IsTuple); + return type.GetField(index).Get().FieldType; + } + } + [Test] public static void TestBasics() { @@ -468,6 +487,12 @@ namespace Tests ++idx; } Test.Assert(idx == 2); + + var tuple = ((int16)1, 2.3f); + GetTupleField.Type tupType0; + GetTupleField.Type tupType1; + Test.Assert(typeof(decltype(tupType0)) == typeof(int16)); + Test.Assert(typeof(decltype(tupType1)) == typeof(float)); } } }