diff options
Diffstat (limited to 'reference_model/src/ops/op_factory.h')
-rw-r--r-- | reference_model/src/ops/op_factory.h | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 341d7dc..25dfc6e 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -62,18 +62,55 @@ return new OP<DType_##DTYPE>(sgt, attribute, id); \ } +#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \ + if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + { \ + return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ { \ return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \ + if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \ + && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + { \ + return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \ + } \ + +// Statement-expression to evaluate accumulate attribute in-place +#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \ + ({ \ + tosa::DType accumDType = tosa::DType_UNKNOWN; \ + if (auto p = dynamic_cast<tosa::Tosa##ATTRIBUTE_NAME##Attribute*>(attribute)) \ + { \ + auto attr = new tosa::Tosa##ATTRIBUTE_NAME##Attribute(p); \ + ASSERT_MEM(attr); \ + accumDType = tosa::EnumValuesDType()[attr->accum_dtype()]; \ + } \ + else \ + { \ + FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \ + "of this attribute is required in order to determine the accumulate type."); \ + } \ + accumDType; \ + }) \ + #define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ |