aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.h')
-rw-r--r--reference_model/src/ops/op_factory.h82
1 files changed, 45 insertions, 37 deletions
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 9117df4..f276e03 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,19 +23,19 @@
#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
case RANK: \
- return new OP<RANK, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
case RANK: \
- return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_0_6(OP) \
switch (inputRank) \
@@ -57,40 +57,42 @@
}
#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
- return new OP<DType_##DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
- && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
- } \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, \
+ id); \
+ }
#define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ outputDTYPE == TOSA_REF_TYPE_##DTYPE3) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id); \
}
// Statement-expression to evaluate accumulate attribute in-place
@@ -108,35 +110,41 @@
FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \
"of this attribute is required in order to determine the accumulate type."); \
} \
- accumDType; \
- }) \
+ ConvertDType(accumDType); \
+ })
#define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, int16_t>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, half_float::half>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
+ { \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, float>(sgt, attribute, id); \
}
-#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, double>(sgt, attribute, id); \
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -151,7 +159,7 @@
}
#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -165,7 +173,7 @@
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
switch (inputRank) \
{ \
@@ -180,7 +188,7 @@
}
#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && outputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -292,11 +300,11 @@ public:
tosa::Op opType,
tosa::TosaAttributeBase* attribute,
uint64_t id,
- tosa::DType inputDType,
+ TOSA_REF_TYPE inputDTYPE,
int inputRank,
- tosa::DType outputDType,
+ TOSA_REF_TYPE outputDTYPE,
int outputRank,
- tosa::DType weightDType,
+ TOSA_REF_TYPE weightDTYPE,
int weightRank);
};
}; // namespace TosaReference