diff options
Diffstat (limited to 'reference_model/src/ops/template_types.h')
-rw-r--r-- | reference_model/src/ops/template_types.h | 96 |
1 files changed, 50 insertions, 46 deletions
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index ece14b1..6dd6e76 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -16,11 +16,10 @@ #ifndef OP_TEMPLATE_TYPES_H #define OP_TEMPLATE_TYPES_H -#include "tosa_generated.h" -#include <Eigen/CXX11/Tensor> +#include "dtype.h" #include "half.hpp" +#include <Eigen/CXX11/Tensor> #include <Eigen/Core> -#include "arith_util.h" using namespace tosa; @@ -64,213 +63,218 @@ using Tensor5 = TensorTemplate<ETensor5<T>>; template <typename T> using Tensor6 = TensorTemplate<ETensor6<T>>; -template <DType type> +template <TOSA_REF_TYPE type> struct GetEigenType; template <> -struct GetEigenType<DType_FP32> +struct GetEigenType<TOSA_REF_TYPE_FP64> +{ + using type = double; +}; +template <> +struct GetEigenType<TOSA_REF_TYPE_FP32> { using type = float; }; template <> -struct GetEigenType<DType_FP16> +struct GetEigenType<TOSA_REF_TYPE_FP16> { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType<DType_BF16> +struct GetEigenType<TOSA_REF_TYPE_BF16> { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType<DType_INT32> +struct GetEigenType<TOSA_REF_TYPE_INT32> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT48> +struct GetEigenType<TOSA_REF_TYPE_INT48> { using type = int64_t; }; template <> -struct GetEigenType<DType_BOOL> +struct GetEigenType<TOSA_REF_TYPE_BOOL> { using type = bool; }; template <> -struct GetEigenType<DType_UINT8> +struct GetEigenType<TOSA_REF_TYPE_UINT8> { using type = int32_t; }; template <> -struct GetEigenType<DType_UINT16> +struct GetEigenType<TOSA_REF_TYPE_UINT16> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT4> +struct GetEigenType<TOSA_REF_TYPE_INT4> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT8> +struct GetEigenType<TOSA_REF_TYPE_INT8> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT16> +struct GetEigenType<TOSA_REF_TYPE_INT16> { using type = int32_t; }; /* Get Accumulate Eigen Type: -Same behaviour as GetEigenType for all DTypes except the -single specialised case of DType_FP16. */ -template <DType Dtype> +Same behaviour as GetEigenType for all DTYPEs except the +single specialised case of TOSA_REF_TYPE_FP16. */ +template <TOSA_REF_TYPE Dtype> struct GetAccEigenType; template <> -struct GetAccEigenType<DType_FP16> +struct GetAccEigenType<TOSA_REF_TYPE_FP16> { using type = half_float::half; }; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> struct GetAccEigenType { using type = typename GetEigenType<Dtype>::type; }; // Meta function to get number of bits -template <DType T> +template <TOSA_REF_TYPE T> struct GetNumBits { static constexpr int32_t value = 0; }; template <> -struct GetNumBits<DType_BOOL> +struct GetNumBits<TOSA_REF_TYPE_BOOL> { static constexpr int32_t value = 1; }; template <> -struct GetNumBits<DType_UINT8> +struct GetNumBits<TOSA_REF_TYPE_UINT8> { static constexpr int32_t value = 8; }; template <> -struct GetNumBits<DType_UINT16> +struct GetNumBits<TOSA_REF_TYPE_UINT16> { static constexpr int32_t value = 16; }; template <> -struct GetNumBits<DType_INT4> +struct GetNumBits<TOSA_REF_TYPE_INT4> { static constexpr int32_t value = 4; }; template <> -struct GetNumBits<DType_INT8> +struct GetNumBits<TOSA_REF_TYPE_INT8> { static constexpr int32_t value = 8; }; template <> -struct GetNumBits<DType_INT16> +struct GetNumBits<TOSA_REF_TYPE_INT16> { static constexpr int32_t value = 16; }; template <> -struct GetNumBits<DType_INT32> +struct GetNumBits<TOSA_REF_TYPE_INT32> { static constexpr int32_t value = 32; }; template <> -struct GetNumBits<DType_INT48> +struct GetNumBits<TOSA_REF_TYPE_INT48> { static constexpr int32_t value = 48; }; template <> -struct GetNumBits<DType_FP16> +struct GetNumBits<TOSA_REF_TYPE_FP16> { static constexpr int32_t value = 16; }; // Meta function to get quantized min/max in compile time -template <DType T> +template <TOSA_REF_TYPE T> struct GetQMin { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_UINT8> +struct GetQMin<TOSA_REF_TYPE_UINT8> { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_UINT16> +struct GetQMin<TOSA_REF_TYPE_UINT16> { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_INT4> +struct GetQMin<TOSA_REF_TYPE_INT4> { static constexpr int64_t value = INT64_C(-8); }; template <> -struct GetQMin<DType_INT8> +struct GetQMin<TOSA_REF_TYPE_INT8> { static constexpr int64_t value = INT64_C(-128); }; template <> -struct GetQMin<DType_INT16> +struct GetQMin<TOSA_REF_TYPE_INT16> { static constexpr int64_t value = INT64_C(-32768); }; template <> -struct GetQMin<DType_INT32> +struct GetQMin<TOSA_REF_TYPE_INT32> { static constexpr int64_t value = -(INT64_C(1) << 31); }; template <> -struct GetQMin<DType_INT48> +struct GetQMin<TOSA_REF_TYPE_INT48> { static constexpr int64_t value = -(INT64_C(1) << 47); }; -template <DType T> +template <TOSA_REF_TYPE T> struct GetQMax { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMax<DType_UINT8> +struct GetQMax<TOSA_REF_TYPE_UINT8> { static constexpr int64_t value = INT64_C(255); }; template <> -struct GetQMax<DType_UINT16> +struct GetQMax<TOSA_REF_TYPE_UINT16> { static constexpr int64_t value = INT64_C(65535); }; template <> -struct GetQMax<DType_INT4> +struct GetQMax<TOSA_REF_TYPE_INT4> { static constexpr int64_t value = INT64_C(7); }; template <> -struct GetQMax<DType_INT8> +struct GetQMax<TOSA_REF_TYPE_INT8> { static constexpr int64_t value = INT64_C(127); }; template <> -struct GetQMax<DType_INT16> +struct GetQMax<TOSA_REF_TYPE_INT16> { static constexpr int64_t value = INT64_C(32767); }; template <> -struct GetQMax<DType_INT32> +struct GetQMax<TOSA_REF_TYPE_INT32> { static constexpr int64_t value = (INT64_C(1) << 31) - 1; }; template <> -struct GetQMax<DType_INT48> +struct GetQMax<TOSA_REF_TYPE_INT48> { static constexpr int64_t value = (INT64_C(1) << 47) - 1; }; |