aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.h
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/op_factory.h
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/op_factory.h')
-rw-r--r--reference_model/src/ops/op_factory.h82
1 files changed, 45 insertions, 37 deletions
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 9117df4..f276e03 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,19 +23,19 @@
#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
case RANK: \
- return new OP<RANK, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
case RANK: \
- return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_0_6(OP) \
switch (inputRank) \
@@ -57,40 +57,42 @@
}
#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
- return new OP<DType_##DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
- && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
- } \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, \
+ id); \
+ }
#define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ outputDTYPE == TOSA_REF_TYPE_##DTYPE3) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id); \
}
// Statement-expression to evaluate accumulate attribute in-place
@@ -108,35 +110,41 @@
FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \
"of this attribute is required in order to determine the accumulate type."); \
} \
- accumDType; \
- }) \
+ ConvertDType(accumDType); \
+ })
#define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, int16_t>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, half_float::half>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
+ { \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, float>(sgt, attribute, id); \
}
-#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, double>(sgt, attribute, id); \
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -151,7 +159,7 @@
}
#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -165,7 +173,7 @@
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
switch (inputRank) \
{ \
@@ -180,7 +188,7 @@
}
#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && outputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -292,11 +300,11 @@ public:
tosa::Op opType,
tosa::TosaAttributeBase* attribute,
uint64_t id,
- tosa::DType inputDType,
+ TOSA_REF_TYPE inputDTYPE,
int inputRank,
- tosa::DType outputDType,
+ TOSA_REF_TYPE outputDTYPE,
int outputRank,
- tosa::DType weightDType,
+ TOSA_REF_TYPE weightDTYPE,
int weightRank);
};
}; // namespace TosaReference