diff --git a/BeefLibs/Beefy2D/src/utils/StructuredData.bf b/BeefLibs/Beefy2D/src/utils/StructuredData.bf index 08dfd840..bd631001 100644 --- a/BeefLibs/Beefy2D/src/utils/StructuredData.bf +++ b/BeefLibs/Beefy2D/src/utils/StructuredData.bf @@ -423,7 +423,7 @@ namespace Beefy.utils } } - public void Get(StringView name, ref T val) where T : Enum + public void Get(StringView name, ref T val) where T : enum { Object obj = Get(name); if (obj == null) @@ -547,7 +547,7 @@ namespace Beefy.utils return (bool)aVal; } - public T GetEnum(String name, T defaultVal = default(T)) where T : Enum + public T GetEnum(String name, T defaultVal = default(T)) where T : enum { Object obj = Get(name); if (obj == null) @@ -566,7 +566,7 @@ namespace Beefy.utils return defaultVal; } - public bool GetEnum(String name, ref T val) where T : Enum + public bool GetEnum(String name, ref T val) where T : enum { Object obj = Get(name); if (obj == null) @@ -614,7 +614,7 @@ namespace Beefy.utils return; } - public T GetCurEnum(T theDefault = default) where T : Enum + public T GetCurEnum(T theDefault = default) where T : enum { Object obj = GetCurrent(); diff --git a/BeefLibs/corlib/src/Enum.bf b/BeefLibs/corlib/src/Enum.bf index 48712752..f39872e1 100644 --- a/BeefLibs/corlib/src/Enum.bf +++ b/BeefLibs/corlib/src/Enum.bf @@ -18,7 +18,7 @@ namespace System ((int32)iVal).ToString(strBuffer); } - public static Result Parse(StringView str, bool ignoreCase = false) where T : Enum + public static Result Parse(StringView str, bool ignoreCase = false) where T : enum { var typeInst = (TypeInstance)typeof(T); for (var field in typeInst.GetFields()) diff --git a/IDEHelper/Compiler/BfDefBuilder.cpp b/IDEHelper/Compiler/BfDefBuilder.cpp index 4ef08440..31b0dbfd 100644 --- a/IDEHelper/Compiler/BfDefBuilder.cpp +++ b/IDEHelper/Compiler/BfDefBuilder.cpp @@ -285,9 +285,10 @@ void BfDefBuilder::ParseGenericParams(BfGenericParamsDeclaration* genericParamsD if (!name.empty()) { - if ((name == "class") || (name == "struct") || (name == "struct*") || (name == "const") || (name == "var")) + if ((name == "class") || (name == "struct") || (name == "struct*") || (name == "const") || (name == "var") || (name == "interface") || (name == "enum")) { - int prevFlags = constraintDef->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr); + int prevFlags = constraintDef->mGenericParamFlags & + (BfGenericParamFlag_Class | BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Interface | BfGenericParamFlag_Enum); if (prevFlags != 0) { String prevFlagName; @@ -295,8 +296,12 @@ void BfDefBuilder::ParseGenericParams(BfGenericParamsDeclaration* genericParamsD prevFlagName = "class"; else if (prevFlags & BfGenericParamFlag_Struct) prevFlagName = "struct"; - else // + else if (prevFlags & BfGenericParamFlag_StructPtr) prevFlagName = "struct*"; + else if (prevFlags & BfGenericParamFlag_Enum) + prevFlagName = "enum"; + else // interface + prevFlagName = "interface"; if (prevFlagName == name) Fail(StrFormat("Cannot specify '%s' twice", prevFlagName.c_str()), constraintNode); @@ -313,6 +318,10 @@ void BfDefBuilder::ParseGenericParams(BfGenericParamsDeclaration* genericParamsD constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_StructPtr); else if (name == "const") constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Const); + else if (name == "interface") + constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Interface); + else if (name == "enum") + constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Enum); else //if (name == "var") constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Var); diff --git a/IDEHelper/Compiler/BfModule.cpp b/IDEHelper/Compiler/BfModule.cpp index 805342f3..9fc1a414 100644 --- a/IDEHelper/Compiler/BfModule.cpp +++ b/IDEHelper/Compiler/BfModule.cpp @@ -7247,6 +7247,29 @@ bool BfModule::CheckGenericConstraints(const BfGenericParamSource& genericParamS return false; } + if (genericParamInst->mGenericParamFlags & BfGenericParamFlag_Enum) + { + bool isEnum = checkArgType->IsEnum(); + if ((origCheckArgType->IsGenericParam()) && (checkArgType->IsInstanceOf(mCompiler->mEnumTypeDef))) + isEnum = true; + if (((checkGenericParamFlags & (BfGenericParamFlag_Enum | BfGenericParamFlag_Var)) == 0) && (!isEnum)) + { + if (!ignoreErrors) + *errorOut = Fail(StrFormat("The type '%s' must be an enum type in order to use it as parameter '%s' for '%s'", + TypeToString(origCheckArgType).c_str(), genericParamInst->GetName().c_str(), GenericParamSourceToString(genericParamSource).c_str()), checkArgTypeRef); + return false; + } + } + + if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Interface) && + ((checkGenericParamFlags & (BfGenericParamFlag_Interface | BfGenericParamFlag_Var)) == 0) && (!checkArgType->IsInterface())) + { + if (!ignoreErrors) + *errorOut = Fail(StrFormat("The type '%s' must be an interface type in order to use it as parameter '%s' for '%s'", + TypeToString(origCheckArgType).c_str(), genericParamInst->GetName().c_str(), GenericParamSourceToString(genericParamSource).c_str()), checkArgTypeRef); + return false; + } + if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Const) != 0) { if (((checkGenericParamFlags & BfGenericParamFlag_Const) == 0) && (!checkArgType->IsConstExprValue())) diff --git a/IDEHelper/Compiler/BfModuleTypeUtils.cpp b/IDEHelper/Compiler/BfModuleTypeUtils.cpp index ea9e539e..dee7a440 100644 --- a/IDEHelper/Compiler/BfModuleTypeUtils.cpp +++ b/IDEHelper/Compiler/BfModuleTypeUtils.cpp @@ -363,6 +363,9 @@ bool BfModule::AreConstraintsSubset(BfGenericParamInstance* checkInner, BfGeneri { // If the outer had a type flag and the inner has a specific type constraint, then see if those are compatible auto outerFlags = checkOuter->mGenericParamFlags; + if ((outerFlags & BfGenericParamFlag_Enum) != 0) + outerFlags |= BfGenericParamFlag_Struct; + if (checkOuter->mTypeConstraint != NULL) { if (checkOuter->mTypeConstraint->IsStruct()) @@ -371,9 +374,17 @@ bool BfModule::AreConstraintsSubset(BfGenericParamInstance* checkInner, BfGeneri outerFlags |= BfGenericParamFlag_StructPtr; else if (checkOuter->mTypeConstraint->IsObject()) outerFlags |= BfGenericParamFlag_Class; + else if (checkOuter->mTypeConstraint->IsEnum()) + outerFlags |= BfGenericParamFlag_Enum | BfGenericParamFlag_Struct; + else if (checkOuter->mTypeConstraint->IsInterface()) + outerFlags |= BfGenericParamFlag_Interface; } - if (((checkInner->mGenericParamFlags | outerFlags) & ~BfGenericParamFlag_Var) != (outerFlags & ~BfGenericParamFlag_Var)) + auto innerFlags = checkInner->mGenericParamFlags; + if ((innerFlags & BfGenericParamFlag_Enum) != 0) + innerFlags |= BfGenericParamFlag_Struct; + + if (((innerFlags | outerFlags) & ~BfGenericParamFlag_Var) != (outerFlags & ~BfGenericParamFlag_Var)) return false; } @@ -8481,7 +8492,7 @@ BfType* BfModule::ResolveTypeRef(BfTypeReference* typeRef, BfPopulateType popula { auto genericParam = GetGenericParamInstance((BfGenericParamType*)resolvedType); if (((genericParam->mTypeConstraint != NULL) && (genericParam->mTypeConstraint->IsValueType())) || - ((genericParam->mGenericParamFlags & (BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr)) != 0)) + ((genericParam->mGenericParamFlags & (BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Enum)) != 0)) { resolvedType = CreatePointerType(resolvedType); } @@ -9970,7 +9981,7 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp // Generic constrained with class or pointer type -> void* if (toType->IsVoidPtr()) { - if ((genericParamInst->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_StructPtr)) || + if (((genericParamInst->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Interface)) != 0) || ((genericParamInst->mTypeConstraint != NULL) && ((genericParamInst->mTypeConstraint->IsPointer()) || (genericParamInst->mTypeConstraint->IsInstanceOf(mCompiler->mFunctionTypeDef)) || @@ -9980,6 +9991,14 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp } } + if (toType->IsInteger()) + { + if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Enum) != 0) + { + return mBfIRBuilder->GetFakeVal(); + } + } + return BfIRValue(); }; @@ -10029,7 +10048,7 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp if (typedVal.mType->IsNull()) { - bool allowCast = (genericParamInst->mGenericParamFlags & BfGenericParamFlag_Class) || (genericParamInst->mGenericParamFlags & BfGenericParamFlag_StructPtr); + bool allowCast = (genericParamInst->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Interface)) != 0; if ((!allowCast) && (genericParamInst->mTypeConstraint != NULL)) allowCast = genericParamInst->mTypeConstraint->IsObject() || genericParamInst->mTypeConstraint->IsPointer(); if (allowCast) @@ -10052,7 +10071,7 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp if (explicitCast) { - if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_StructPtr) || + if (((genericParamInst->mGenericParamFlags & BfGenericParamFlag_StructPtr) != 0) || ((genericParamInst->mTypeConstraint != NULL) && genericParamInst->mTypeConstraint->IsInstanceOf(mCompiler->mFunctionTypeDef))) { auto voidPtrType = CreatePointerType(GetPrimitiveType(BfTypeCode_None)); @@ -10061,6 +10080,24 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp return castedVal; } } + + if ((typedVal.mType->IsIntegral()) && ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Enum) != 0)) + { + bool allowCast = explicitCast; + if ((!allowCast) && (typedVal.mType->IsIntegral())) + { + // Allow implicit cast of zero + auto constant = mBfIRBuilder->GetConstant(typedVal.mValue); + if ((constant != NULL) && (mBfIRBuilder->IsInt(constant->mTypeCode))) + { + allowCast = constant->mInt64 == 0; + } + } + if (allowCast) + { + return mBfIRBuilder->GetFakeVal(); + } + } } if ((typedVal.mType->IsTypeInstance()) && (toType->IsTypeInstance())) diff --git a/IDEHelper/Compiler/BfReducer.cpp b/IDEHelper/Compiler/BfReducer.cpp index 562bc314..39679838 100644 --- a/IDEHelper/Compiler/BfReducer.cpp +++ b/IDEHelper/Compiler/BfReducer.cpp @@ -9496,6 +9496,8 @@ BfGenericConstraintsDeclaration* BfReducer::CreateGenericConstraintsDeclaration( case BfToken_Var: case BfToken_New: case BfToken_Delete: + case BfToken_Enum: + case BfToken_Interface: addToConstraint = true; break; case BfToken_Operator: diff --git a/IDEHelper/Compiler/BfSystem.h b/IDEHelper/Compiler/BfSystem.h index d892b757..00437bf6 100644 --- a/IDEHelper/Compiler/BfSystem.h +++ b/IDEHelper/Compiler/BfSystem.h @@ -592,18 +592,20 @@ public: enum BfGenericParamFlags : uint16 { - BfGenericParamFlag_None = 0, - BfGenericParamFlag_Class = 1, - BfGenericParamFlag_Struct = 2, - BfGenericParamFlag_StructPtr = 4, - BfGenericParamFlag_New = 8, - BfGenericParamFlag_Delete = 0x10, - BfGenericParamFlag_Var = 0x20, - BfGenericParamFlag_Const = 0x40, - BfGenericParamFlag_Equals = 0x80, - BfGenericParamFlag_Equals_Op = 0x100, - BfGenericParamFlag_Equals_Type = 0x200, - BfGenericParamFlag_Equals_IFace = 0x400 + BfGenericParamFlag_None = 0, + BfGenericParamFlag_Class = 1, + BfGenericParamFlag_Struct = 2, + BfGenericParamFlag_StructPtr = 4, + BfGenericParamFlag_Enum = 8, + BfGenericParamFlag_Interface = 0x10, + BfGenericParamFlag_New = 0x20, + BfGenericParamFlag_Delete = 0x40, + BfGenericParamFlag_Var = 0x80, + BfGenericParamFlag_Const = 0x100, + BfGenericParamFlag_Equals = 0x200, + BfGenericParamFlag_Equals_Op = 0x400, + BfGenericParamFlag_Equals_Type = 0x800, + BfGenericParamFlag_Equals_IFace = 0x1000 }; class BfConstraintDef diff --git a/IDEHelper/Tests/src/Generics.bf b/IDEHelper/Tests/src/Generics.bf index 1d533f15..53008abe 100644 --- a/IDEHelper/Tests/src/Generics.bf +++ b/IDEHelper/Tests/src/Generics.bf @@ -132,16 +132,23 @@ namespace Tests return 1; } - public static int MethodA(T val) where T : ValueType + public static int MethodA(T val) where T : struct { return 2; } - public static int MethodA(T val) where T : Enum + public static int MethodA(T val) where T : enum { + int val2 = (int)val; + T val3 = 0; return 3; } + public static int MethodA(T val) where T : interface + { + return 4; + } + public struct Entry { public static int operator<=>(Entry lhs, Entry rhs) @@ -200,9 +207,12 @@ namespace Tests LibA.LibA0.Alloc(); LibA.LibA0.Alloc(); + IDisposable iDisp = null; + Test.Assert(MethodA("") == 1); Test.Assert(MethodA(1.2f) == 2); Test.Assert(MethodA(TypeCode.Boolean) == 3); + Test.Assert(MethodA(iDisp) == 4); ClassC cc = scope .(); Test.Assert(ClassC.mInstance == cc);