aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/type_conversion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r--reference_model/src/ops/type_conversion.cc33
1 files changed, 32 insertions, 1 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 52de2e4..50e710a 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
#include "quant_util.h"
#include "template_types.h"
#include <cmath>
+#include "half.hpp"
using namespace TosaReference;
using namespace Eigen;
@@ -287,6 +288,30 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper()
}
template <DType InDtype>
+CastHelper<InDtype, DType_FP16>::CastHelper()
+{
+ fcn = [](InEigenType in) -> float {
+ half_float::half out = half_float::half_cast<half_float::half, InEigenType>(in); // Cast to half_float
+ return half_float::half_cast<float, half_float::half>(out); // Cast to float (underlying FP16 EigenType)
+ };
+}
+
+template <DType OutDtype>
+CastHelper<DType_FP16, OutDtype>::CastHelper()
+{
+ // Assuming InEigenType = float.
+ fcn = [](float in) -> OutEigenType {
+ // Perform initial rounding in half-precision then cast back to float
+ half_float::half h = half_float::half_cast<half_float::half, float>(in);
+ h = std::round(h);
+ OutEigenType out = half_float::half_cast<float, half_float::half>(h);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+}
+
+template <DType InDtype>
CastHelper<InDtype, DType_FLOAT>::CastHelper()
{
fcn = [](InEigenType in) -> float {
@@ -313,15 +338,21 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);