aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/template_types.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/template_types.h')
-rw-r--r--reference_model/src/ops/template_types.h96
1 files changed, 50 insertions, 46 deletions
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index ece14b1..6dd6e76 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -16,11 +16,10 @@
#ifndef OP_TEMPLATE_TYPES_H
#define OP_TEMPLATE_TYPES_H
-#include "tosa_generated.h"
-#include <Eigen/CXX11/Tensor>
+#include "dtype.h"
#include "half.hpp"
+#include <Eigen/CXX11/Tensor>
#include <Eigen/Core>
-#include "arith_util.h"
using namespace tosa;
@@ -64,213 +63,218 @@ using Tensor5 = TensorTemplate<ETensor5<T>>;
template <typename T>
using Tensor6 = TensorTemplate<ETensor6<T>>;
-template <DType type>
+template <TOSA_REF_TYPE type>
struct GetEigenType;
template <>
-struct GetEigenType<DType_FP32>
+struct GetEigenType<TOSA_REF_TYPE_FP64>
+{
+ using type = double;
+};
+template <>
+struct GetEigenType<TOSA_REF_TYPE_FP32>
{
using type = float;
};
template <>
-struct GetEigenType<DType_FP16>
+struct GetEigenType<TOSA_REF_TYPE_FP16>
{
// NOTE: full precision used
using type = float;
};
template <>
-struct GetEigenType<DType_BF16>
+struct GetEigenType<TOSA_REF_TYPE_BF16>
{
// NOTE: full precision used
using type = float;
};
template <>
-struct GetEigenType<DType_INT32>
+struct GetEigenType<TOSA_REF_TYPE_INT32>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT48>
+struct GetEigenType<TOSA_REF_TYPE_INT48>
{
using type = int64_t;
};
template <>
-struct GetEigenType<DType_BOOL>
+struct GetEigenType<TOSA_REF_TYPE_BOOL>
{
using type = bool;
};
template <>
-struct GetEigenType<DType_UINT8>
+struct GetEigenType<TOSA_REF_TYPE_UINT8>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_UINT16>
+struct GetEigenType<TOSA_REF_TYPE_UINT16>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT4>
+struct GetEigenType<TOSA_REF_TYPE_INT4>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT8>
+struct GetEigenType<TOSA_REF_TYPE_INT8>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT16>
+struct GetEigenType<TOSA_REF_TYPE_INT16>
{
using type = int32_t;
};
/* Get Accumulate Eigen Type:
-Same behaviour as GetEigenType for all DTypes except the
-single specialised case of DType_FP16. */
-template <DType Dtype>
+Same behaviour as GetEigenType for all DTYPEs except the
+single specialised case of TOSA_REF_TYPE_FP16. */
+template <TOSA_REF_TYPE Dtype>
struct GetAccEigenType;
template <>
-struct GetAccEigenType<DType_FP16>
+struct GetAccEigenType<TOSA_REF_TYPE_FP16>
{
using type = half_float::half;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
struct GetAccEigenType
{
using type = typename GetEigenType<Dtype>::type;
};
// Meta function to get number of bits
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetNumBits
{
static constexpr int32_t value = 0;
};
template <>
-struct GetNumBits<DType_BOOL>
+struct GetNumBits<TOSA_REF_TYPE_BOOL>
{
static constexpr int32_t value = 1;
};
template <>
-struct GetNumBits<DType_UINT8>
+struct GetNumBits<TOSA_REF_TYPE_UINT8>
{
static constexpr int32_t value = 8;
};
template <>
-struct GetNumBits<DType_UINT16>
+struct GetNumBits<TOSA_REF_TYPE_UINT16>
{
static constexpr int32_t value = 16;
};
template <>
-struct GetNumBits<DType_INT4>
+struct GetNumBits<TOSA_REF_TYPE_INT4>
{
static constexpr int32_t value = 4;
};
template <>
-struct GetNumBits<DType_INT8>
+struct GetNumBits<TOSA_REF_TYPE_INT8>
{
static constexpr int32_t value = 8;
};
template <>
-struct GetNumBits<DType_INT16>
+struct GetNumBits<TOSA_REF_TYPE_INT16>
{
static constexpr int32_t value = 16;
};
template <>
-struct GetNumBits<DType_INT32>
+struct GetNumBits<TOSA_REF_TYPE_INT32>
{
static constexpr int32_t value = 32;
};
template <>
-struct GetNumBits<DType_INT48>
+struct GetNumBits<TOSA_REF_TYPE_INT48>
{
static constexpr int32_t value = 48;
};
template <>
-struct GetNumBits<DType_FP16>
+struct GetNumBits<TOSA_REF_TYPE_FP16>
{
static constexpr int32_t value = 16;
};
// Meta function to get quantized min/max in compile time
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetQMin
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_UINT8>
+struct GetQMin<TOSA_REF_TYPE_UINT8>
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_UINT16>
+struct GetQMin<TOSA_REF_TYPE_UINT16>
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_INT4>
+struct GetQMin<TOSA_REF_TYPE_INT4>
{
static constexpr int64_t value = INT64_C(-8);
};
template <>
-struct GetQMin<DType_INT8>
+struct GetQMin<TOSA_REF_TYPE_INT8>
{
static constexpr int64_t value = INT64_C(-128);
};
template <>
-struct GetQMin<DType_INT16>
+struct GetQMin<TOSA_REF_TYPE_INT16>
{
static constexpr int64_t value = INT64_C(-32768);
};
template <>
-struct GetQMin<DType_INT32>
+struct GetQMin<TOSA_REF_TYPE_INT32>
{
static constexpr int64_t value = -(INT64_C(1) << 31);
};
template <>
-struct GetQMin<DType_INT48>
+struct GetQMin<TOSA_REF_TYPE_INT48>
{
static constexpr int64_t value = -(INT64_C(1) << 47);
};
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetQMax
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMax<DType_UINT8>
+struct GetQMax<TOSA_REF_TYPE_UINT8>
{
static constexpr int64_t value = INT64_C(255);
};
template <>
-struct GetQMax<DType_UINT16>
+struct GetQMax<TOSA_REF_TYPE_UINT16>
{
static constexpr int64_t value = INT64_C(65535);
};
template <>
-struct GetQMax<DType_INT4>
+struct GetQMax<TOSA_REF_TYPE_INT4>
{
static constexpr int64_t value = INT64_C(7);
};
template <>
-struct GetQMax<DType_INT8>
+struct GetQMax<TOSA_REF_TYPE_INT8>
{
static constexpr int64_t value = INT64_C(127);
};
template <>
-struct GetQMax<DType_INT16>
+struct GetQMax<TOSA_REF_TYPE_INT16>
{
static constexpr int64_t value = INT64_C(32767);
};
template <>
-struct GetQMax<DType_INT32>
+struct GetQMax<TOSA_REF_TYPE_INT32>
{
static constexpr int64_t value = (INT64_C(1) << 31) - 1;
};
template <>
-struct GetQMax<DType_INT48>
+struct GetQMax<TOSA_REF_TYPE_INT48>
{
static constexpr int64_t value = (INT64_C(1) << 47) - 1;
};