diff options
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 52de2e4..50e710a 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include "quant_util.h" #include "template_types.h" #include <cmath> +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -287,6 +288,30 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper() } template <DType InDtype> +CastHelper<InDtype, DType_FP16>::CastHelper() +{ + fcn = [](InEigenType in) -> float { + half_float::half out = half_float::half_cast<half_float::half, InEigenType>(in); // Cast to half_float + return half_float::half_cast<float, half_float::half>(out); // Cast to float (underlying FP16 EigenType) + }; +} + +template <DType OutDtype> +CastHelper<DType_FP16, OutDtype>::CastHelper() +{ + // Assuming InEigenType = float. + fcn = [](float in) -> OutEigenType { + // Perform initial rounding in half-precision then cast back to float + half_float::half h = half_float::half_cast<half_float::half, float>(in); + h = std::round(h); + OutEigenType out = half_float::half_cast<float, half_float::half>(h); + out = std::max<OutEigenType>(out, OutMin); + out = std::min<OutEigenType>(out, OutMax); + return out; + }; +} + +template <DType InDtype> CastHelper<InDtype, DType_FLOAT>::CastHelper() { fcn = [](InEigenType in) -> float { @@ -313,15 +338,21 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); |