1
0
Fork 0
mirror of https://github.com/beefytech/Beef.git synced 2025-06-16 23:34:10 +02:00

Fixed enumeration capabilities of a generic param with ienumerator iface

This commit is contained in:
Brian Fiete 2020-06-03 05:23:20 -07:00
parent 78bb60cddc
commit a328334ab3

View file

@ -5824,9 +5824,11 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
if (isVarEnumerator) if (isVarEnumerator)
varType = GetPrimitiveType(BfTypeCode_Var); varType = GetPrimitiveType(BfTypeCode_Var);
BfGenericParamInstance* genericParamInst = NULL;
if (target.mType->IsGenericParam()) if (target.mType->IsGenericParam())
{ {
auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType); genericParamInst = GetGenericParamInstance((BfGenericParamType*)target.mType);
if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Var) != 0) if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Var) != 0)
{ {
varType = GetPrimitiveType(BfTypeCode_Var); varType = GetPrimitiveType(BfTypeCode_Var);
@ -5863,41 +5865,44 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
{ {
// Generic method or mixin decl // Generic method or mixin decl
} }
else if (!target.mType->IsTypeInstance()) else if ((!target.mType->IsTypeInstance()) && (genericParamInst == NULL))
{ {
Fail(StrFormat("Type '%s' cannot be used in enumeration", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression); Fail(StrFormat("Type '%s' cannot be used in enumeration", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression);
} }
else if (forEachStmt->mCollectionExpression != NULL) else if (forEachStmt->mCollectionExpression != NULL)
{ {
auto targetTypeInstance = target.mType->ToTypeInstance(); auto targetTypeInstance = target.mType->ToTypeInstance();
PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods);
itr = target; itr = target;
bool hadGetEnumeratorType; bool hadGetEnumeratorType;
auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true);
if (!getEnumeratorMethod) if (targetTypeInstance != NULL)
{ {
hadGetEnumeratorType = false; PopulateType(targetTypeInstance, BfPopulateType_DataAndMethods);
} auto getEnumeratorMethod = GetMethodByName(targetTypeInstance, "GetEnumerator", 0, true);
else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsStatic) if (!getEnumeratorMethod)
{ {
hadGetEnumeratorType = true; hadGetEnumeratorType = false;
Fail(StrFormat("Type '%s' does not contain a non-static 'GetEnumerator' method", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); }
} else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsStatic)
else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsConcrete) {
{ hadGetEnumeratorType = true;
hadGetEnumeratorType = true; Fail(StrFormat("Type '%s' does not contain a non-static 'GetEnumerator' method", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression);
Fail(StrFormat("Iteration requires a concrete implementation of '%s'", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); }
} else if (getEnumeratorMethod.mMethodInstance->mMethodDef->mIsConcrete)
else {
{ hadGetEnumeratorType = true;
hadGetEnumeratorType = true; Fail(StrFormat("Iteration requires a concrete implementation of '%s'", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression);
BfExprEvaluator exprEvaluator(this); }
SizedArray<BfIRValue, 1> args; else
auto castedTarget = Cast(forEachStmt->mCollectionExpression, target, getEnumeratorMethod.mMethodInstance->GetOwner()); {
exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args); hadGetEnumeratorType = true;
itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args); BfExprEvaluator exprEvaluator(this);
SizedArray<BfIRValue, 1> args;
auto castedTarget = Cast(forEachStmt->mCollectionExpression, target, getEnumeratorMethod.mMethodInstance->GetOwner());
exprEvaluator.PushThis(forEachStmt->mCollectionExpression, castedTarget, getEnumeratorMethod.mMethodInstance, args);
itr = exprEvaluator.CreateCall(getEnumeratorMethod.mMethodInstance, mCompiler->IsSkippingExtraResolveChecks() ? BfIRValue() : getEnumeratorMethod.mFunc, false, args);
}
} }
if (itr) if (itr)
@ -5905,32 +5910,38 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
PopulateType(itr.mType, BfPopulateType_DataAndMethods); PopulateType(itr.mType, BfPopulateType_DataAndMethods);
BfGenericTypeInstance* genericItrInterface = NULL; BfGenericTypeInstance* genericItrInterface = NULL;
auto _CheckInterface = [&](BfTypeInstance* interface)
{
if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef))
{
if (genericItrInterface != NULL)
{
Fail(StrFormat("Type '%s' implements multiple %s<T> interfaces", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);
}
itrInterface = interface;
genericItrInterface = itrInterface->ToGenericTypeInstance();
if (inferVarType)
{
varType = genericItrInterface->mTypeGenericArguments[0];
if (isRefExpression)
{
if (varType->IsPointer())
varType = CreateRefType(varType->GetUnderlyingType());
}
}
}
};
auto enumeratorTypeInst = itr.mType->ToTypeInstance(); auto enumeratorTypeInst = itr.mType->ToTypeInstance();
if (enumeratorTypeInst != NULL) if (enumeratorTypeInst != NULL)
{ {
for (auto& interfaceRef : enumeratorTypeInst->mInterfaces) for (auto& interfaceRef : enumeratorTypeInst->mInterfaces)
{ {
BfTypeInstance* interface = interfaceRef.mInterfaceType; BfTypeInstance* interface = interfaceRef.mInterfaceType;
if (interface->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef)) _CheckInterface(interface);
{
if (genericItrInterface != NULL)
{
Fail(StrFormat("Type '%s' implements multiple %s<T> interfaces", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);
}
itrInterface = interface;
genericItrInterface = itrInterface->ToGenericTypeInstance();
if (inferVarType)
{
varType = genericItrInterface->mTypeGenericArguments[0];
if (isRefExpression)
{
if (varType->IsPointer())
varType = CreateRefType(varType->GetUnderlyingType());
}
}
}
} }
if (enumeratorTypeInst->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef)) if (enumeratorTypeInst->mTypeDef == (isRefExpression ? mCompiler->mGenericIRefEnumeratorTypeDef : mCompiler->mGenericIEnumeratorTypeDef))
@ -5949,11 +5960,17 @@ void BfModule::Visit(BfForEachStatement* forEachStmt)
} }
} }
if ((genericItrInterface == NULL) && (genericParamInst != NULL))
{
for (auto interface : genericParamInst->mInterfaceConstraints)
_CheckInterface(interface);
}
if (genericItrInterface == NULL) if (genericItrInterface == NULL)
{ {
if (!hadGetEnumeratorType) if (!hadGetEnumeratorType)
{ {
Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator<T> interface", TypeToString(targetTypeInstance).c_str()), forEachStmt->mCollectionExpression); Fail(StrFormat("Type '%s' must contain a 'GetEnumerator' method or implement an IEnumerator<T> interface", TypeToString(target.mType).c_str()), forEachStmt->mCollectionExpression);
} }
else else
Fail(StrFormat("Enumerator type '%s' must implement an %s<T> interface", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression); Fail(StrFormat("Enumerator type '%s' must implement an %s<T> interface", TypeToString(itr.mType).c_str(), isRefExpression ? "IRefEnumerator" : "IEnumerator"), forEachStmt->mCollectionExpression);