1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-10 20:42:21 +02:00

Fixed enumeration capabilities of a generic param with ienumerator iface

This commit is contained in:
Brian Fiete 2020-06-03 05:23:20 -07:00
parent 78bb60cddc
commit a328334ab3

View file

@ -5824,9 +5824,11 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
if (isVarEnumerator)
varType = GetPrimitiveType(BfTypeCode_Var);
BfGenericParamInstance* genericParamInst = NULL;
if (target.mType->IsGenericParam())
{
auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType);
genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType);
if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Var) != 0)
{
varType = GetPrimitiveType(BfTypeCode_Var);
@ -5863,41 +5865,44 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
{
// Generic method or mixin decl
}
else if (!target.mType->IsTypeInstance())
else if ((!target.mType->IsTypeInstance()) && (genericParamInst == NULL))
{
Fail(StrFormat("Type '%s' cannot be used in enumeration", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression);
}
else if (forEachStmt->mCollectionExpression != NULL)
{
auto targetTypeInstance = target.mType->ToTypeInstance();
PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods);
itr = target;
bool hadGetEnumeratorType;
auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true);
if (!getEnumeratorMethod)
if (targetTypeInstance != NULL)
{
hadGetEnumeratorType = false;
}
else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsStatic)
{
hadGetEnumeratorType = true;
Fail(StrFormat("Type '%s' does not contain a non-static 'GetEnumerator' method", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression);
}
else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsConcrete)
{
hadGetEnumeratorType = true;
Fail(StrFormat("Iteration requires a concrete implementation of '%s'", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression);
}
else
{
hadGetEnumeratorType = true;
BfExprEvaluator exprEvaluator(this);
SizedArray<BfIRValue, 1> args;
auto castedTarget = Cast(forEachStmt->mCollectionExpression, target, getEnumeratorMethod.mMethodInstance->GetOwner());
exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args);
itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args);
PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods);
auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true);
if (!getEnumeratorMethod)
{
hadGetEnumeratorType = false;
}
else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsStatic)
{
hadGetEnumeratorType = true;
Fail(StrFormat("Type '%s' does not contain a non-static 'GetEnumerator' method", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression);
}
else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsConcrete)
{
hadGetEnumeratorType = true;
Fail(StrFormat("Iteration requires a concrete implementation of '%s'", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression);
}
else
{
hadGetEnumeratorType = true;
BfExprEvaluator exprEvaluator(this);
SizedArray<BfIRValue, 1> args;
auto castedTarget = Cast(forEachStmt->mCollectionExpression, target, getEnumeratorMethod.mMethodInstance->GetOwner());
exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args);
itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args);
}
}
if (itr)
@ -5905,32 +5910,38 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
PopulateType(itr.mType, BfPopulateType_DataAndMethods);
BfGenericTypeInstance* genericItrInterface = NULL;
auto _CheckInterface = [&](BfTypeInstance* interface)
{
if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef))
{
if (genericItrInterface != NULL)
{
Fail(StrFormat("Type '%s' implements multiple %s<T> interfaces", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);
}
itrInterface = interface;
genericItrInterface = itrInterface->ToGenericTypeInstance();
if (inferVarType)
{
varType = genericItrInterface->mTypeGenericArguments[0];
if (isRefExpression)
{
if (varType->IsPointer())
varType = CreateRefType(varType->GetUnderlyingType());
}
}
}
};
auto enumeratorTypeInst = itr.mType->ToTypeInstance();
if (enumeratorTypeInst != NULL)
{
for (auto& interfaceRef : enumeratorTypeInst->mInterfaces)
{
BfTypeInstance* interface = interfaceRef.mInterfaceType;
if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef))
{
if (genericItrInterface != NULL)
{
Fail(StrFormat("Type '%s' implements multiple %s<T> interfaces", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);
}
itrInterface = interface;
genericItrInterface = itrInterface->ToGenericTypeInstance();
if (inferVarType)
{
varType = genericItrInterface->mTypeGenericArguments[0];
if (isRefExpression)
{
if (varType->IsPointer())
varType = CreateRefType(varType->GetUnderlyingType());
}
}
}
_CheckInterface(interface);
}
if (enumeratorTypeInst->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef))
@ -5949,11 +5960,17 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
}
}
if ((genericItrInterface == NULL) && (genericParamInst != NULL))
{
for (auto interface : genericParamInst->mInterfaceConstraints)
_CheckInterface(interface);
}
if (genericItrInterface == NULL)
{
if (!hadGetEnumeratorType)
{
Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator<T> interface", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression);
Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator<T> interface", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression);
}
else
Fail(StrFormat("Enumerator type '%s' must implement an %s<T> interface", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);