diff --git a/IDEHelper/Compiler/BfExprEvaluator.cpp b/IDEHelper/Compiler/BfExprEvaluator.cpp index 8aef2de5..5ae11a55 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.cpp +++ b/IDEHelper/Compiler/BfExprEvaluator.cpp @@ -22159,14 +22159,14 @@ void BfExprEvaluator::PerformBinaryOperation(BfExpression* leftExpression, BfExp if ((binaryOp == BfBinaryOp_NullCoalesce) && (PerformBinaryOperation_NullCoalesce(opToken, leftExpression, rightExpression, leftValue, wantType, NULL))) return; - + BfType* rightWantType = wantType; if (origWantType->IsIntUnknown()) rightWantType = NULL; else if ((mExpectingType != NULL) && (wantType != NULL) && (mExpectingType->IsIntegral()) && (wantType->IsIntegral()) && (mExpectingType->mSize > wantType->mSize) && ((binaryOp == BfBinaryOp_Add) || (binaryOp == BfBinaryOp_Subtract) || (binaryOp == BfBinaryOp_Multiply))) rightWantType = mExpectingType; - rightValue = mModule->CreateValueFromExpression(rightExpression, rightWantType, (BfEvalExprFlags)((mBfEvalExprFlags & BfEvalExprFlags_InheritFlags) | BfEvalExprFlags_NoCast)); + rightValue = mModule->CreateValueFromExpression(rightExpression, rightWantType, (BfEvalExprFlags)((mBfEvalExprFlags & BfEvalExprFlags_InheritFlags) | BfEvalExprFlags_NoCast | BfEvalExprFlags_AllowIntUnknown)); if ((rightWantType != wantType) && (rightValue.mType == rightWantType)) wantType = rightWantType; if ((!leftValue) || (!rightValue)) @@ -22288,6 +22288,57 @@ bool BfExprEvaluator::PerformBinaryOperation_NullCoalesce(BfTokenNode* opToken, return false; } +bool BfExprEvaluator::PerformBinaryOperation_Numeric(BfAstNode* leftExpression, BfAstNode* rightExpression, BfBinaryOp binaryOp, BfAstNode* opToken, BfBinOpFlags flags, BfTypedValue leftValue, BfTypedValue rightValue) +{ + switch (binaryOp) + { + case BfBinaryOp_Add: + case BfBinaryOp_Subtract: + case BfBinaryOp_Multiply: + case BfBinaryOp_Divide: + case BfBinaryOp_Modulus: + break; + default: + return false; + } + + auto wantType = mExpectingType; + if ((wantType == NULL) || + ((!wantType->IsFloat()) && (!wantType->IsIntegral()))) + wantType = NULL; + + auto leftType = mModule->GetClosestNumericCastType(leftValue, mExpectingType); + auto rightType = mModule->GetClosestNumericCastType(rightValue, mExpectingType); + + if (leftType != NULL) + { + if ((rightType == NULL) || (mModule->CanCast(mModule->GetFakeTypedValue(rightType), leftType))) + wantType = leftType; + else if ((rightType != NULL) && (mModule->CanCast(mModule->GetFakeTypedValue(leftType), rightType))) + wantType = rightType; + } + else if (rightType != NULL) + wantType = rightType; + + if (wantType == NULL) + wantType = mModule->GetPrimitiveType(BfTypeCode_IntPtr); + + auto convLeftValue = mModule->Cast(opToken, leftValue, wantType, BfCastFlags_SilentFail); + if (!convLeftValue) + return false; + + auto convRightValue = mModule->Cast(opToken, rightValue, wantType, BfCastFlags_SilentFail); + if (!convRightValue) + return false; + + mResult = BfTypedValue(); + + // Let the error come from here, if any - so we always return 'true' to avoid a second error + PerformBinaryOperation(leftExpression, rightExpression, binaryOp, opToken, flags, convLeftValue, convRightValue); + + return true; +} + void BfExprEvaluator::PerformBinaryOperation(BfExpression* leftExpression, BfExpression* rightExpression, BfBinaryOp binaryOp, BfTokenNode* opToken, BfBinOpFlags flags) { if ((binaryOp == BfBinaryOp_Range) || (binaryOp == BfBinaryOp_ClosedRange)) @@ -22586,6 +22637,9 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod if (rightValue.mType->IsRef()) rightValue.mType = rightValue.mType->GetUnderlyingType(); + BfType* origLeftType = leftValue.mType; + BfType* origRightType = rightValue.mType; + mModule->FixIntUnknown(leftValue, rightValue); // Prefer floats, prefer chars @@ -23269,12 +23323,30 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod if (flippedBinaryOp != BfBinaryOp_None) findBinaryOp = flippedBinaryOp; } - - auto prevResultType = resultType; - if ((leftValue.mType->IsPrimitiveType()) && (!rightValue.mType->IsTypedPrimitive())) - resultType = leftValue.mType; - if ((rightValue.mType->IsPrimitiveType()) && (!leftValue.mType->IsTypedPrimitive())) - resultType = rightValue.mType; + + bool resultHandled = false; + if (((origLeftType != NULL) && (origLeftType->IsIntUnknown())) || + ((origRightType != NULL) && (origRightType->IsIntUnknown()))) + { + if (!resultType->IsPrimitiveType()) + { + BfType* numericCastType = mModule->GetClosestNumericCastType(*resultTypedValue, mExpectingType); + if (numericCastType != NULL) + { + resultHandled = true; + resultType = numericCastType; + } + } + } + + if (!resultHandled) + { + auto prevResultType = resultType; + if ((leftValue.mType->IsPrimitiveType()) && (!origLeftType->IsIntUnknown()) && (!rightValue.mType->IsTypedPrimitive())) + resultType = leftValue.mType; + if ((rightValue.mType->IsPrimitiveType()) && (!origRightType->IsIntUnknown()) && (!leftValue.mType->IsTypedPrimitive())) + resultType = rightValue.mType; + } } } @@ -23641,6 +23713,9 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod return; } + if (PerformBinaryOperation_Numeric(leftExpression, rightExpression, binaryOp, opToken, flags, leftValue, rightValue)) + return; + if (mModule->PreFail()) { mModule->Fail(StrFormat("Operator '%s' cannot be applied to operands of type '%s'", @@ -23687,6 +23762,9 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod } } + if (PerformBinaryOperation_Numeric(leftExpression, rightExpression, binaryOp, opToken, flags, leftValue, rightValue)) + return; + mModule->Fail(StrFormat("Operator '%s' cannot be applied to operands of type '%s' and '%s'", BfGetOpName(binaryOp), mModule->TypeToString(leftValue.mType).c_str(), diff --git a/IDEHelper/Compiler/BfExprEvaluator.h b/IDEHelper/Compiler/BfExprEvaluator.h index e64acc21..ae71d503 100644 --- a/IDEHelper/Compiler/BfExprEvaluator.h +++ b/IDEHelper/Compiler/BfExprEvaluator.h @@ -462,6 +462,7 @@ public: bool CheckConstCompare(BfBinaryOp binaryOp, BfAstNode* opToken, const BfTypedValue& leftValue, const BfTypedValue& rightValue); void AddStrings(const BfTypedValue& leftValue, const BfTypedValue& rightValue, BfAstNode* refNode); bool PerformBinaryOperation_NullCoalesce(BfTokenNode* opToken, BfExpression* leftExpression, BfExpression* rightExpression, BfTypedValue leftValue, BfType* wantType, BfTypedValue* assignTo = NULL); + bool PerformBinaryOperation_Numeric(BfAstNode* leftExpression, BfAstNode* rightExpression, BfBinaryOp binaryOp, BfAstNode* opToken, BfBinOpFlags flags, BfTypedValue leftValue, BfTypedValue rightValue); void PerformBinaryOperation(BfType* resultType, BfIRValue convLeftValue, BfIRValue convRightValue, BfBinaryOp binaryOp, BfAstNode* opToken); void PerformBinaryOperation(BfAstNode* leftExpression, BfAstNode* rightExpression, BfBinaryOp binaryOp, BfAstNode* opToken, BfBinOpFlags flags, BfTypedValue leftValue, BfTypedValue rightValue); void PerformBinaryOperation(BfExpression* leftNode, BfExpression* rightNode, BfBinaryOp binaryOp, BfTokenNode* opToken, BfBinOpFlags flags, BfTypedValue leftValue); @@ -506,7 +507,7 @@ public: void PerformUnaryOperation(BfExpression* unaryOpExpr, BfUnaryOp unaryOp, BfTokenNode* opToken, BfUnaryOpFlags opFlags); BfTypedValue PerformUnaryOperation_TryOperator(const BfTypedValue& inValue, BfExpression* unaryOpExpr, BfUnaryOp unaryOp, BfTokenNode* opToken, BfUnaryOpFlags opFlags); void PerformUnaryOperation_OnResult(BfExpression* unaryOpExpr, BfUnaryOp unaryOp, BfTokenNode* opToken, BfUnaryOpFlags opFlags); - BfTypedValue PerformAssignment_CheckOp(BfAssignmentExpression* assignExpr, bool deferBinop, BfTypedValue& leftValue, BfTypedValue& rightValue, bool& evaluatedRight); + BfTypedValue PerformAssignment_CheckOp(BfAssignmentExpression* assignExpr, bool deferBinop, BfTypedValue& leftValue, BfTypedValue& rightValue, bool& evaluatedRight); void PerformAssignment(BfAssignmentExpression* assignExpr, bool evaluatedLeft, BfTypedValue rightValue, BfTypedValue* outCascadeValue = NULL); void PopulateDeferrredTupleAssignData(BfTupleExpression* tupleExr, DeferredTupleAssignData& deferredTupleAssignData); void AssignDeferrredTupleAssignData(BfAssignmentExpression* assignExpr, DeferredTupleAssignData& deferredTupleAssignData, BfTypedValue rightValue); diff --git a/IDEHelper/Compiler/BfModule.h b/IDEHelper/Compiler/BfModule.h index f2380503..8146c99b 100644 --- a/IDEHelper/Compiler/BfModule.h +++ b/IDEHelper/Compiler/BfModule.h @@ -1647,6 +1647,7 @@ public: void EmitDeferredCallProcessor(SLIList& callEntries, BfIRValue callTail); bool CanCast(BfTypedValue typedVal, BfType* toType, BfCastFlags castFlags = BfCastFlags_None); bool AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNeedsMemberCasting); + BfType* GetClosestNumericCastType(const BfTypedValue& typedVal, BfType* wantType); BfTypedValue BoxValue(BfAstNode* srcNode, BfTypedValue typedVal, BfType* toType /*Can be System.Object or interface*/, const BfAllocTarget& allocTarget, BfCastFlags castFlags = BfCastFlags_None); BfIRValue CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfIRValue irFunc = BfIRValue()); BfIRValue CastToValue(BfAstNode* srcNode, BfTypedValue val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfCastResultFlags* resultFlags = NULL); @@ -1832,8 +1833,8 @@ public: BfIRType GetIRLoweredType(BfTypeCode loweredTypeCode, BfTypeCode loweredTypeCode2); BfMethodRefType* CreateMethodRefType(BfMethodInstance* methodInstance, bool mustAlreadyExist = false); BfType* FixIntUnknown(BfType* type); - void FixIntUnknown(BfTypedValue& typedVal, BfType* matchType = NULL); - void FixIntUnknown(BfTypedValue& lhs, BfTypedValue& rhs); + void FixIntUnknown(BfTypedValue& typedVal, BfType* matchType = NULL); + void FixIntUnknown(BfTypedValue& lhs, BfTypedValue& rhs); void FixValueActualization(BfTypedValue& typedVal, bool force = false); bool TypeEquals(BfTypedValue& val, BfType* type); BfTypeDef* ResolveGenericInstanceDef(BfGenericInstanceTypeRef* genericTypeRef, BfType** outType = NULL, BfResolveTypeRefFlags resolveFlags = BfResolveTypeRefFlag_None); diff --git a/IDEHelper/Compiler/BfModuleTypeUtils.cpp b/IDEHelper/Compiler/BfModuleTypeUtils.cpp index a6d34ba1..8b6a10c3 100644 --- a/IDEHelper/Compiler/BfModuleTypeUtils.cpp +++ b/IDEHelper/Compiler/BfModuleTypeUtils.cpp @@ -12240,6 +12240,73 @@ bool BfModule::AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNe return true; } +BfType* BfModule::GetClosestNumericCastType(const BfTypedValue& typedVal, BfType* wantType) +{ + BfType* toType = wantType; + if ((toType == NULL) || + ((!toType->IsFloat()) && (!toType->IsIntegral()))) + toType = NULL; + + BfType* bestReturnType = NULL; + + if (typedVal.mType->IsTypedPrimitive()) + return NULL; + + auto checkType = typedVal.mType->ToTypeInstance(); + while (checkType != NULL) + { + for (auto operatorDef : checkType->mTypeDef->mOperators) + { + if (operatorDef->mOperatorDeclaration->mIsConvOperator) + { + if (operatorDef->IsExplicit()) + continue; + + auto returnType = CheckOperator(checkType, operatorDef, typedVal, BfTypedValue()); + if ((returnType != NULL) && + ((returnType->IsIntegral()) || (returnType->IsFloat()))) + { + bool canCastTo = true; + + if ((toType != NULL) && (!CanCast(GetFakeTypedValue(returnType), toType))) + canCastTo = false; + + if (canCastTo) + { + if (bestReturnType == NULL) + { + bestReturnType = returnType; + } + else + { + if (CanCast(GetFakeTypedValue(bestReturnType), returnType)) + { + bestReturnType = returnType; + } + } + } + } + } + } + + checkType = checkType->mBaseType; + } + + if ((toType == NULL) && (bestReturnType != NULL)) + { + auto intPtrType = GetPrimitiveType(BfTypeCode_IntPtr); + if (!CanCast(GetFakeTypedValue(bestReturnType), intPtrType)) + { + // If no 'wantType' is specified, try to get closest one to an intptr + auto otherType = GetClosestNumericCastType(typedVal, intPtrType); + if (otherType != NULL) + return otherType; + } + } + + return bestReturnType; +} + BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags, BfIRValue irFunc) { auto invokeMethodInstance = GetDelegateInvokeMethod(toType->ToTypeInstance()); diff --git a/IDEHelper/Tests/src/Numerics.bf b/IDEHelper/Tests/src/Numerics.bf index 3bd85f41..ba02aa1f 100644 --- a/IDEHelper/Tests/src/Numerics.bf +++ b/IDEHelper/Tests/src/Numerics.bf @@ -1,3 +1,5 @@ +#pragma warning disable 168 + using System; using System.Numerics; @@ -21,6 +23,22 @@ namespace Tests float4 v3 = v0.wzyx; Test.Assert(v3 === .(4, 3, 2, 1)); + + Result r0 = 123; + Result r1 = 2000; + Result r2 = 3000; + + uint16 v4 = r0 + 123; + var v5 = r0 + 2; + Test.Assert(v5.GetType() == typeof(uint16)); + var v6 = r0 + r0; + Test.Assert(v6.GetType() == typeof(uint16)); + var v7 = r0 + r1; + Test.Assert(v7.GetType() == typeof(int)); + var v8 = r0 + r2; + Test.Assert(v8.GetType() == typeof(uint32)); + var v9 = r2 + r0; + Test.Assert(v9.GetType() == typeof(uint32)); } } }