1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-07 19:18:19 +02:00

Improved constraint checks where generic param type constraint passes

This commit is contained in:
Brian Fiete 2025-05-28 11:57:28 +02:00
parent 116d9c6f01
commit b7725d0ed0
4 changed files with 76 additions and 1 deletions

View file

@ -23367,6 +23367,15 @@ void BfExprEvaluator::PerformUnaryOperation_OnResult(BfExpression* unaryOpExpr,
mResult = opResult;
return;
}
auto typeConstraint = mModule->GetGenericParamInstanceTypeConstraint(mResult.mType);
if ((typeConstraint != NULL) && (!typeConstraint->IsGenericParam()))
{
// Handle cases such as 'where T : float'
mResult.mType = typeConstraint;
PerformUnaryOperation_OnResult(unaryOpExpr, unaryOp, opToken, opFlags);
return;
}
}
switch (unaryOp)
@ -24947,7 +24956,7 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod
BfBinaryOp findBinaryOp = binaryOp;
bool isComparison = (binaryOp >= BfBinaryOp_Equality) && (binaryOp <= BfBinaryOp_LessThanOrEqual);
for (int pass = 0; pass < 2; pass++)
{
BfBinaryOp oppositeBinaryOp = BfGetOppositeBinaryOp(findBinaryOp);
@ -25316,6 +25325,48 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod
findBinaryOp = flippedBinaryOp;
}
auto _FixOpCheckGenericParam = [&](BfTypedValue& typedVal)
{
if ((typedVal.mType != NULL) && (typedVal.mType->IsGenericParam()))
{
auto genericParamInstance = mModule->GetGenericParamInstance((BfGenericParamType*)typedVal.mType);
if (genericParamInstance->mTypeConstraint != NULL)
{
typedVal.mType = genericParamInstance->mTypeConstraint;
return true;
}
}
return false;
};
auto leftTypeConstraint = mModule->GetGenericParamInstanceTypeConstraint(leftValue.mType);
auto rightTypeConstraint = mModule->GetGenericParamInstanceTypeConstraint(rightValue.mType);
if ((leftTypeConstraint != NULL) || (rightTypeConstraint != NULL))
{
// Handle cases such as 'where T : float'
bool needNewCheck = false;
BfTypedValue newLeftValue = leftValue;
if ((leftTypeConstraint != NULL) && (!leftTypeConstraint->IsGenericParam()))
{
newLeftValue.mType = leftTypeConstraint;
needNewCheck = true;
}
BfTypedValue newRightValue = rightValue;
if ((rightTypeConstraint != NULL) && (!rightTypeConstraint->IsGenericParam()))
{
newRightValue.mType = rightTypeConstraint;
needNewCheck = true;
}
if (needNewCheck)
{
PerformBinaryOperation(leftExpression, rightExpression, binaryOp, opToken, flags, newLeftValue, newRightValue);
return;
}
}
bool resultHandled = false;
if (((origLeftType != NULL) && (origLeftType->IsIntUnknown())) ||
((origRightType != NULL) && (origRightType->IsIntUnknown())))

View file

@ -1972,6 +1972,7 @@ public:
bool IsUnboundGeneric(BfType* type);
BfGenericParamInstance* GetGenericTypeParamInstance(int paramIdx, BfFailHandleKind failHandleKind = BfFailHandleKind_Normal);
BfGenericParamInstance* GetGenericParamInstance(BfGenericParamType* type, bool checkMixinBind = false, BfFailHandleKind failHandleKind = BfFailHandleKind_Normal);
BfType* GetGenericParamInstanceTypeConstraint(BfType* type, bool checkMixinBind = false, BfFailHandleKind failHandleKind = BfFailHandleKind_Normal);
void GetActiveTypeGenericParamInstances(SizedArray<BfGenericParamInstance*, 4>& genericParamInstance);
BfGenericParamInstance* GetMergedGenericParamData(BfType* type, BfGenericParamFlags& outFlags, BfType*& outTypeConstraint);
BfTypeInstance* GetBaseType(BfTypeInstance* typeInst);

View file

@ -9861,6 +9861,16 @@ BfGenericParamInstance* BfModule::GetGenericParamInstance(BfGenericParamType* ty
return GetGenericTypeParamInstance(type->mGenericParamIdx, failHandleKind);
}
BfType* BfModule::GetGenericParamInstanceTypeConstraint(BfType* type, bool checkMixinBind, BfFailHandleKind failHandleKind)
{
if (!type->IsGenericParam())
return NULL;
auto genericParamInstance = GetGenericParamInstance((BfGenericParamType*)type, checkMixinBind, failHandleKind);
if (genericParamInstance != NULL)
return genericParamInstance->mTypeConstraint;
return NULL;
}
bool BfModule::ResolveTypeResult_Validate(BfAstNode* typeRef, BfType* resolvedTypeRef)
{
if ((typeRef == NULL) || (resolvedTypeRef == NULL))

View file

@ -7,6 +7,19 @@ namespace Tests
{
class Constraints
{
struct Vector2<T>
{
public T mX;
public T mY;
}
extension Vector2<T> where T : float
{
public T LengthSquared => mX * mX + mY * mY;
public T Length => Math.Sqrt(LengthSquared);
public T NegX = -mX;
}
class Dicto : Dictionary<int, float>
{