aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.h')
-rw-r--r--reference_model/src/ops/op_factory.h39
1 files changed, 38 insertions, 1 deletions
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 341d7dc..25dfc6e 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -62,18 +62,55 @@
return new OP<DType_##DTYPE>(sgt, attribute, id); \
}
+#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \
+ if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ { \
+ return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ }
+
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
{ \
return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
}
+#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
+ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
+ && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ } \
+
+// Statement-expression to evaluate accumulate attribute in-place
+#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \
+ ({ \
+ tosa::DType accumDType = tosa::DType_UNKNOWN; \
+ if (auto p = dynamic_cast<tosa::Tosa##ATTRIBUTE_NAME##Attribute*>(attribute)) \
+ { \
+ auto attr = new tosa::Tosa##ATTRIBUTE_NAME##Attribute(p); \
+ ASSERT_MEM(attr); \
+ accumDType = tosa::EnumValuesDType()[attr->accum_dtype()]; \
+ } \
+ else \
+ { \
+ FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \
+ "of this attribute is required in order to determine the accumulate type."); \
+ } \
+ accumDType; \
+ }) \
+
#define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \
}
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
+ }
+
#define DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \