From be0733d37ca2a81bfe46635a0681414626171265 Mon Sep 17 00:00:00 2001 From: Brian Fiete Date: Sat, 8 Mar 2025 11:02:07 -0800 Subject: [PATCH] Fixed deferred function call --- IDEHelper/Compiler/BfExprEvaluator.cpp | 31 +++++++++++++++++--------- IDEHelper/Compiler/BfExprEvaluator.h | 13 ++++++++--- IDEHelper/Compiler/BfModule.cpp | 22 +++++++++++++----- IDEHelper/Compiler/BfModule.h | 5 ++++- IDEHelper/Compiler/BfStmtEvaluator.cpp | 24 ++++++++++++++++---- IDEHelper/Tests/src/Functions.bf | 11 +++++++++ 6 files changed, 81 insertions(+), 25 deletions(-) diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index ab66895e..a0b5f033 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -3356,8 +3356,7 @@ BfExprEvaluator::BfExprEvaluator(BfModule* module) mExpectingType = NULL; mFunctionBindResult = NULL; mExplicitCast = false; - mDeferCallRef = NULL; - mDeferScopeAlloc = NULL; + mDeferCallData = NULL; mPrefixedAttributeState = NULL; mResolveGenericParam = true; mNoBind = false; @@ -6938,7 +6937,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, BfMethodInstance* if (methodInstance->mVirtualTableIdx != -1) { - if ((!bypassVirtual) && (mDeferCallRef == NULL)) + if ((!bypassVirtual) && (mDeferCallData == NULL)) { if ((methodDef->mIsOverride) && (mModule->mCurMethodInstance->mIsReified)) { @@ -7107,9 +7106,12 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, BfMethodInstance* return _GetDefaultReturnValue(); } - if (mDeferCallRef != NULL) + if (mDeferCallData != NULL) { - mModule->AddDeferredCall(BfModuleMethodInstance(methodInstance, func), irArgs, mDeferScopeAlloc, mDeferCallRef, bypassVirtual); + if (mDeferCallData->mFuncAlloca_Orig == func) + mModule->AddDeferredCall(BfModuleMethodInstance(methodInstance, mDeferCallData->mFuncAlloca), irArgs, mDeferCallData->mScopeAlloc, mDeferCallData->mRefNode, bypassVirtual, false, true); + else + mModule->AddDeferredCall(BfModuleMethodInstance(methodInstance, func), irArgs, mDeferCallData->mScopeAlloc, mDeferCallData->mRefNode, bypassVirtual); return mModule->GetFakeTypedValue(returnType); } @@ -7861,6 +7863,13 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu auto funcType = mModule->mBfIRBuilder->MapMethod(moduleMethodInstance.mMethodInstance); auto funcPtrType = mModule->mBfIRBuilder->GetPointerTo(funcType); moduleMethodInstance.mFunc = mModule->mBfIRBuilder->CreateIntToPtr(target.mValue, funcPtrType); + + if (mDeferCallData != NULL) + { + mDeferCallData->mFuncAlloca_Orig = moduleMethodInstance.mFunc; + mDeferCallData->mFuncAlloca = mModule->CreateAlloca(funcPtrType, target.mType->mAlign, false, "FuncAlloca"); + mModule->mBfIRBuilder->CreateStore(mDeferCallData->mFuncAlloca_Orig, mDeferCallData->mFuncAlloca); + } } else if (!methodDef->mIsStatic) { @@ -8070,7 +8079,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu autoComplete->mIsCapturingMethodMatchInfo = wasCapturingMatchInfo; }); - BfScopeData* boxScopeData = mDeferScopeAlloc; + BfScopeData* boxScopeData = (mDeferCallData != NULL) ? mDeferCallData->mScopeAlloc : NULL; if ((boxScopeData == NULL) && (mModule->mCurMethodState != NULL)) boxScopeData = mModule->mCurMethodState->mCurScope; @@ -9254,9 +9263,9 @@ BfTypedValue BfExprEvaluator::ResolveArgValue(BfResolvedArg& resolvedArg, BfType if ((argValue) && (argValue.mType != wantType) && (wantType != NULL)) { - if ((mDeferScopeAlloc != NULL) && (wantType == mModule->mContext->mBfObjectType)) + if ((mDeferCallData != NULL) && (wantType == mModule->mContext->mBfObjectType)) { - BfAllocTarget allocTarget(mDeferScopeAlloc); + BfAllocTarget allocTarget(mDeferCallData->mScopeAlloc); argValue = mModule->BoxValue(expr, argValue, wantType, allocTarget, ((mBfEvalExprFlags & BfEvalExprFlags_Comptime) != 0) ? BfCastFlags_WantsConst : BfCastFlags_None); } else @@ -17731,7 +17740,7 @@ void BfExprEvaluator::InjectMixin(BfAstNode* targetSrc, BfTypedValue target, boo if (mModule->mCurMethodState == NULL) return; - if (mDeferCallRef != NULL) + if (mDeferCallData != NULL) { mModule->Fail("Mixins cannot be directly deferred. Consider wrapping in a block.", targetSrc); } @@ -20013,7 +20022,7 @@ BfTypedValue BfExprEvaluator::GetResult(bool clearResult, bool resolveGenericTyp if (!handled) { SetAndRestoreValue prevFunctionBindResult(mFunctionBindResult, NULL); - SetAndRestoreValue prevDeferCallRef(mDeferCallRef, NULL); + SetAndRestoreValue prevDeferCallRef(mDeferCallData, NULL); BfMethodDef* matchedMethod = GetPropertyMethodDef(mPropDef, BfMethodType_PropertyGetter, mPropCheckedKind, mPropTarget); if (matchedMethod == NULL) @@ -23198,7 +23207,7 @@ BfTypedValue BfExprEvaluator::PerformUnaryOperation_TryOperator(const BfTypedVal else { SetAndRestoreValue prevFlags(mBfEvalExprFlags, (BfEvalExprFlags)(mBfEvalExprFlags | BfEvalExprFlags_NoAutoComplete)); - SetAndRestoreValue prevDeferCallRef(mDeferCallRef, NULL); + SetAndRestoreValue prevDeferCallRef(mDeferCallData, NULL); result = CreateCall(&methodMatcher, callTarget); } diff --git a/IDEHelper/Compiler/BfExprEvaluator.h b/IDEHelper/Compiler/BfExprEvaluator.h index fc4859e0..345cf366 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.h +++ b/IDEHelper/Compiler/BfExprEvaluator.h @@ -401,6 +401,14 @@ enum BfBinOpFlags BfBinOpFlag_DeferRight = 0x20 }; +struct BfDeferCallData +{ + BfAstNode* mRefNode; + BfScopeData* mScopeAlloc; + BfIRValue mFuncAlloca; // When we need to load + BfIRValue mFuncAlloca_Orig; +}; + class BfExprEvaluator : public BfStructuralVisitor { public: @@ -422,9 +430,8 @@ public: BfAttributeState* mPrefixedAttributeState; BfTypedValue* mReceivingValue; BfFunctionBindResult* mFunctionBindResult; - SizedArray mIndexerValues; - BfAstNode* mDeferCallRef; - BfScopeData* mDeferScopeAlloc; + SizedArray mIndexerValues; + BfDeferCallData* mDeferCallData; bool mUsedAsStatement; bool mPropDefBypassVirtual; bool mExplicitCast; diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index c3067bee..c2f2561d 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -2080,23 +2080,23 @@ void BfModule::RestoreScopeState() mCurMethodState->mTailScope = mCurMethodState->mCurScope; } -BfIRValue BfModule::CreateAlloca(BfType* type, bool addLifetime, const char* name, BfIRValue arraySize) +BfIRValue BfModule::CreateAlloca(BfIRType irType, int align, bool addLifetime, const char* name, BfIRValue arraySize) { if (mBfIRBuilder->mIgnoreWrites) return mBfIRBuilder->GetFakeVal(); - BF_ASSERT((*(int8*)&addLifetime == 1) || (*(int8*)&addLifetime == 0)); - mBfIRBuilder->PopulateType(type); + BF_ASSERT((*(int8*)&addLifetime == 1) || (*(int8*)&addLifetime == 0)); auto prevInsertBlock = mBfIRBuilder->GetInsertBlock(); if (!mBfIRBuilder->mIgnoreWrites) BF_ASSERT(!prevInsertBlock.IsFake()); mBfIRBuilder->SetInsertPoint(mCurMethodState->mIRHeadBlock); BfIRValue allocaInst; if (arraySize) - allocaInst = mBfIRBuilder->CreateAlloca(mBfIRBuilder->MapType(type), arraySize); + allocaInst = mBfIRBuilder->CreateAlloca(irType, arraySize); else - allocaInst = mBfIRBuilder->CreateAlloca(mBfIRBuilder->MapType(type)); - mBfIRBuilder->SetAllocaAlignment(allocaInst, type->mAlign); + allocaInst = mBfIRBuilder->CreateAlloca(irType); + if (align > 0) + mBfIRBuilder->SetAllocaAlignment(allocaInst, align); mBfIRBuilder->ClearDebugLocation(allocaInst); if (name != NULL) mBfIRBuilder->SetName(allocaInst, name); @@ -2110,6 +2110,16 @@ BfIRValue BfModule::CreateAlloca(BfType* type, bool addLifetime, const char* nam return allocaInst; } +BfIRValue BfModule::CreateAlloca(BfType* type, bool addLifetime, const char* name, BfIRValue arraySize) +{ + if (mBfIRBuilder->mIgnoreWrites) + return mBfIRBuilder->GetFakeVal(); + + BF_ASSERT((*(int8*)&addLifetime == 1) || (*(int8*)&addLifetime == 0)); + mBfIRBuilder->PopulateType(type); + return CreateAlloca(mBfIRBuilder->MapType(type), type->mAlign, addLifetime, name, arraySize); +} + BfIRValue BfModule::CreateAllocaInst(BfTypeInstance* typeInst, bool addLifetime, const char* name) { if (mBfIRBuilder->mIgnoreWrites) diff --git a/IDEHelper/Compiler/BfModule.h b/IDEHelper/Compiler/BfModule.h index 798a7ea9..8fd8cdac 100644 --- a/IDEHelper/Compiler/BfModule.h +++ b/IDEHelper/Compiler/BfModule.h @@ -323,6 +323,7 @@ public: bool mCastThis; bool mArgsNeedLoad; bool mIgnored; + bool mIsAllocaFunc; SLIList mDynList; BfIRValue mDynCallTail; @@ -1481,6 +1482,7 @@ enum BfDeferredBlockFlags BfDeferredBlockFlag_DoNullChecks = 2, BfDeferredBlockFlag_SkipObjectAccessCheck = 4, BfDeferredBlockFlag_MoveNewBlocksToEnd = 8, + BfDeferredBlockFlag_IsAllocaFunc = 0x10 }; enum BfGetCustomAttributesFlags @@ -1696,6 +1698,7 @@ public: BfTypedValue FlushNullConditional(BfTypedValue result, bool ignoreNullable = false); void NewScopeState(bool createLexicalBlock = true, bool flushValueScope = true); // returns prev scope data + BfIRValue CreateAlloca(BfIRType irType, int align, bool addLifetime = true, const char* name = NULL, BfIRValue arraySize = BfIRValue()); BfIRValue CreateAlloca(BfType* type, bool addLifetime = true, const char* name = NULL, BfIRValue arraySize = BfIRValue()); BfIRValue CreateAllocaInst(BfTypeInstance* typeInst, bool addLifetime = true, const char* name = NULL); BfDeferredCallEntry* AddStackAlloc(BfTypedValue val, BfIRValue arraySize, BfAstNode* refNode, BfScopeData* scope, bool condAlloca = false, bool mayEscape = false, BfIRBlock valBlock = BfIRBlock()); @@ -1721,7 +1724,7 @@ public: void EmitDeferredCall(BfModuleMethodInstance moduleMethodInstance, SizedArrayImpl& llvmArgs, BfDeferredBlockFlags flags = BfDeferredBlockFlag_None); bool AddDeferredCallEntry(BfDeferredCallEntry* deferredCallEntry, BfScopeData* scope); BfDeferredCallEntry* AddDeferredBlock(BfBlock* block, BfScopeData* scope, Array* captures = NULL); - BfDeferredCallEntry* AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl& llvmArgs, BfScopeData* scope, BfAstNode* srcNode = NULL, bool bypassVirtual = false, bool doNullCheck = false); + BfDeferredCallEntry* AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl& llvmArgs, BfScopeData* scope, BfAstNode* srcNode = NULL, bool bypassVirtual = false, bool doNullCheck = false, bool isAllocaFunc = false); void EmitDeferredCall(BfScopeData* scopeData, BfDeferredCallEntry& deferredCallEntry, bool moveBlocks); void EmitDeferredCallProcessor(BfScopeData* scopeData, SLIList& callEntries, BfIRValue callTail); void EmitDeferredCallProcessorInstances(BfScopeData* scopeData); diff --git a/IDEHelper/Compiler/BfStmtEvaluator.cpp b/IDEHelper/Compiler/BfStmtEvaluator.cpp index 245fd968..4312adcd 100644 --- a/IDEHelper/Compiler/BfStmtEvaluator.cpp +++ b/IDEHelper/Compiler/BfStmtEvaluator.cpp @@ -543,11 +543,12 @@ BfDeferredCallEntry* BfModule::AddDeferredBlock(BfBlock* block, BfScopeData* sco return deferredCallEntry; } -BfDeferredCallEntry* BfModule::AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl& llvmArgs, BfScopeData* scopeData, BfAstNode* srcNode, bool bypassVirtual, bool doNullCheck) +BfDeferredCallEntry* BfModule::AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl& llvmArgs, BfScopeData* scopeData, BfAstNode* srcNode, bool bypassVirtual, bool doNullCheck, bool isAllocaFunc) { BfDeferredCallEntry* deferredCallEntry = new BfDeferredCallEntry(); BF_ASSERT(moduleMethodInstance); deferredCallEntry->mModuleMethodInstance = moduleMethodInstance; + deferredCallEntry->mIsAllocaFunc = isAllocaFunc; for (auto arg : llvmArgs) { @@ -783,7 +784,12 @@ void BfModule::EmitDeferredCall(BfModuleMethodInstance moduleMethodInstance, Siz } BfExprEvaluator expressionEvaluator(this); - expressionEvaluator.CreateCall(NULL, moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, ((flags & BfDeferredBlockFlag_BypassVirtual) != 0), llvmArgs); + + auto func = moduleMethodInstance.mFunc; + if ((flags & BfDeferredBlockFlag_IsAllocaFunc) != 0) + func = mBfIRBuilder->CreateLoad(func); + + expressionEvaluator.CreateCall(NULL, moduleMethodInstance.mMethodInstance, func, ((flags & BfDeferredBlockFlag_BypassVirtual) != 0), llvmArgs); if ((flags & BfDeferredBlockFlag_DoNullChecks) != 0) { @@ -914,6 +920,8 @@ void BfModule::EmitDeferredCall(BfScopeData* scopeData, BfDeferredCallEntry& def flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_DoNullChecks | BfDeferredBlockFlag_SkipObjectAccessCheck | BfDeferredBlockFlag_MoveNewBlocksToEnd); if (moveBlocks) flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_MoveNewBlocksToEnd); + if (deferredCallEntry.mIsAllocaFunc) + flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_IsAllocaFunc); EmitDeferredCall(deferredCallEntry.mModuleMethodInstance, args, flags); } @@ -926,6 +934,7 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIList MapType; @@ -951,6 +960,7 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIListmModuleMethodInstance = moduleMethodInstance; callInfo->mBypassVirtual = deferredCallEntry->mBypassVirtual; + callInfo->mIsAllocaFunc = deferredCallEntry->mIsAllocaFunc; } else { @@ -1118,6 +1128,7 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIListmMethodDef; auto methodOwner = methodInstance->mMethodInstanceGroup->mOwner; @@ -1204,6 +1215,8 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIListCreateBr(condBB); @@ -7336,9 +7349,12 @@ void BfModule::Visit(BfDeferStatement* deferStmt) } else if (auto exprStmt = BfNodeDynCast(deferStmt->mTargetNode)) { + BfDeferCallData deferCallData; + deferCallData.mRefNode = exprStmt->mExpression; + deferCallData.mScopeAlloc = scope; + BfExprEvaluator expressionEvaluator(this); - expressionEvaluator.mDeferCallRef = exprStmt->mExpression; - expressionEvaluator.mDeferScopeAlloc = scope; + expressionEvaluator.mDeferCallData = &deferCallData; expressionEvaluator.VisitChild(exprStmt->mExpression); if (mCurMethodState->mPendingNullConditional != NULL) FlushNullConditional(expressionEvaluator.mResult, true); diff --git a/IDEHelper/Tests/src/Functions.bf b/IDEHelper/Tests/src/Functions.bf index 7dcaf804..6bf073b9 100644 --- a/IDEHelper/Tests/src/Functions.bf +++ b/IDEHelper/Tests/src/Functions.bf @@ -182,6 +182,13 @@ namespace Tests { sVal = 123; } + + public static void TestDefer() + { + function void() func = => Func; + if (func != null) + defer:: func.Invoke(); + } } public static int UseFunc0(function int (T this, float f) func, T a, float b) @@ -254,6 +261,10 @@ namespace Tests ClassC.Test(); Test.Assert(Zoop.sVal == 123); + + Zoop.sVal = 0; + Zoop.TestDefer(); + Test.Assert(Zoop.sVal == 123); } } }