1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-10 04:22:20 +02:00

Improved function binding to virtual methods

This commit is contained in:
Brian Fiete 2021-07-05 14:36:39 -07:00
parent ee06457b2f
commit 46cc3d088b
5 changed files with 75 additions and 22 deletions

View file

@ -1617,7 +1617,7 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
return false; return false;
} }
if ((checkMethod->mIsVirtual) && (!checkMethod->mIsOverride) && (!mBypassVirtual) && if ((checkMethod->mIsVirtual) && (!checkMethod->mIsOverride) && (!mBypassVirtual) &&
(targetTypeInstance != NULL) && (targetTypeInstance->IsObject())) (targetTypeInstance != NULL) && (targetTypeInstance->IsObject()))
{ {
mModule->PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods); mModule->PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods);
@ -6394,6 +6394,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu
if (delegateInfo->mHasExplicitThis) if (delegateInfo->mHasExplicitThis)
{ {
target = mModule->GetDefaultTypedValue(delegateInfo->mParams[0], false, BfDefaultValueKind_Addr); target = mModule->GetDefaultTypedValue(delegateInfo->mParams[0], false, BfDefaultValueKind_Addr);
bypassVirtual = true;
} }
} }
else if (bindResult->mBindType->IsFunction()) else if (bindResult->mBindType->IsFunction())
@ -8158,7 +8159,7 @@ BfTypedValue BfExprEvaluator::MatchMethod(BfAstNode* targetSrc, BfMethodBoundExp
methodMatcher.mAllowImplicitThis = allowImplicitThis; methodMatcher.mAllowImplicitThis = allowImplicitThis;
methodMatcher.mAllowStatic = !target.mValue; methodMatcher.mAllowStatic = !target.mValue;
methodMatcher.mAllowNonStatic = !methodMatcher.mAllowStatic; methodMatcher.mAllowNonStatic = !methodMatcher.mAllowStatic;
methodMatcher.mAutoFlushAmbiguityErrors = !wantsExtensionCheck; methodMatcher.mAutoFlushAmbiguityErrors = !wantsExtensionCheck;
if (allowImplicitThis) if (allowImplicitThis)
{ {
if (mModule->mCurMethodState == NULL) if (mModule->mCurMethodState == NULL)
@ -11643,7 +11644,7 @@ void BfExprEvaluator::Visit(BfDelegateBindExpression* delegateBindExpr)
return; return;
} }
} }
result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mOrigTarget, bindResult.mMethodInstance, mExpectingType); result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mOrigTarget, bindResult.mMethodInstance, mExpectingType, BfCastFlags_None, bindResult.mFunc);
} }
if (result) if (result)
mResult = BfTypedValue(result, mExpectingType); mResult = BfTypedValue(result, mExpectingType);

View file

@ -198,7 +198,7 @@ public:
bool mAllowStatic; bool mAllowStatic;
bool mAllowNonStatic; bool mAllowNonStatic;
bool mSkipImplicitParams; bool mSkipImplicitParams;
bool mAutoFlushAmbiguityErrors; bool mAutoFlushAmbiguityErrors;
BfEvalExprFlags mBfEvalExprFlags; BfEvalExprFlags mBfEvalExprFlags;
int mMethodCheckCount; int mMethodCheckCount;
BfType* mExplicitInterfaceCheck; BfType* mExplicitInterfaceCheck;

View file

