From 37f72cd3b67779fdff0878942f4b5f8a8d9270cc Mon Sep 17 00:00:00 2001 From: Brian Fiete Date: Sat, 22 Mar 2025 15:34:59 -0400 Subject: [PATCH] Added ability to dynamically cast delegates with compatible signatures --- BeefLibs/corlib/src/Object.bf | 6 + IDEHelper/Compiler/BfCompiler.cpp | 6 + IDEHelper/Compiler/BfDefBuilder.cpp | 21 +++ IDEHelper/Compiler/BfModule.cpp | 228 +++++++++++++++++++--------- IDEHelper/Compiler/BfModule.h | 4 + IDEHelper/Compiler/BfSystem.h | 1 + IDEHelper/Compiler/CeMachine.cpp | 18 ++- IDEHelper/Tests/src/Delegates.bf | 9 +- 8 files changed, 217 insertions(+), 76 deletions(-) diff --git a/BeefLibs/corlib/src/Object.bf b/BeefLibs/corlib/src/Object.bf index 0c70314f..a959af3d 100644 --- a/BeefLibs/corlib/src/Object.bf +++ b/BeefLibs/corlib/src/Object.bf @@ -130,6 +130,12 @@ namespace System { return null; } + + [NoShow] + public virtual Object DynamicCastToSignature(int32 sig) + { + return null; + } #endif int IHashable.GetHashCode() diff --git a/IDEHelper/Compiler/BfCompiler.cpp b/IDEHelper/Compiler/BfCompiler.cpp index 3a1303a2..aff318a3 100644 --- a/IDEHelper/Compiler/BfCompiler.cpp +++ b/IDEHelper/Compiler/BfCompiler.cpp @@ -5622,6 +5622,12 @@ void BfCompiler::MarkStringPool(BfModule* module) stringPoolEntry.mLastUsedRevision = mRevision; } + for (int stringId : module->mSignatureIdRefs) + { + BfStringPoolEntry& stringPoolEntry = module->mContext->mStringObjectIdMap[stringId]; + stringPoolEntry.mLastUsedRevision = mRevision; + } + /*if (module->mOptModule != NULL) MarkStringPool(module->mOptModule);*/ auto altModule = module->mNextAltModule; diff --git a/IDEHelper/Compiler/BfDefBuilder.cpp b/IDEHelper/Compiler/BfDefBuilder.cpp index e5eba3b3..b433e1b7 100644 --- a/IDEHelper/Compiler/BfDefBuilder.cpp +++ b/IDEHelper/Compiler/BfDefBuilder.cpp @@ -1398,6 +1398,27 @@ void BfDefBuilder::AddDynamicCastMethods(BfTypeDef* typeDef) methodDef->mReturnTypeRef = typeDef->mSystem->mDirectObjectTypeRef; methodDef->mIsNoReflect = true; } + + if ((typeDef->mIsDelegate) && (!typeDef->mIsClosure)) + { + auto methodDef = new BfMethodDef(); + methodDef->mIdx = (int)typeDef->mMethods.size(); + typeDef->mMethods.push_back(methodDef); + methodDef->mDeclaringType = typeDef; + methodDef->mName = BF_METHODNAME_DYNAMICCAST_SIGNATURE; + methodDef->mProtection = BfProtection_Protected; + methodDef->mIsStatic = false; + methodDef->mMethodType = BfMethodType_Normal; + methodDef->mIsVirtual = true; + methodDef->mIsOverride = true; + + auto paramDef = new BfParameterDef(); + paramDef->mName = "sig"; + paramDef->mTypeRef = typeDef->mSystem->mDirectInt32TypeRef; + methodDef->mParams.push_back(paramDef); + methodDef->mReturnTypeRef = typeDef->mSystem->mDirectObjectTypeRef; + methodDef->mIsNoReflect = true; + } } void BfDefBuilder::AddParam(BfMethodDef* methodDef, BfTypeReference* typeRef, const StringImpl& paramName) diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index 6e4197dc..052888ce 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -5053,91 +5053,106 @@ void BfModule::CreateDynamicCastMethod() } bool isInterfacePass = mCurMethodInstance->mMethodDef->mName == BF_METHODNAME_DYNAMICCAST_INTERFACE; - + bool isSignaturePass = mCurMethodInstance->mMethodDef->mName == BF_METHODNAME_DYNAMICCAST_SIGNATURE; + auto func = mCurMethodState->mIRFunction; auto thisValue = mBfIRBuilder->GetArgument(0); auto typeIdValue = mBfIRBuilder->GetArgument(1); - auto intPtrType = GetPrimitiveType(BfTypeCode_IntPtr); - auto int32Type = GetPrimitiveType(BfTypeCode_Int32); - typeIdValue = CastToValue(NULL, BfTypedValue(typeIdValue, intPtrType), int32Type, (BfCastFlags)(BfCastFlags_Explicit | BfCastFlags_SilentFail)); - auto thisObject = mBfIRBuilder->CreateBitCast(thisValue, mBfIRBuilder->MapType(objType)); auto trueBB = mBfIRBuilder->CreateBlock("check.true"); //auto falseBB = mBfIRBuilder->CreateBlock("check.false"); auto exitBB = mBfIRBuilder->CreateBlock("exit"); - SizedArray typeMatches; - SizedArray exChecks; - FindSubTypes(mCurTypeInstance, &typeMatches, &exChecks, isInterfacePass); - - if ((mCurTypeInstance->IsGenericTypeInstance()) && (!mCurTypeInstance->IsUnspecializedType())) - { - // Add 'unbound' type id to cast list so things like "List is List<>" work - auto genericTypeInst = mCurTypeInstance->mTypeDef; - BfTypeVector genericArgs; - for (int i = 0; i < (int) genericTypeInst->mGenericParamDefs.size(); i++) - genericArgs.push_back(GetGenericParamType(BfGenericParamKind_Type, i)); - auto unboundType = ResolveTypeDef(mCurTypeInstance->mTypeDef->GetDefinition(), genericArgs, BfPopulateType_Declaration); - typeMatches.push_back(unboundType->mTypeId); - } - - if (mCurTypeInstance->IsBoxed()) - { - BfBoxedType* boxedType = (BfBoxedType*)mCurTypeInstance; - BfTypeInstance* innerType = boxedType->mElementType->ToTypeInstance(); - - FindSubTypes(innerType, &typeMatches, &exChecks, isInterfacePass); - - if (innerType->IsTypedPrimitive()) - { - auto underlyingType = innerType->GetUnderlyingType(); - typeMatches.push_back(underlyingType->mTypeId); - } - - auto innerTypeInst = innerType->ToTypeInstance(); - if ((innerTypeInst->IsInstanceOf(mCompiler->mSizedArrayTypeDef)) || - (innerTypeInst->IsInstanceOf(mCompiler->mPointerTTypeDef)) || - (innerTypeInst->IsInstanceOf(mCompiler->mMethodRefTypeDef))) - { - PopulateType(innerTypeInst); - //TODO: What case was this supposed to handle? - //typeMatches.push_back(innerTypeInst->mFieldInstances[0].mResolvedType->mTypeId); - } - } - - auto curBlock = mBfIRBuilder->GetInsertBlock(); - - BfIRValue vDataPtr; - if (!exChecks.empty()) - { - BfType* intPtrType = GetPrimitiveType(BfTypeCode_IntPtr); - auto ptrPtrType = mBfIRBuilder->GetPointerTo(mBfIRBuilder->GetPointerTo(mBfIRBuilder->MapType(intPtrType))); - auto vDataPtrPtr = mBfIRBuilder->CreateBitCast(thisValue, ptrPtrType); - vDataPtr = FixClassVData(mBfIRBuilder->CreateLoad(vDataPtrPtr/*, "vtable"*/)); - } - - auto switchStatement = mBfIRBuilder->CreateSwitch(typeIdValue, exitBB, (int)typeMatches.size() + (int)exChecks.size()); - for (auto typeMatch : typeMatches) - mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(typeMatch), trueBB); - Array incomingFalses; - for (auto ifaceTypeInst : exChecks) + BfIRBlock curBlock; + + if (isSignaturePass) { - BfIRBlock nextBB = mBfIRBuilder->CreateBlock("exCheck", true); - mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(ifaceTypeInst->mTypeId), nextBB); - mBfIRBuilder->SetInsertPoint(nextBB); + //auto falseBB = mBfIRBuilder->CreateBlock("check.false"); + curBlock = mBfIRBuilder->GetInsertBlock(); - BfIRValue slotOfs = GetInterfaceSlotNum(ifaceTypeInst); + auto signatureId = GetDelegateSignatureId(mCurTypeInstance); + auto eqResult = mBfIRBuilder->CreateCmpEQ(typeIdValue, GetConstValue32(signatureId)); + mBfIRBuilder->CreateCondBr(eqResult, trueBB, exitBB); + } + else + { + auto intPtrType = GetPrimitiveType(BfTypeCode_IntPtr); + auto int32Type = GetPrimitiveType(BfTypeCode_Int32); + typeIdValue = CastToValue(NULL, BfTypedValue(typeIdValue, intPtrType), int32Type, (BfCastFlags)(BfCastFlags_Explicit | BfCastFlags_SilentFail)); - auto ifacePtrPtr = mBfIRBuilder->CreateInBoundsGEP(vDataPtr, slotOfs/*, "iface"*/); - auto ifacePtr = mBfIRBuilder->CreateLoad(ifacePtrPtr); + SizedArray typeMatches; + SizedArray exChecks; + FindSubTypes(mCurTypeInstance, &typeMatches, &exChecks, isInterfacePass); - auto cmpResult = mBfIRBuilder->CreateCmpNE(ifacePtr, mBfIRBuilder->CreateConst(BfTypeCode_IntPtr, 0)); - mBfIRBuilder->CreateCondBr(cmpResult, trueBB, exitBB); + if ((mCurTypeInstance->IsGenericTypeInstance()) && (!mCurTypeInstance->IsUnspecializedType())) + { + // Add 'unbound' type id to cast list so things like "List is List<>" work + auto genericTypeInst = mCurTypeInstance->mTypeDef; + BfTypeVector genericArgs; + for (int i = 0; i < (int)genericTypeInst->mGenericParamDefs.size(); i++) + genericArgs.push_back(GetGenericParamType(BfGenericParamKind_Type, i)); + auto unboundType = ResolveTypeDef(mCurTypeInstance->mTypeDef->GetDefinition(), genericArgs, BfPopulateType_Declaration); + typeMatches.push_back(unboundType->mTypeId); + } - incomingFalses.push_back(nextBB); + if (mCurTypeInstance->IsBoxed()) + { + BfBoxedType* boxedType = (BfBoxedType*)mCurTypeInstance; + BfTypeInstance* innerType = boxedType->mElementType->ToTypeInstance(); + + FindSubTypes(innerType, &typeMatches, &exChecks, isInterfacePass); + + if (innerType->IsTypedPrimitive()) + { + auto underlyingType = innerType->GetUnderlyingType(); + typeMatches.push_back(underlyingType->mTypeId); + } + + auto innerTypeInst = innerType->ToTypeInstance(); + if ((innerTypeInst->IsInstanceOf(mCompiler->mSizedArrayTypeDef)) || + (innerTypeInst->IsInstanceOf(mCompiler->mPointerTTypeDef)) || + (innerTypeInst->IsInstanceOf(mCompiler->mMethodRefTypeDef))) + { + PopulateType(innerTypeInst); + //TODO: What case was this supposed to handle? + //typeMatches.push_back(innerTypeInst->mFieldInstances[0].mResolvedType->mTypeId); + } + } + + curBlock = mBfIRBuilder->GetInsertBlock(); + + BfIRValue vDataPtr; + if (!exChecks.empty()) + { + BfType* intPtrType = GetPrimitiveType(BfTypeCode_IntPtr); + auto ptrPtrType = mBfIRBuilder->GetPointerTo(mBfIRBuilder->GetPointerTo(mBfIRBuilder->MapType(intPtrType))); + auto vDataPtrPtr = mBfIRBuilder->CreateBitCast(thisValue, ptrPtrType); + vDataPtr = FixClassVData(mBfIRBuilder->CreateLoad(vDataPtrPtr/*, "vtable"*/)); + } + + auto switchStatement = mBfIRBuilder->CreateSwitch(typeIdValue, exitBB, (int)typeMatches.size() + (int)exChecks.size()); + for (auto typeMatch : typeMatches) + mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(typeMatch), trueBB); + + for (auto ifaceTypeInst : exChecks) + { + BfIRBlock nextBB = mBfIRBuilder->CreateBlock("exCheck", true); + mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(ifaceTypeInst->mTypeId), nextBB); + mBfIRBuilder->SetInsertPoint(nextBB); + + BfIRValue slotOfs = GetInterfaceSlotNum(ifaceTypeInst); + + auto ifacePtrPtr = mBfIRBuilder->CreateInBoundsGEP(vDataPtr, slotOfs/*, "iface"*/); + auto ifacePtr = mBfIRBuilder->CreateLoad(ifacePtrPtr); + + auto cmpResult = mBfIRBuilder->CreateCmpNE(ifacePtr, mBfIRBuilder->CreateConst(BfTypeCode_IntPtr, 0)); + mBfIRBuilder->CreateCondBr(cmpResult, trueBB, exitBB); + + incomingFalses.push_back(nextBB); + } } mBfIRBuilder->AddBlock(trueBB); @@ -10947,8 +10962,25 @@ void BfModule::EmitDynamicCastCheck(const BfTypedValue& targetValue, BfType* tar auto typeTypeInstance = ResolveTypeDef(mCompiler->mReflectTypeInstanceTypeDef)->ToTypeInstance(); - if (mCompiler->mOptions.mAllowHotSwapping) + if (targetType->IsDelegate()) { + // Delegate signature check + int signatureId = GetDelegateSignatureId(targetType->ToTypeInstance()); + BfExprEvaluator exprEvaluator(this); + + AddBasicBlock(checkBB); + auto objectParam = mBfIRBuilder->CreateBitCast(targetValue.mValue, mBfIRBuilder->MapType(mContext->mBfObjectType)); + auto moduleMethodInstance = GetMethodByName(mContext->mBfObjectType, "DynamicCastToSignature"); + SizedArray irArgs; + irArgs.push_back(objectParam); + irArgs.push_back(GetConstValue32(signatureId)); + auto callResult = exprEvaluator.CreateCall(NULL, moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, false, irArgs); + auto cmpResult = mBfIRBuilder->CreateCmpNE(callResult.mValue, GetDefaultValue(callResult.mType)); + irb->CreateCondBr(cmpResult, trueBlock, falseBlock); + } + else if (mCompiler->mOptions.mAllowHotSwapping) + { + // "Slow" check BfExprEvaluator exprEvaluator(this); AddBasicBlock(checkBB); @@ -10967,7 +10999,7 @@ void BfModule::EmitDynamicCastCheck(const BfTypedValue& targetValue, BfType* tar BfIRValue vDataPtr = irb->CreateBitCast(targetValue.mValue, irb->MapType(intPtrType)); vDataPtr = irb->CreateLoad(vDataPtr); if ((mCompiler->mOptions.mObjectHasDebugFlags) && (!mIsComptimeModule)) - vDataPtr = irb->CreateAnd(vDataPtr, irb->CreateConst(BfTypeCode_IntPtr, (uint64)~0xFFULL)); + vDataPtr = irb->CreateAnd(vDataPtr, irb->CreateConst(BfTypeCode_IntPtr, (uint64)~0xFFULL)); if (targetType->IsInterface()) { @@ -17291,9 +17323,61 @@ BfType* BfModule::GetDelegateReturnType(BfType* delegateType) BfMethodInstance* BfModule::GetDelegateInvokeMethod(BfTypeInstance* typeInstance) { + if (typeInstance->IsClosure()) + typeInstance = typeInstance->mBaseType; return GetRawMethodInstanceAtIdx(typeInstance, 0, "Invoke"); } +String BfModule::GetDelegateSignatureString(BfTypeInstance* typeInstance) +{ + auto invokeMethod = GetDelegateInvokeMethod(typeInstance); + if (invokeMethod == NULL) + return ""; + + String sigString = ""; + sigString = TypeToString(invokeMethod->mReturnType); + sigString += "("; + for (int paramIdx = 0; paramIdx < invokeMethod->GetParamCount(); paramIdx++) + { + if (paramIdx > 0) + sigString += ", "; + + auto paramKind = invokeMethod->GetParamKind(paramIdx); + + if (paramKind == BfParamKind_Params) + { + sigString += "params "; + } + + auto paramType = invokeMethod->GetParamType(paramIdx); + sigString += TypeToString(paramType); + + if (paramKind == BfParamKind_ExplicitThis) + sigString += " this"; + } + sigString += ")"; + return sigString; +} + +int BfModule::GetSignatureId(const StringImpl& str) +{ + int strId = mContext->GetStringLiteralId(str); + mSignatureIdRefs.Add(strId); + return strId; +} + +int BfModule::GetDelegateSignatureId(BfTypeInstance* typeInstance) +{ + BF_ASSERT(typeInstance->IsDelegate()); + if (typeInstance->mTypeInfoEx == NULL) + { + typeInstance->mTypeInfoEx = new BfTypeInfoEx(); + auto signature = GetDelegateSignatureString(typeInstance); + typeInstance->mTypeInfoEx->mMinValue = GetSignatureId(signature); + } + return (int)typeInstance->mTypeInfoEx->mMinValue; +} + void BfModule::CreateDelegateInvokeMethod() { // Clear out debug loc - otherwise we'll single step onto the delegate type declaration @@ -22544,7 +22628,7 @@ void BfModule::ProcessMethod(BfMethodInstance* methodInstance, bool isInlineDup, mCurMethodState->mLeftBlockUncond = true; } } - else if ((methodDef->mName == BF_METHODNAME_DYNAMICCAST) || (methodDef->mName == BF_METHODNAME_DYNAMICCAST_INTERFACE)) + else if ((methodDef->mName == BF_METHODNAME_DYNAMICCAST) || (methodDef->mName == BF_METHODNAME_DYNAMICCAST_INTERFACE) || (methodDef->mName == BF_METHODNAME_DYNAMICCAST_SIGNATURE)) { if (mCurTypeInstance->IsObject()) { diff --git a/IDEHelper/Compiler/BfModule.h b/IDEHelper/Compiler/BfModule.h index 8fd8cdac..455e06c5 100644 --- a/IDEHelper/Compiler/BfModule.h +++ b/IDEHelper/Compiler/BfModule.h @@ -1552,6 +1552,7 @@ public: Dictionary mStringCharPtrPool; Array mStringPoolRefs; HashSet mUnreifiedStringPoolRefs; + HashSet mSignatureIdRefs; Array mPrevIRBuilders; // Before extensions BfIRBuilder* mBfIRBuilder; @@ -2019,6 +2020,9 @@ public: void CreateDelegateInvokeMethod(); BfType* GetDelegateReturnType(BfType* delegateType); BfMethodInstance* GetDelegateInvokeMethod(BfTypeInstance* typeInstance); + String GetDelegateSignatureString(BfTypeInstance* typeInstance); + int GetSignatureId(const StringImpl& str); + int GetDelegateSignatureId(BfTypeInstance* typeInstance); String GetLocalMethodName(const StringImpl& baseName, BfAstNode* anchorNode, BfMethodState* declMethodState, BfMixinState* declMixinState); BfMethodDef* GetLocalMethodDef(BfLocalMethod* localMethod); BfModuleMethodInstance GetLocalMethodInstance(BfLocalMethod* localMethod, const BfTypeVector& methodGenericArguments, BfMethodInstance* methodInstance = NULL, bool force = false); diff --git a/IDEHelper/Compiler/BfSystem.h b/IDEHelper/Compiler/BfSystem.h index df0ea772..9ad8755c 100644 --- a/IDEHelper/Compiler/BfSystem.h +++ b/IDEHelper/Compiler/BfSystem.h @@ -851,6 +851,7 @@ enum BfCallingConvention : uint8 #define BF_METHODNAME_FIND_TLS_MEMBERS "GCFindTLSMembers" #define BF_METHODNAME_DYNAMICCAST "DynamicCastToTypeId" #define BF_METHODNAME_DYNAMICCAST_INTERFACE "DynamicCastToInterface" +#define BF_METHODNAME_DYNAMICCAST_SIGNATURE "DynamicCastToSignature" #define BF_METHODNAME_CALCAPPEND "this$calcAppend" #define BF_METHODNAME_ENUM_HASFLAG "HasFlag" #define BF_METHODNAME_ENUM_GETUNDERLYING "get__Underlying" diff --git a/IDEHelper/Compiler/CeMachine.cpp b/IDEHelper/Compiler/CeMachine.cpp index a07ac637..b97872ad 100644 --- a/IDEHelper/Compiler/CeMachine.cpp +++ b/IDEHelper/Compiler/CeMachine.cpp @@ -8110,16 +8110,28 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8* { CE_CHECKADDR(valueAddr, sizeof(int32)); - auto ifaceType = GetBfType(ifaceId); + auto wantType = GetBfType(ifaceId); int32 objTypeId = *(int32*)(memStart + valueAddr); auto valueType = GetBfType(objTypeId); - if ((ifaceType == NULL) || (valueType == NULL)) + if ((wantType == NULL) || (valueType == NULL)) { _Fail("Invalid type in CeOp_DynamicCastCheck"); return false; } - if (ceModule->TypeIsSubTypeOf(valueType->ToTypeInstance(), ifaceType->ToTypeInstance(), false)) + bool matches = false; + if (ceModule->TypeIsSubTypeOf(valueType->ToTypeInstance(), wantType->ToTypeInstance(), false)) + { + matches = true; + } + else if ((valueType->IsDelegate()) && (wantType->IsDelegate())) + { + int valueSignatureId = ceModule->GetDelegateSignatureId(valueType->ToTypeInstance()); + int checkSignatureId = ceModule->GetDelegateSignatureId(wantType->ToTypeInstance()); + matches = valueSignatureId == checkSignatureId; + } + + if (matches) CeSetAddrVal(&result, valueAddr, ptrSize); else CeSetAddrVal(&result, 0, ptrSize); diff --git a/IDEHelper/Tests/src/Delegates.bf b/IDEHelper/Tests/src/Delegates.bf index 260bec7c..49ef08ae 100644 --- a/IDEHelper/Tests/src/Delegates.bf +++ b/IDEHelper/Tests/src/Delegates.bf @@ -232,9 +232,16 @@ namespace Tests public static void TestCasting() { - delegate int(int, int) dlg0 = null; + delegate int(int, int) dlg0 = scope (a, b) => 1; delegate int(int a, int b) dlg1 = dlg0; delegate int(int a2, int b2) dlg2 = (.)dlg1; + delegate int(float a, float b) dlg3 = scope (a, b) => 2; + + Object obj = dlg0; + dlg1 = (.)obj; + dlg2 = (.)obj; + Test.Assert(obj is delegate int(int a, int b)); + Test.Assert(!(obj is delegate int(float a, float b))); function int(int, int) func0 = null; function int(int a, int b) func1 = func0;