diff options
author | Tai Ly <tai.ly@arm.com> | 2023-03-28 22:06:56 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-05-05 19:23:15 +0000 |
commit | a4d748b08accce06fab93e2d2b96e499b35ae89b (patch) | |
tree | 20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/template_types.h | |
parent | 0c71686875618b2e11290273b7a05b88ef8a8aae (diff) | |
download | reference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz |
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model,
which will cause reference model to convert all floating point tensors
to FP64 tensors and compute all operators accordingly.
Also adds optional -p arguments to test runners tosa_verif_run_tests.py
and tosa_verif_framework_compiler_runner.py to run tests in precise mode
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/template_types.h')
-rw-r--r-- | reference_model/src/ops/template_types.h | 96 |
1 files changed, 50 insertions, 46 deletions
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index ece14b1..6dd6e76 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -16,11 +16,10 @@ #ifndef OP_TEMPLATE_TYPES_H #define OP_TEMPLATE_TYPES_H -#include "tosa_generated.h" -#include <Eigen/CXX11/Tensor> +#include "dtype.h" #include "half.hpp" +#include <Eigen/CXX11/Tensor> #include <Eigen/Core> -#include "arith_util.h" using namespace tosa; @@ -64,213 +63,218 @@ using Tensor5 = TensorTemplate<ETensor5<T>>; template <typename T> using Tensor6 = TensorTemplate<ETensor6<T>>; -template <DType type> +template <TOSA_REF_TYPE type> struct GetEigenType; template <> -struct GetEigenType<DType_FP32> +struct GetEigenType<TOSA_REF_TYPE_FP64> +{ + using type = double; +}; +template <> +struct GetEigenType<TOSA_REF_TYPE_FP32> { using type = float; }; template <> -struct GetEigenType<DType_FP16> +struct GetEigenType<TOSA_REF_TYPE_FP16> { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType<DType_BF16> +struct GetEigenType<TOSA_REF_TYPE_BF16> { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType<DType_INT32> +struct GetEigenType<TOSA_REF_TYPE_INT32> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT48> +struct GetEigenType<TOSA_REF_TYPE_INT48> { using type = int64_t; }; template <> -struct GetEigenType<DType_BOOL> +struct GetEigenType<TOSA_REF_TYPE_BOOL> { using type = bool; }; template <> -struct GetEigenType<DType_UINT8> +struct GetEigenType<TOSA_REF_TYPE_UINT8> { using type = int32_t; }; template <> -struct GetEigenType<DType_UINT16> +struct GetEigenType<TOSA_REF_TYPE_UINT16> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT4> +struct GetEigenType<TOSA_REF_TYPE_INT4> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT8> +struct GetEigenType<TOSA_REF_TYPE_INT8> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT16> +struct GetEigenType<TOSA_REF_TYPE_INT16> { using type = int32_t; }; /* Get Accumulate Eigen Type: -Same behaviour as GetEigenType for all DTypes except the -single specialised case of DType_FP16. */ -template <DType Dtype> +Same behaviour as GetEigenType for all DTYPEs except the +single specialised case of TOSA_REF_TYPE_FP16. */ +template <TOSA_REF_TYPE Dtype> struct GetAccEigenType; template <> -struct GetAccEigenType<DType_FP16> +struct GetAccEigenType<TOSA_REF_TYPE_FP16> { using type = half_float::half; }; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> struct GetAccEigenType { using type = typename GetEigenType<Dtype>::type; }; // Meta function to get number of bits -template <DType T> +template <TOSA_REF_TYPE T> struct GetNumBits { static constexpr int32_t value = 0; }; template <> -struct GetNumBits<DType_BOOL> +struct GetNumBits<TOSA_REF_TYPE_BOOL> { static constexpr int32_t value = 1; }; template <> -struct GetNumBits<DType_UINT8> +struct GetNumBits<TOSA_REF_TYPE_UINT8> { static constexpr int32_t value = 8; }; template <> -struct GetNumBits<DType_UINT16> +struct GetNumBits<TOSA_REF_TYPE_UINT16> { static constexpr int32_t value = 16; }; template <> -struct GetNumBits<DType_INT4> +struct GetNumBits<TOSA_REF_TYPE_INT4> { static constexpr int32_t value = 4; }; template <> -struct GetNumBits<DType_INT8> +struct GetNumBits<TOSA_REF_TYPE_INT8> { static constexpr int32_t value = 8; }; template <> -struct GetNumBits<DType_INT16> +struct GetNumBits<TOSA_REF_TYPE_INT16> { static constexpr int32_t value = 16; }; template <> -struct GetNumBits<DType_INT32> +struct GetNumBits<TOSA_REF_TYPE_INT32> { static constexpr int32_t value = 32; }; template <> -struct GetNumBits<DType_INT48> +struct GetNumBits<TOSA_REF_TYPE_INT48> { static constexpr int32_t value = 48; }; template <> -struct GetNumBits<DType_FP16> +struct GetNumBits<TOSA_REF_TYPE_FP16> { static constexpr int32_t value = 16; }; // Meta function to get quantized min/max in compile time -template <DType T> +template <TOSA_REF_TYPE T> struct GetQMin { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_UINT8> +struct GetQMin<TOSA_REF_TYPE_UINT8> { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_UINT16> +struct GetQMin<TOSA_REF_TYPE_UINT16> { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_INT4> +struct GetQMin<TOSA_REF_TYPE_INT4> { static constexpr int64_t value = INT64_C(-8); }; template <> -struct GetQMin<DType_INT8> +struct GetQMin<TOSA_REF_TYPE_INT8> { static constexpr int64_t value = INT64_C(-128); }; template <> -struct GetQMin<DType_INT16> +struct GetQMin<TOSA_REF_TYPE_INT16> { static constexpr int64_t value = INT64_C(-32768); }; template <> -struct GetQMin<DType_INT32> +struct GetQMin<TOSA_REF_TYPE_INT32> { static constexpr int64_t value = -(INT64_C(1) << 31); }; template <> -struct GetQMin<DType_INT48> +struct GetQMin<TOSA_REF_TYPE_INT48> { static constexpr int64_t value = -(INT64_C(1) << 47); }; -template <DType T> +template <TOSA_REF_TYPE T> struct GetQMax { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMax<DType_UINT8> +struct GetQMax<TOSA_REF_TYPE_UINT8> { static constexpr int64_t value = INT64_C(255); }; template <> -struct GetQMax<DType_UINT16> +struct GetQMax<TOSA_REF_TYPE_UINT16> { static constexpr int64_t value = INT64_C(65535); }; template <> -struct GetQMax<DType_INT4> +struct GetQMax<TOSA_REF_TYPE_INT4> { static constexpr int64_t value = INT64_C(7); }; template <> -struct GetQMax<DType_INT8> +struct GetQMax<TOSA_REF_TYPE_INT8> { static constexpr int64_t value = INT64_C(127); }; template <> -struct GetQMax<DType_INT16> +struct GetQMax<TOSA_REF_TYPE_INT16> { static constexpr int64_t value = INT64_C(32767); }; template <> -struct GetQMax<DType_INT32> +struct GetQMax<TOSA_REF_TYPE_INT32> { static constexpr int64_t value = (INT64_C(1) << 31) - 1; }; template <> -struct GetQMax<DType_INT48> +struct GetQMax<TOSA_REF_TYPE_INT48> { static constexpr int64_t value = (INT64_C(1) << 47) - 1; }; |