diff options
Diffstat (limited to 'reference_model/src/ops/template_types.h')
-rw-r--r-- | reference_model/src/ops/template_types.h | 41 |
1 files changed, 40 insertions, 1 deletions
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 2bc7e04..9511c31 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "tosa_generated.h" #include <Eigen/CXX11/Tensor> +#include "half.hpp" using namespace tosa; @@ -69,6 +70,12 @@ struct GetEigenType<DType_FLOAT> using type = float; }; template <> +struct GetEigenType<DType_FP16> +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType<DType_INT32> { using type = int32_t; @@ -109,6 +116,28 @@ struct GetEigenType<DType_INT16> using type = int32_t; }; +/* Get Accumulate Eigen Type: +Same behaviour as GetEigenType for all DTypes except the +single specialised case of DType_FP16. */ +template <DType Dtype> +struct GetAccEigenType; +template <> +struct GetAccEigenType<DType_FP16> +{ + using type = half_float::half; +}; +template <DType Dtype> +struct GetAccEigenType +{ + using type = typename GetEigenType<Dtype>::type; +}; + +template <DType Dtype> +struct GetHalfEigenType +{ + using type = half_float::half; +}; + // Meta function to get number of bits template <DType T> struct GetNumBits @@ -155,6 +184,11 @@ struct GetNumBits<DType_INT48> { static constexpr int32_t value = 48; }; +template <> +struct GetNumBits<DType_FP16> +{ + static constexpr int32_t value = 16; +}; // Meta function to get quantized min/max in compile time template <DType T> @@ -262,6 +296,11 @@ struct GetAccDType<DType_INT16, DType_INT16> static constexpr DType value = DType_INT48; }; template <> +struct GetAccDType<DType_FP16, DType_FP16> +{ + static constexpr DType value = DType_FP16; +}; +template <> struct GetAccDType<DType_FLOAT, DType_FLOAT> { static constexpr DType value = DType_FLOAT; |