@ -1601,7 +1601,7 @@ public:
bool CanCast(BfTypedValue typedVal, BfType* toType, BfCastFlags castFlags = BfCastFlags_None); bool CanCast(BfTypedValue typedVal, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
bool AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNeedsMemberCasting); bool AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNeedsMemberCasting);
BfTypedValue BoxValue(BfAstNode* srcNode, BfTypedValue typedVal, BfType* toType /*Can be System.Object or interface*/, const BfAllocTarget& allocTarget, bool callDtor = true); BfTypedValue BoxValue(BfAstNode* srcNode, BfTypedValue typedVal, BfType* toType /*Can be System.Object or interface*/, const BfAllocTarget& allocTarget, bool callDtor = true);
BfIRValue CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags = BfCastFlags_None); BfIRValue CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfIRValue irFunc = BfIRValue());
BfIRValue CastToValue(BfAstNode* srcNode, BfTypedValue val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfCastResultFlags* resultFlags = NULL); BfIRValue CastToValue(BfAstNode* srcNode, BfTypedValue val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfCastResultFlags* resultFlags = NULL);
BfTypedValue Cast(BfAstNode* srcNode, const BfTypedValue& val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None); BfTypedValue Cast(BfAstNode* srcNode, const BfTypedValue& val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
BfPrimitiveType* GetIntCoercibleType(BfType* type); BfPrimitiveType* GetIntCoercibleType(BfType* type);

View file

@ -10880,11 +10880,33 @@ bool BfModule::AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNe
return true; return true;
} }
BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags) BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags, BfIRValue irFunc)
{ {
auto invokeMethodInstance = GetDelegateInvokeMethod(toType->ToTypeInstance()); auto invokeMethodInstance = GetDelegateInvokeMethod(toType->ToTypeInstance());
if (invokeMethodInstance->IsExactMatch(methodInstance, false, true)) bool methodsThisMatch = true;
if (invokeMethodInstance->mMethodDef->mIsStatic != methodInstance->mMethodDef->mIsStatic)
methodsThisMatch = false;
else
{
if (!methodInstance->mMethodDef->mIsStatic)
{
BfType* thisType = methodInstance->GetThisType();
if (thisType->IsPointer())
thisType = thisType->GetUnderlyingType();
BfType* invokeThisType = invokeMethodInstance->GetThisType();
if (invokeThisType->IsPointer())
invokeThisType = invokeThisType->GetUnderlyingType();
if (!TypeIsSubTypeOf(thisType->ToTypeInstance(), invokeThisType->ToTypeInstance()))
methodsThisMatch = false;
}
}
bool methodMatches = methodsThisMatch;
if (methodMatches)
methodMatches = invokeMethodInstance->IsExactMatch(methodInstance, false, false);
if (methodMatches)
{ {
if (methodInstance->GetOwner()->IsFunction()) if (methodInstance->GetOwner()->IsFunction())
{ {
@ -10892,19 +10914,23 @@ BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targe
return targetValue.mValue; return targetValue.mValue;
} }
BfModuleMethodInstance methodRefMethod; BfIRFunction bindFuncVal = irFunc;
if (methodInstance->mDeclModule == this) if (!bindFuncVal)
methodRefMethod = methodInstance;
else
methodRefMethod = ReferenceExternalMethodInstance(methodInstance);
auto dataType = GetPrimitiveType(BfTypeCode_IntPtr);
if (!methodRefMethod.mFunc)
{ {
if ((!methodInstance->mIsUnspecialized) && (HasCompiledOutput())) BfModuleMethodInstance methodRefMethod;
AssertErrorState(); if (methodInstance->mDeclModule == this)
return GetDefaultValue(dataType); methodRefMethod = methodInstance;
else
methodRefMethod = ReferenceExternalMethodInstance(methodInstance);
auto dataType = GetPrimitiveType(BfTypeCode_IntPtr);
if (!methodRefMethod.mFunc)
{
if ((!methodInstance->mIsUnspecialized) && (HasCompiledOutput()))
AssertErrorState();
return GetDefaultValue(dataType);
}
bindFuncVal = methodRefMethod.mFunc;
} }
auto bindFuncVal = methodRefMethod.mFunc;
if (mCompiler->mOptions.mAllowHotSwapping) if (mCompiler->mOptions.mAllowHotSwapping)
bindFuncVal = mBfIRBuilder->RemapBindFunction(bindFuncVal); bindFuncVal = mBfIRBuilder->RemapBindFunction(bindFuncVal);
return mBfIRBuilder->CreatePtrToInt(bindFuncVal, BfTypeCode_IntPtr); return mBfIRBuilder->CreatePtrToInt(bindFuncVal, BfTypeCode_IntPtr);
@ -10912,7 +10938,7 @@ BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targe
if ((castFlags & BfCastFlags_SilentFail) == 0) if ((castFlags & BfCastFlags_SilentFail) == 0)
{ {
if (invokeMethodInstance->IsExactMatch(methodInstance, true, true)) if ((methodsThisMatch) && (invokeMethodInstance->IsExactMatch(methodInstance, true, false)))
{ {
Fail(StrFormat("Non-static method '%s' cannot match '%s' because it contains captured variables, consider using a delegate or removing captures", MethodToString(methodInstance).c_str(), TypeToString(toType).c_str()), srcNode); Fail(StrFormat("Non-static method '%s' cannot match '%s' because it contains captured variables, consider using a delegate or removing captures", MethodToString(methodInstance).c_str(), TypeToString(toType).c_str()), srcNode);
} }
@ -10940,7 +10966,7 @@ BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targe
invokeThisWasPtr = true; invokeThisWasPtr = true;
} }
if (invokeThisType == thisType) if (TypeIsSubTypeOf(thisType->ToTypeInstance(), invokeThisType->ToTypeInstance()))
{ {
if (invokeThisWasPtr != thisWasPtr) if (invokeThisWasPtr != thisWasPtr)
{ {

View file

@ -8,19 +8,32 @@ namespace Tests
{ {
class ClassA class ClassA
{ {
int mA = 123; public int mA = 123;
public int GetA(float f) public int GetA(float f)
{ {
return mA + (int)f; return mA + (int)f;
} }
public virtual int GetB(float f)
{
return mA + (int)f + 1000;
}
public int GetT<T>(T val) where T : var public int GetT<T>(T val) where T : var
{ {
return mA + (int)val; return mA + (int)val;
} }
} }
class ClassB : ClassA
{
public override int GetB(float f)
{
return mA + (int)f + 2000;
}
}
struct StructA struct StructA
{ {
int mA = 123; int mA = 123;
@ -159,7 +172,9 @@ namespace Tests
[Test] [Test]
public static void TestBasics() public static void TestBasics()
{ {
ClassA ca = scope .(); ClassA ca = scope ClassA();
ClassA ca2 = scope ClassB();
StructA sa = .(); StructA sa = .();
StructB sb = .(); StructB sb = .();
@ -169,6 +184,17 @@ namespace Tests
func0 = => ca.GetT<float>; func0 = => ca.GetT<float>;
Test.Assert(func0(ca, 100.0f) == 223); Test.Assert(func0(ca, 100.0f) == 223);
func0 = => ca.GetB;
Test.Assert(func0(ca, 100.0f) == 1223);
func0 = => ca2.GetB;
Test.Assert(func0(ca, 100.0f) == 2223);
func0 = => ClassA.GetB;
Test.Assert(func0(ca, 100.0f) == 1223);
func0 = => ClassB.GetB;
Test.Assert(func0(ca, 100.0f) == 2223);
func0 = => ca.GetA;
function int (StructA this, float f) func1 = => sa.GetA; function int (StructA this, float f) func1 = => sa.GetA;
var func1b = func1; var func1b = func1;
Test.Assert(func1(sa, 100.0f) == 23623); Test.Assert(func1(sa, 100.0f) == 23623);