diff options
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 17abaf7..484f768 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -459,10 +459,15 @@ CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper() fcn = [](float in) -> OutEigenType { // Cast from float representation back to half_float before rounding half_float::half h = half_float::half(in); - h = std::rint(h); - OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h); - out = std::max<OutEigenType>(out, OutMin); - out = std::min<OutEigenType>(out, OutMax); + if (h >= half_float::half(float(OutMax))) + return OutMax; + + if (h <= half_float::half(float(OutMin))) + return OutMin; + + h = std::rint(h); + OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h); + return out; }; } @@ -478,9 +483,13 @@ CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper() { // bf16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { - OutEigenType out = std::round(in); - out = std::max<OutEigenType>(out, OutMin); - out = std::min<OutEigenType>(out, OutMax); + if (in >= float(OutMax)) + return OutMax; + + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); return out; }; } @@ -527,9 +536,13 @@ CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper() case TOSA_REF_TYPE_INT32: // fp64 data converted to integer fcn = [](InEigenType in) -> OutEigenType { + if (in >= double(OutMax)) + return OutMax; + + if (in <= double(OutMin)) + return OutMin; + OutEigenType out = std::rint(in); - out = std::max<OutEigenType>(out, OutMin); - out = std::min<OutEigenType>(out, OutMax); return out; }; break; |