From e1d7939081446981abf8f8a17d5d5669a4c021a8 Mon Sep 17 00:00:00 2001 From: Brian Fiete Date: Sat, 7 May 2022 11:40:55 -0700 Subject: [PATCH] Improved virtual overrides in extensions --- IDEHelper/Backend/BeIRCodeGen.cpp | 8 ++++++ IDEHelper/Compiler/BfCompiler.cpp | 17 +++++++++++ IDEHelper/Compiler/BfContext.cpp | 3 +- IDEHelper/Compiler/BfIRBuilder.cpp | 6 ++++ IDEHelper/Compiler/BfIRBuilder.h | 2 ++ IDEHelper/Compiler/BfIRCodeGen.cpp | 10 ++++++- IDEHelper/Compiler/BfModule.cpp | 36 +++++++++++++++++++++--- IDEHelper/Compiler/BfResolvedTypeUtils.h | 16 +++++++---- IDEHelper/Tests/LibA/src/LibA0.bf | 13 +++++++++ IDEHelper/Tests/TestsB/src/TestsB0.bf | 4 +++ IDEHelper/Tests/src/Extensions.bf | 12 ++++++++ 11 files changed, 116 insertions(+), 11 deletions(-) diff --git a/IDEHelper/Backend/BeIRCodeGen.cpp b/IDEHelper/Backend/BeIRCodeGen.cpp index 6123a4dd..b4b7e497 100644 --- a/IDEHelper/Backend/BeIRCodeGen.cpp +++ b/IDEHelper/Backend/BeIRCodeGen.cpp @@ -2291,6 +2291,14 @@ void BeIRCodeGen::HandleNextCmd() SetResult(curId, mBeModule->CreateFunction(type, linkageType, name)); } break; + case BfIRCmd_SetFunctionName: + { + CMD_PARAM(BeValue*, func); + CMD_PARAM(String, name); + BeFunction* beFunc = BeValueDynCast(func); + beFunc->mName = name; + } + break; case BfIRCmd_EnsureFunctionPatchable: { diff --git a/IDEHelper/Compiler/BfCompiler.cpp b/IDEHelper/Compiler/BfCompiler.cpp index add65531..d73277eb 100644 --- a/IDEHelper/Compiler/BfCompiler.cpp +++ b/IDEHelper/Compiler/BfCompiler.cpp @@ -5878,6 +5878,23 @@ void BfCompiler::PopulateReified() checkType = checkType->mBaseType; } + + for (auto& reifyDep : typeInst->mReifyMethodDependencies) + { + if ((reifyDep.mDepMethod.mTypeInstance == NULL) || + (reifyDep.mDepMethod.mTypeInstance->IsIncomplete())) + continue; + + BfMethodInstance* depMethod = reifyDep.mDepMethod; + if (depMethod == NULL) + continue; + + if ((depMethod->mIsReified) && (depMethod->mMethodInstanceGroup->IsImplemented())) + { + auto methodDef = typeInst->mTypeDef->mMethods[reifyDep.mMethodIdx]; + typeInst->mModule->GetMethodInstance(typeInst, methodDef, BfTypeVector()); + } + } } } } diff --git a/IDEHelper/Compiler/BfContext.cpp b/IDEHelper/Compiler/BfContext.cpp index a302e8a4..c98b22e9 100644 --- a/IDEHelper/Compiler/BfContext.cpp +++ b/IDEHelper/Compiler/BfContext.cpp @@ -1225,7 +1225,8 @@ void BfContext::RebuildType(BfType* type, bool deleteOnDemandTypes, bool rebuild delete typeInst->mAttributeData; typeInst->mAttributeData = NULL; typeInst->mVirtualMethodTableSize = 0; - typeInst->mVirtualMethodTable.Clear(); + typeInst->mVirtualMethodTable.Clear(); + typeInst->mReifyMethodDependencies.Clear(); typeInst->mSize = -1; typeInst->mAlign = -1; typeInst->mInstSize = -1; diff --git a/IDEHelper/Compiler/BfIRBuilder.cpp b/IDEHelper/Compiler/BfIRBuilder.cpp index b5d18930..3e47e63e 100644 --- a/IDEHelper/Compiler/BfIRBuilder.cpp +++ b/IDEHelper/Compiler/BfIRBuilder.cpp @@ -5314,6 +5314,12 @@ BfIRFunction BfIRBuilder::CreateFunction(BfIRFunctionType funcType, BfIRLinkageT return retVal; } +void BfIRBuilder::SetFunctionName(BfIRValue func, const StringImpl& name) +{ + WriteCmd(BfIRCmd_SetFunctionName, func, name); + NEW_CMD_INSERTED_IRVALUE; +} + void BfIRBuilder::EnsureFunctionPatchable() { BfIRValue retVal = WriteCmd(BfIRCmd_EnsureFunctionPatchable); diff --git a/IDEHelper/Compiler/BfIRBuilder.h b/IDEHelper/Compiler/BfIRBuilder.h index bdaf5373..a892b11e 100644 --- a/IDEHelper/Compiler/BfIRBuilder.h +++ b/IDEHelper/Compiler/BfIRBuilder.h @@ -278,6 +278,7 @@ enum BfIRCmd : uint8 BfIRCmd_GetIntrinsic, BfIRCmd_CreateFunctionType, BfIRCmd_CreateFunction, + BfIRCmd_SetFunctionName, BfIRCmd_EnsureFunctionPatchable, BfIRCmd_RemapBindFunction, BfIRCmd_SetActiveFunction, @@ -1305,6 +1306,7 @@ public: BfIRFunctionType MapMethod(BfMethodInstance* methodInstance); BfIRFunctionType CreateFunctionType(BfIRType resultType, const BfSizedArray& paramTypes, bool isVarArg = false); BfIRFunction CreateFunction(BfIRFunctionType funcType, BfIRLinkageType linkageType, const StringImpl& name); + void SetFunctionName(BfIRValue func, const StringImpl& name); void EnsureFunctionPatchable(); BfIRValue RemapBindFunction(BfIRValue func); void SetActiveFunction(BfIRFunction func); diff --git a/IDEHelper/Compiler/BfIRCodeGen.cpp b/IDEHelper/Compiler/BfIRCodeGen.cpp index 14acee63..273ab80f 100644 --- a/IDEHelper/Compiler/BfIRCodeGen.cpp +++ b/IDEHelper/Compiler/BfIRCodeGen.cpp @@ -2959,7 +2959,7 @@ void BfIRCodeGen::HandleNextCmd() { CMD_PARAM(llvm::FunctionType*, type); BfIRLinkageType linkageType = (BfIRLinkageType)mStream->Read(); - CMD_PARAM(String, name); + CMD_PARAM(String, name); auto func = mLLVMModule->getFunction(name.c_str()); if ((func == NULL) || (func->getFunctionType() != type)) @@ -2968,6 +2968,14 @@ void BfIRCodeGen::HandleNextCmd() SetResult(curId, func); } break; + case BfIRCmd_SetFunctionName: + { + CMD_PARAM(llvm::Value*, func); + CMD_PARAM(String, name); + llvm::Function* llvmFunc = llvm::dyn_cast(func); + llvmFunc->setName(name.c_str()); + } + break; case BfIRCmd_EnsureFunctionPatchable: { int minPatchSize = 5; diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index e3533c94..65c682c0 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -22547,7 +22547,9 @@ void BfModule::SetupIRFunction(BfMethodInstance* methodInstance, StringImpl& man auto checkMethodInstance = mCurTypeInstance->mMethodInstanceGroups[checkMethod->mIdx].mDefault; if (checkMethodInstance == NULL) continue; - if ((checkMethodInstance->mIRFunction == prevFunc) && (checkMethodInstance->mMethodDef->mMethodDeclaration != NULL)) + if ((checkMethodInstance->mIRFunction == prevFunc) && + (checkMethodInstance->mMethodDef->mMethodDeclaration != NULL) && + (checkMethodInstance->mVirtualTableIdx < 0)) { BfAstNode* refNode = methodDef->GetRefNode(); if (auto propertyMethodDeclaration = methodDef->GetPropertyMethodDeclaration()) @@ -24633,12 +24635,38 @@ bool BfModule::SlotVirtualMethod(BfMethodInstance* methodInstance, BfAmbiguityCo auto declMethodInstance = (BfMethodInstance*)typeInstance->mVirtualMethodTable[virtualMethodMatchIdx].mDeclaringMethod; _AddVirtualDecl(declMethodInstance); setMethodInstance->mVirtualTableIdx = virtualMethodMatchIdx; + + auto& implMethodRef = typeInstance->mVirtualMethodTable[virtualMethodMatchIdx].mImplementingMethod; + if ((!mCompiler->mIsResolveOnly) && (implMethodRef.mMethodNum >= 0) && + (implMethodRef.mTypeInstance == typeInstance) && (methodInstance->GetOwner() == typeInstance)) + { + auto prevImplMethodInstance = (BfMethodInstance*)implMethodRef; + if (prevImplMethodInstance->mMethodDef->mDeclaringType->mProject != methodInstance->mMethodDef->mDeclaringType->mProject) + { + // We may need to have to previous method reified when we must re-slot in another project during vdata creation + BfReifyMethodDependency dep; + dep.mDepMethod = typeInstance->mVirtualMethodTable[virtualMethodMatchIdx].mDeclaringMethod; + dep.mMethodIdx = implMethodRef.mMethodNum; + typeInstance->mReifyMethodDependencies.Add(dep); + } + + if (!methodInstance->mMangleWithIdx) + { + // Keep mangled names from conflicting + methodInstance->mMangleWithIdx = true; + if ((methodInstance->mIRFunction) && (methodInstance->mDeclModule->mIsModuleMutable)) + { + StringT<4096> mangledName; + BfMangler::Mangle(mangledName, mCompiler->GetMangleKind(), methodInstance); + methodInstance->mDeclModule->mBfIRBuilder->SetFunctionName(methodInstance->mIRFunction, mangledName); + } + } + } + typeInstance->mVirtualMethodTable[virtualMethodMatchIdx].mImplementingMethod = setMethodInstance; } } - - - + if (methodOverriden != NULL) { CheckOverridenMethod(methodInstance, methodOverriden); diff --git a/IDEHelper/Compiler/BfResolvedTypeUtils.h b/IDEHelper/Compiler/BfResolvedTypeUtils.h index 97b6ab81..b94ccfcd 100644 --- a/IDEHelper/Compiler/BfResolvedTypeUtils.h +++ b/IDEHelper/Compiler/BfResolvedTypeUtils.h @@ -1903,16 +1903,22 @@ public: class BfCeTypeInfo; +struct BfReifyMethodDependency +{ +public: + BfNonGenericMethodRef mDepMethod; + int mMethodIdx; +}; + // Instance of struct or class class BfTypeInstance : public BfDependedType { -public: +public: int mSignatureRevision; int mLastNonGenericUsedRevision; int mInheritanceId; int mInheritanceCount; BfModule* mModule; - BfTypeDef* mTypeDef; BfTypeInstance* mBaseType; BfCustomAttributes* mCustomAttributes; @@ -1920,12 +1926,12 @@ public: BfTypeInfoEx* mTypeInfoEx; BfGenericTypeInfo* mGenericTypeInfo; BfCeTypeInfo* mCeTypeInfo; - Array mInterfaces; - Array mInterfaceMethodTable; + Array mInterfaceMethodTable; Array mMethodInstanceGroups; Array mOperatorInfo; - Array mVirtualMethodTable; + Array mVirtualMethodTable; + Array mReifyMethodDependencies; BfHotTypeData* mHotTypeData; int mVirtualMethodTableSize; // With hot reloading, mVirtualMethodTableSize can be larger than mInterfaceMethodTable (live vtable versioning) Array mFieldInstances; diff --git a/IDEHelper/Tests/LibA/src/LibA0.bf b/IDEHelper/Tests/LibA/src/LibA0.bf index 09e6a449..fbe84f0f 100644 --- a/IDEHelper/Tests/LibA/src/LibA0.bf +++ b/IDEHelper/Tests/LibA/src/LibA0.bf @@ -45,6 +45,19 @@ namespace LibA T val = default; return Overload0(val); } + + public virtual int GetA() + { + return 1; + } + } + + extension LibA0 + { + public override int GetA() + { + return 2; + } } struct Handler diff --git a/IDEHelper/Tests/TestsB/src/TestsB0.bf b/IDEHelper/Tests/TestsB/src/TestsB0.bf index e571e317..a7dca87b 100644 --- a/IDEHelper/Tests/TestsB/src/TestsB0.bf +++ b/IDEHelper/Tests/TestsB/src/TestsB0.bf @@ -42,6 +42,10 @@ namespace TestsB Test.Assert(ca.mB == 1008); Test.Assert(ca.mC == 9); Test.Assert(ca.GetVal2() == 11); + + LibA.LibA0 la0 = scope .(); + int la0a = la0.GetA(); + Test.Assert(la0a == 2); } } diff --git a/IDEHelper/Tests/src/Extensions.bf b/IDEHelper/Tests/src/Extensions.bf index a2974ad8..6aef2980 100644 --- a/IDEHelper/Tests/src/Extensions.bf +++ b/IDEHelper/Tests/src/Extensions.bf @@ -57,6 +57,14 @@ extension LibClassA namespace LibA { + extension LibA0 + { + public new override int GetA() + { + return 3; + } + } + extension LibA3 { this @@ -386,6 +394,10 @@ namespace Tests delete ca; Test.Assert(LibClassA.sMagic == 7771); + LibA.LibA0 la0 = scope .(); + int la0a = la0.GetA(); + Test.Assert(la0a == 3); + LibA.LibA3 la3 = scope .(); Test.Assert(la3.mA == 114); Test.Assert(la3.mB == 7);