diff options
author | James Ward <james.ward@arm.com> | 2022-08-12 20:48:56 +0100 |
---|---|---|
committer | James Ward <james.ward@arm.com> | 2022-10-11 11:56:02 +0100 |
commit | 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch) | |
tree | fea519246b698eb944b9d58537fc90bc30481d11 /reference_model/src/ops/template_types.h | |
parent | ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff) | |
download | reference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz |
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'reference_model/src/ops/template_types.h')
-rw-r--r-- | reference_model/src/ops/template_types.h | 41 |
1 files changed, 40 insertions, 1 deletions
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 2bc7e04..9511c31 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "tosa_generated.h" #include <Eigen/CXX11/Tensor> +#include "half.hpp" using namespace tosa; @@ -69,6 +70,12 @@ struct GetEigenType<DType_FLOAT> using type = float; }; template <> +struct GetEigenType<DType_FP16> +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType<DType_INT32> { using type = int32_t; @@ -109,6 +116,28 @@ struct GetEigenType<DType_INT16> using type = int32_t; }; +/* Get Accumulate Eigen Type: +Same behaviour as GetEigenType for all DTypes except the +single specialised case of DType_FP16. */ +template <DType Dtype> +struct GetAccEigenType; +template <> +struct GetAccEigenType<DType_FP16> +{ + using type = half_float::half; +}; +template <DType Dtype> +struct GetAccEigenType +{ + using type = typename GetEigenType<Dtype>::type; +}; + +template <DType Dtype> +struct GetHalfEigenType +{ + using type = half_float::half; +}; + // Meta function to get number of bits template <DType T> struct GetNumBits @@ -155,6 +184,11 @@ struct GetNumBits<DType_INT48> { static constexpr int32_t value = 48; }; +template <> +struct GetNumBits<DType_FP16> +{ + static constexpr int32_t value = 16; +}; // Meta function to get quantized min/max in compile time template <DType T> @@ -262,6 +296,11 @@ struct GetAccDType<DType_INT16, DType_INT16> static constexpr DType value = DType_INT48; }; template <> +struct GetAccDType<DType_FP16, DType_FP16> +{ + static constexpr DType value = DType_FP16; +}; +template <> struct GetAccDType<DType_FLOAT, DType_FLOAT> { static constexpr DType value = DType_FLOAT; |