From 46cc3d088b722ce8da61c4c71c4a698bd83ddc7e Mon Sep 17 00:00:00 2001 From: Brian Fiete Date: Mon, 5 Jul 2021 14:36:39 -0700 Subject: [PATCH] Improved function binding to virtual methods --- IDEHelper/Compiler/BfExprEvaluator.cpp | 7 +-- IDEHelper/Compiler/BfExprEvaluator.h | 2 +- IDEHelper/Compiler/BfModule.h | 2 +- IDEHelper/Compiler/BfModuleTypeUtils.cpp | 56 +++++++++++++++++------- IDEHelper/Tests/src/Functions.bf | 30 ++++++++++++- 5 files changed, 75 insertions(+), 22 deletions(-) diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index cc51ebe6..8bb3f776 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -1617,7 +1617,7 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst return false; } - if ((checkMethod->mIsVirtual) && (!checkMethod->mIsOverride) && (!mBypassVirtual) && + if ((checkMethod->mIsVirtual) && (!checkMethod->mIsOverride) && (!mBypassVirtual) && (targetTypeInstance != NULL) && (targetTypeInstance->IsObject())) { mModule->PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods); @@ -6394,6 +6394,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu if (delegateInfo->mHasExplicitThis) { target = mModule->GetDefaultTypedValue(delegateInfo->mParams[0], false, BfDefaultValueKind_Addr); + bypassVirtual = true; } } else if (bindResult->mBindType->IsFunction()) @@ -8158,7 +8159,7 @@ BfTypedValue BfExprEvaluator::MatchMethod(BfAstNode* targetSrc, BfMethodBoundExp methodMatcher.mAllowImplicitThis = allowImplicitThis; methodMatcher.mAllowStatic = !target.mValue; methodMatcher.mAllowNonStatic = !methodMatcher.mAllowStatic; - methodMatcher.mAutoFlushAmbiguityErrors = !wantsExtensionCheck; + methodMatcher.mAutoFlushAmbiguityErrors = !wantsExtensionCheck; if (allowImplicitThis) { if (mModule->mCurMethodState == NULL) @@ -11643,7 +11644,7 @@ void BfExprEvaluator::Visit(BfDelegateBindExpression* delegateBindExpr) return; } } - result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mOrigTarget, bindResult.mMethodInstance, mExpectingType); + result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mOrigTarget, bindResult.mMethodInstance, mExpectingType, BfCastFlags_None, bindResult.mFunc); } if (result) mResult = BfTypedValue(result, mExpectingType); diff --git a/IDEHelper/Compiler/BfExprEvaluator.h b/IDEHelper/Compiler/BfExprEvaluator.h index 46a8fc9a..3e7c7329 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.h +++ b/IDEHelper/Compiler/BfExprEvaluator.h @@ -198,7 +198,7 @@ public: bool mAllowStatic; bool mAllowNonStatic; bool mSkipImplicitParams; - bool mAutoFlushAmbiguityErrors; + bool mAutoFlushAmbiguityErrors; BfEvalExprFlags mBfEvalExprFlags; int mMethodCheckCount; BfType* mExplicitInterfaceCheck; diff --git a/IDEHelper/Compiler/BfModule.h b/IDEHelper/Compiler/BfModule.h index 0b541a33..873aa130 100644 --- a/IDEHelper/Compiler/BfModule.h +++ b/IDEHelper/Compiler/BfModule.h @@ -1601,7 +1601,7 @@ public: bool CanCast(BfTypedValue typedVal, BfType* toType, BfCastFlags castFlags = BfCastFlags_None); bool AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNeedsMemberCasting); BfTypedValue BoxValue(BfAstNode* srcNode, BfTypedValue typedVal, BfType* toType /*Can be System.Object or interface*/, const BfAllocTarget& allocTarget, bool callDtor = true); - BfIRValue CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags = BfCastFlags_None); + BfIRValue CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfIRValue irFunc = BfIRValue()); BfIRValue CastToValue(BfAstNode* srcNode, BfTypedValue val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfCastResultFlags* resultFlags = NULL); BfTypedValue Cast(BfAstNode* srcNode, const BfTypedValue& val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None); BfPrimitiveType* GetIntCoercibleType(BfType* type); diff --git a/IDEHelper/Compiler/BfModuleTypeUtils.cpp b/IDEHelper/Compiler/BfModuleTypeUtils.cpp index de7d4aab..5cd10923 100644 --- a/IDEHelper/Compiler/BfModuleTypeUtils.cpp +++ b/IDEHelper/Compiler/BfModuleTypeUtils.cpp @@ -10880,11 +10880,33 @@ bool BfModule::AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNe return true; } -BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags) +BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags, BfIRValue irFunc) { auto invokeMethodInstance = GetDelegateInvokeMethod(toType->ToTypeInstance()); - if (invokeMethodInstance->IsExactMatch(methodInstance, false, true)) + bool methodsThisMatch = true; + if (invokeMethodInstance->mMethodDef->mIsStatic != methodInstance->mMethodDef->mIsStatic) + methodsThisMatch = false; + else + { + if (!methodInstance->mMethodDef->mIsStatic) + { + BfType* thisType = methodInstance->GetThisType(); + if (thisType->IsPointer()) + thisType = thisType->GetUnderlyingType(); + BfType* invokeThisType = invokeMethodInstance->GetThisType(); + if (invokeThisType->IsPointer()) + invokeThisType = invokeThisType->GetUnderlyingType(); + if (!TypeIsSubTypeOf(thisType->ToTypeInstance(), invokeThisType->ToTypeInstance())) + methodsThisMatch = false; + } + } + + bool methodMatches = methodsThisMatch; + if (methodMatches) + methodMatches = invokeMethodInstance->IsExactMatch(methodInstance, false, false); + + if (methodMatches) { if (methodInstance->GetOwner()->IsFunction()) { @@ -10892,19 +10914,23 @@ BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targe return targetValue.mValue; } - BfModuleMethodInstance methodRefMethod; - if (methodInstance->mDeclModule == this) - methodRefMethod = methodInstance; - else - methodRefMethod = ReferenceExternalMethodInstance(methodInstance); - auto dataType = GetPrimitiveType(BfTypeCode_IntPtr); - if (!methodRefMethod.mFunc) + BfIRFunction bindFuncVal = irFunc; + if (!bindFuncVal) { - if ((!methodInstance->mIsUnspecialized) && (HasCompiledOutput())) - AssertErrorState(); - return GetDefaultValue(dataType); + BfModuleMethodInstance methodRefMethod; + if (methodInstance->mDeclModule == this) + methodRefMethod = methodInstance; + else + methodRefMethod = ReferenceExternalMethodInstance(methodInstance); + auto dataType = GetPrimitiveType(BfTypeCode_IntPtr); + if (!methodRefMethod.mFunc) + { + if ((!methodInstance->mIsUnspecialized) && (HasCompiledOutput())) + AssertErrorState(); + return GetDefaultValue(dataType); + } + bindFuncVal = methodRefMethod.mFunc; } - auto bindFuncVal = methodRefMethod.mFunc; if (mCompiler->mOptions.mAllowHotSwapping) bindFuncVal = mBfIRBuilder->RemapBindFunction(bindFuncVal); return mBfIRBuilder->CreatePtrToInt(bindFuncVal, BfTypeCode_IntPtr); @@ -10912,7 +10938,7 @@ BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targe if ((castFlags & BfCastFlags_SilentFail) == 0) { - if (invokeMethodInstance->IsExactMatch(methodInstance, true, true)) + if ((methodsThisMatch) && (invokeMethodInstance->IsExactMatch(methodInstance, true, false))) { Fail(StrFormat("Non-static method '%s' cannot match '%s' because it contains captured variables, consider using a delegate or removing captures", MethodToString(methodInstance).c_str(), TypeToString(toType).c_str()), srcNode); } @@ -10940,7 +10966,7 @@ BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targe invokeThisWasPtr = true; } - if (invokeThisType == thisType) + if (TypeIsSubTypeOf(thisType->ToTypeInstance(), invokeThisType->ToTypeInstance())) { if (invokeThisWasPtr != thisWasPtr) { diff --git a/IDEHelper/Tests/src/Functions.bf b/IDEHelper/Tests/src/Functions.bf index 669c1a4a..24f22151 100644 --- a/IDEHelper/Tests/src/Functions.bf +++ b/IDEHelper/Tests/src/Functions.bf @@ -8,19 +8,32 @@ namespace Tests { class ClassA { - int mA = 123; + public int mA = 123; public int GetA(float f) { return mA + (int)f; } + public virtual int GetB(float f) + { + return mA + (int)f + 1000; + } + public int GetT(T val) where T : var { return mA + (int)val; } } + class ClassB : ClassA + { + public override int GetB(float f) + { + return mA + (int)f + 2000; + } + } + struct StructA { int mA = 123; @@ -159,7 +172,9 @@ namespace Tests [Test] public static void TestBasics() { - ClassA ca = scope .(); + ClassA ca = scope ClassA(); + ClassA ca2 = scope ClassB(); + StructA sa = .(); StructB sb = .(); @@ -169,6 +184,17 @@ namespace Tests func0 = => ca.GetT; Test.Assert(func0(ca, 100.0f) == 223); + func0 = => ca.GetB; + Test.Assert(func0(ca, 100.0f) == 1223); + func0 = => ca2.GetB; + Test.Assert(func0(ca, 100.0f) == 2223); + func0 = => ClassA.GetB; + Test.Assert(func0(ca, 100.0f) == 1223); + func0 = => ClassB.GetB; + Test.Assert(func0(ca, 100.0f) == 2223); + + func0 = => ca.GetA; + function int (StructA this, float f) func1 = => sa.GetA; var func1b = func1; Test.Assert(func1(sa, 100.0f) == 23623);