1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-17 15:46:05 +02:00

Fixed enumeration capabilities of a generic param with ienumerator iface

This commit is contained in:
Brian Fiete 2020-06-03 05:23:20 -07:00
parent 78bb60cddc
commit a328334ab3

View file

@ -5824,9 +5824,11 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
if (isVarEnumerator) if (isVarEnumerator)
varType = GetPrimitiveType(BfTypeCode_Var); varType = GetPrimitiveType(BfTypeCode_Var);
BfGenericParamInstance* genericParamInst = NULL;
if (target.mType->IsGenericParam()) if (target.mType->IsGenericParam())
{ {
auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType); genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType);
if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Var) != 0) if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Var) != 0)
{ {
varType = GetPrimitiveType(BfTypeCode_Var); varType = GetPrimitiveType(BfTypeCode_Var);
@ -5863,7 +5865,7 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
{ {
// Generic method or mixin decl // Generic method or mixin decl
} }
else if (!target.mType->IsTypeInstance()) else if ((!target.mType->IsTypeInstance()) && (genericParamInst == NULL))
{ {
Fail(StrFormat("Type '%s' cannot be used in enumeration", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression); Fail(StrFormat("Type '%s' cannot be used in enumeration", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression);
} }
@ -5871,10 +5873,12 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
{ {
auto targetTypeInstance = target.mType->ToTypeInstance(); auto targetTypeInstance = target.mType->ToTypeInstance();
PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods);
itr = target; itr = target;
bool hadGetEnumeratorType; bool hadGetEnumeratorType;
if (targetTypeInstance != NULL)
{
PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods);
auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true); auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true);
if (!getEnumeratorMethod) if (!getEnumeratorMethod)
{ {
@ -5899,18 +5903,16 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args); exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args);
itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args); itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args);
} }
}
if (itr) if (itr)
{ {
PopulateType(itr.mType, BfPopulateType_DataAndMethods); PopulateType(itr.mType, BfPopulateType_DataAndMethods);
BfGenericTypeInstance* genericItrInterface = NULL; BfGenericTypeInstance* genericItrInterface = NULL;
auto enumeratorTypeInst = itr.mType->ToTypeInstance();
if (enumeratorTypeInst != NULL) auto _CheckInterface = [&](BfTypeInstance* interface)
{ {
for (auto& interfaceRef : enumeratorTypeInst->mInterfaces)
{
BfTypeInstance* interface = interfaceRef.mInterfaceType;
if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef)) if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef))
{ {
if (genericItrInterface != NULL) if (genericItrInterface != NULL)
@ -5931,6 +5933,15 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
} }
} }
};
auto enumeratorTypeInst = itr.mType->ToTypeInstance();
if (enumeratorTypeInst != NULL)
{
for (auto& interfaceRef : enumeratorTypeInst->mInterfaces)
{
BfTypeInstance* interface = interfaceRef.mInterfaceType;
_CheckInterface(interface);
} }
if (enumeratorTypeInst->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef)) if (enumeratorTypeInst->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef))
@ -5949,11 +5960,17 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
} }
} }
if ((genericItrInterface == NULL) && (genericParamInst != NULL))
{
for (auto interface : genericParamInst->mInterfaceConstraints)
_CheckInterface(interface);
}
if (genericItrInterface == NULL) if (genericItrInterface == NULL)
{ {
if (!hadGetEnumeratorType) if (!hadGetEnumeratorType)
{ {
Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator<T> interface", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator<T> interface", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression);
} }
else else
Fail(StrFormat("Enumerator type '%s' must implement an %s<T> interface", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression); Fail(StrFormat("Enumerator type '%s' must implement an %s<T> interface", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);