diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index c7815fea..ed75baf9 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -3402,6 +3402,7 @@ void BfExprEvaluator::Visit(BfBlock* blockExpr) } mModule->VisitEmbeddedStatement(blockExpr, this, BfNodeIsA(blockExpr) ? BfEmbeddedStatementFlags_Unscoped : BfEmbeddedStatementFlags_None); + mResult = mModule->SanitizeAddr(mResult); } bool BfExprEvaluator::CheckVariableDeclaration(BfAstNode* checkNode, bool requireSimpleIfExpr, bool exprMustBeTrue, bool silentFail) @@ -17164,6 +17165,7 @@ void BfExprEvaluator::InjectMixin(BfAstNode* targetSrc, BfTypedValue target, boo if (!exprNode->IsA()) { // Mixin expression result + SetAndRestoreValue prevFlags(mBfEvalExprFlags, (BfEvalExprFlags)(mBfEvalExprFlags | BfEvalExprFlags_AllowRefExpr)); mModule->UpdateSrcPos(exprNode); VisitChild(exprNode); FinishExpressionResult(); @@ -17187,6 +17189,7 @@ void BfExprEvaluator::InjectMixin(BfAstNode* targetSrc, BfTypedValue target, boo } mResult = mModule->LoadValue(mResult); + mResult = mModule->SanitizeAddr(mResult); int localIdx = startLocalIdx; @@ -19497,7 +19500,7 @@ void BfExprEvaluator::PerformAssignment(BfAssignmentExpression* assignExpr, bool } ResolveGenericType(); - auto ptr = mResult; + auto ptr = mModule->RemoveRef(mResult); mResult = BfTypedValue(); if (mPropDef != NULL) @@ -21569,8 +21572,7 @@ void BfExprEvaluator::PerformUnaryOperation_OnResult(BfExpression* unaryOpExpr, if (!mResult) return; - if (mResult.mType->IsRef()) - mResult.mType = mResult.mType->GetUnderlyingType(); + mResult = mModule->RemoveRef(mResult); if (mResult.mType->IsVar()) { diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index b72caf73..46b13bf6 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -8791,13 +8791,26 @@ BfTypedValue BfModule::CreateValueFromExpression(BfExprEvaluator& exprEvaluator, { // Only allow a 'ref' type if we have an explicit 'ref' operator bool allowRef = false; - BfExpression* checkExpr = expr; - while (auto parenExpr = BfNodeDynCast(checkExpr)) - checkExpr = parenExpr->mExpression; - if (auto unaryOp = BfNodeDynCast(checkExpr)) + BfAstNode* checkExpr = expr; + + while (checkExpr != NULL) { - if ((unaryOp->mOp == BfUnaryOp_Ref) || (unaryOp->mOp == BfUnaryOp_Mut)) - allowRef = true; + if (auto parenExpr = BfNodeDynCast(checkExpr)) + checkExpr = parenExpr->mExpression; + else if (auto unaryOp = BfNodeDynCast(checkExpr)) + { + if ((unaryOp->mOp == BfUnaryOp_Ref) || (unaryOp->mOp == BfUnaryOp_Mut)) + allowRef = true; + break; + } + if (auto block = BfNodeDynCast(checkExpr)) + { + if (block->mChildArr.mSize == 0) + break; + checkExpr = block->mChildArr.back(); + } + else + break; } if (!allowRef) typedVal = RemoveRef(typedVal); @@ -12480,6 +12493,23 @@ BfTypedValue BfModule::RemoveRef(BfTypedValue typedValue) return typedValue; } +BfTypedValue BfModule::SanitizeAddr(BfTypedValue typedValue) +{ + if (!typedValue) + return typedValue; + + if (typedValue.mType->IsRef()) + { + typedValue = LoadValue(typedValue); + + auto copiedVal = BfTypedValue(CreateAlloca(typedValue.mType), typedValue.mType, true); + mBfIRBuilder->CreateStore(typedValue.mValue, copiedVal.mValue); + return copiedVal; + } + + return typedValue; +} + BfTypedValue BfModule::ToRef(BfTypedValue typedValue, BfRefType* refType) { if (refType == NULL) diff --git a/IDEHelper/Compiler/BfModule.h b/IDEHelper/Compiler/BfModule.h index f4fc6582..1264f81b 100644 --- a/IDEHelper/Compiler/BfModule.h +++ b/IDEHelper/Compiler/BfModule.h @@ -1700,6 +1700,7 @@ public: void EmitDynamicCastCheck(BfTypedValue typedVal, BfType* type, bool allowNull); void CheckStaticAccess(BfTypeInstance* typeInstance); BfTypedValue RemoveRef(BfTypedValue typedValue); + BfTypedValue SanitizeAddr(BfTypedValue typedValue); BfTypedValue ToRef(BfTypedValue typedValue, BfRefType* refType = NULL); BfTypedValue LoadOrAggregateValue(BfTypedValue typedValue); BfTypedValue LoadValue(BfTypedValue typedValue, BfAstNode* refNode = NULL, bool isVolatile = false); diff --git a/IDEHelper/Compiler/BfReducer.cpp b/IDEHelper/Compiler/BfReducer.cpp index a9f62ae0..8e9fb70d 100644 --- a/IDEHelper/Compiler/BfReducer.cpp +++ b/IDEHelper/Compiler/BfReducer.cpp @@ -4427,15 +4427,10 @@ BfAstNode* BfReducer::DoCreateStatement(BfAstNode* node, CreateStmtFlags createS (unaryOperatorExpr->mOp == BfUnaryOp_Decrement) || (unaryOperatorExpr->mOp == BfUnaryOp_PostDecrement); - if ((unaryOperatorExpr->mOp == BfUnaryOp_Ref) || (unaryOperatorExpr->mOp == BfUnaryOp_Mut) || (unaryOperatorExpr->mOp == BfUnaryOp_Out)) + if (unaryOperatorExpr->mOp == BfUnaryOp_Out) { - if (unaryOperatorExpr->mOp == BfUnaryOp_Ref) - Fail("Cannot use 'ref' in this context", unaryOperatorExpr); - else if (unaryOperatorExpr->mOp == BfUnaryOp_Mut) - Fail("Cannot use 'mut' in this context", unaryOperatorExpr); - else - Fail("Cannot use 'out' in this context", unaryOperatorExpr); - return NULL; + unaryOperatorExpr->mOp = BfUnaryOp_Ref; + Fail("Cannot use 'out' in this context", unaryOperatorExpr); } } diff --git a/IDEHelper/Compiler/BfStmtEvaluator.cpp b/IDEHelper/Compiler/BfStmtEvaluator.cpp index ea686fe8..3cb62937 100644 --- a/IDEHelper/Compiler/BfStmtEvaluator.cpp +++ b/IDEHelper/Compiler/BfStmtEvaluator.cpp @@ -3538,7 +3538,7 @@ void BfModule::VisitCodeBlock(BfBlock* block) else if ((mCurMethodInstance != NULL) && (mCurMethodInstance->IsMixin()) && (mCurMethodState->mCurScope == &mCurMethodState->mHeadScope)) { // Only in mixin definition - result ignored - CreateValueFromExpression(expr); + CreateValueFromExpression(expr, NULL, BfEvalExprFlags_AllowRefExpr); break; } else diff --git a/IDEHelper/Tests/src/Mixins.bf b/IDEHelper/Tests/src/Mixins.bf index b7078cf3..993e06b9 100644 --- a/IDEHelper/Tests/src/Mixins.bf +++ b/IDEHelper/Tests/src/Mixins.bf @@ -73,6 +73,12 @@ namespace Tests total + 100 } + static mixin GetRef(var a) + { + a += 1000; + ref a + } + [Test] public static void TestBasics() { @@ -125,6 +131,14 @@ namespace Tests AppendAndNullify!(str0); Test.Assert(str0 == null); Test.Assert(str1 == "AB"); + + int b = 12; + GetRef!(b) += 200; + Test.Assert(b == 1212); + + var c = { ref b }; + c = 99; + Test.Assert(b == 99); } }