From 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 Mon Sep 17 00:00:00 2001 From: James Ward Date: Fri, 12 Aug 2022 20:48:56 +0100 Subject: Reference model changes for fp16 support Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward --- reference_model/src/ops/template_types.h | 41 +++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) (limited to 'reference_model/src/ops/template_types.h') diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 2bc7e04..9511c31 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "tosa_generated.h" #include +#include "half.hpp" using namespace tosa; @@ -69,6 +70,12 @@ struct GetEigenType using type = float; }; template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType { using type = int32_t; @@ -109,6 +116,28 @@ struct GetEigenType using type = int32_t; }; +/* Get Accumulate Eigen Type: +Same behaviour as GetEigenType for all DTypes except the +single specialised case of DType_FP16. */ +template +struct GetAccEigenType; +template <> +struct GetAccEigenType +{ + using type = half_float::half; +}; +template +struct GetAccEigenType +{ + using type = typename GetEigenType::type; +}; + +template +struct GetHalfEigenType +{ + using type = half_float::half; +}; + // Meta function to get number of bits template struct GetNumBits @@ -155,6 +184,11 @@ struct GetNumBits { static constexpr int32_t value = 48; }; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 16; +}; // Meta function to get quantized min/max in compile time template @@ -262,6 +296,11 @@ struct GetAccDType static constexpr DType value = DType_INT48; }; template <> +struct GetAccDType +{ + static constexpr DType value = DType_FP16; +}; +template <> struct GetAccDType { static constexpr DType value = DType_FLOAT; -- cgit v1.2.1