diff --git a/IDEHelper/Compiler/BfStmtEvaluator.cpp b/IDEHelper/Compiler/BfStmtEvaluator.cpp index c3345d27..75fa0018 100644 --- a/IDEHelper/Compiler/BfStmtEvaluator.cpp +++ b/IDEHelper/Compiler/BfStmtEvaluator.cpp @@ -5824,9 +5824,11 @@ void BfModule::Visit(BfForEachStatement* forEachStmt) if (isVarEnumerator) varType = GetPrimitiveType(BfTypeCode_Var); + BfGenericParamInstance* genericParamInst = NULL; + if (target.mType->IsGenericParam()) { - auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType); + genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType); if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Var) != 0) { varType = GetPrimitiveType(BfTypeCode_Var); @@ -5863,41 +5865,44 @@ void BfModule::Visit(BfForEachStatement* forEachStmt) { // Generic method or mixin decl } - else if (!target.mType->IsTypeInstance()) + else if ((!target.mType->IsTypeInstance()) && (genericParamInst == NULL)) { Fail(StrFormat("Type '%s' cannot be used in enumeration", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression); } else if (forEachStmt->mCollectionExpression != NULL) { auto targetTypeInstance = target.mType->ToTypeInstance(); - - PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods); - + itr = target; bool hadGetEnumeratorType; - auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true); - if (!getEnumeratorMethod) + + if (targetTypeInstance != NULL) { - hadGetEnumeratorType = false; - } - else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsStatic) - { - hadGetEnumeratorType = true; - Fail(StrFormat("Type '%s' does not contain a non-static 'GetEnumerator' method", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); - } - else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsConcrete) - { - hadGetEnumeratorType = true; - Fail(StrFormat("Iteration requires a concrete implementation of '%s'", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); - } - else - { - hadGetEnumeratorType = true; - BfExprEvaluator exprEvaluator(this); - SizedArray args; - auto castedTarget = Cast(forEachStmt->mCollectionExpression, target, getEnumeratorMethod.mMethodInstance->GetOwner()); - exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args); - itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args); + PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods); + auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true); + if (!getEnumeratorMethod) + { + hadGetEnumeratorType = false; + } + else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsStatic) + { + hadGetEnumeratorType = true; + Fail(StrFormat("Type '%s' does not contain a non-static 'GetEnumerator' method", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); + } + else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsConcrete) + { + hadGetEnumeratorType = true; + Fail(StrFormat("Iteration requires a concrete implementation of '%s'", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); + } + else + { + hadGetEnumeratorType = true; + BfExprEvaluator exprEvaluator(this); + SizedArray args; + auto castedTarget = Cast(forEachStmt->mCollectionExpression, target, getEnumeratorMethod.mMethodInstance->GetOwner()); + exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args); + itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args); + } } if (itr) @@ -5905,32 +5910,38 @@ void BfModule::Visit(BfForEachStatement* forEachStmt) PopulateType(itr.mType, BfPopulateType_DataAndMethods); BfGenericTypeInstance* genericItrInterface = NULL; + + auto _CheckInterface = [&](BfTypeInstance* interface) + { + if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef)) + { + if (genericItrInterface != NULL) + { + Fail(StrFormat("Type '%s' implements multiple %s interfaces", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression); + } + + itrInterface = interface; + genericItrInterface = itrInterface->ToGenericTypeInstance(); + if (inferVarType) + { + varType = genericItrInterface->mTypeGenericArguments[0]; + if (isRefExpression) + { + if (varType->IsPointer()) + varType = CreateRefType(varType->GetUnderlyingType()); + } + + } + } + }; + auto enumeratorTypeInst = itr.mType->ToTypeInstance(); if (enumeratorTypeInst != NULL) { for (auto& interfaceRef : enumeratorTypeInst->mInterfaces) { BfTypeInstance* interface = interfaceRef.mInterfaceType; - if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef)) - { - if (genericItrInterface != NULL) - { - Fail(StrFormat("Type '%s' implements multiple %s interfaces", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression); - } - - itrInterface = interface; - genericItrInterface = itrInterface->ToGenericTypeInstance(); - if (inferVarType) - { - varType = genericItrInterface->mTypeGenericArguments[0]; - if (isRefExpression) - { - if (varType->IsPointer()) - varType = CreateRefType(varType->GetUnderlyingType()); - } - - } - } + _CheckInterface(interface); } if (enumeratorTypeInst->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef)) @@ -5949,11 +5960,17 @@ void BfModule::Visit(BfForEachStatement* forEachStmt) } } + if ((genericItrInterface == NULL) && (genericParamInst != NULL)) + { + for (auto interface : genericParamInst->mInterfaceConstraints) + _CheckInterface(interface); + } + if (genericItrInterface == NULL) { if (!hadGetEnumeratorType) { - Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator interface", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); + Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator interface", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression); } else Fail(StrFormat("Enumerator type '%s' must implement an %s interface", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);