aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/type_conversion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r--reference_model/src/ops/type_conversion.cc175
1 files changed, 175 insertions, 0 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 484f768..5dbc7bd 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -15,6 +15,7 @@
#include "type_conversion.h"
#include "arith_util.h"
+#include "float_utils.h"
#include "half.hpp"
#include "quant_util.h"
#include "template_types.h"
@@ -24,6 +25,12 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
+using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>;
+using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>;
+using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>;
+using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>;
+using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>;
+
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_RESCALE, id_)
@@ -527,6 +534,162 @@ CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
}
template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to integer
+ fcn = [](float in) -> OutEigenType {
+ if (in >= float(OutMax))
+ return OutMax;
+ if (in <= float(OutMin))
+ return OutMin;
+
+ OutEigenType out = std::rint(in);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32)
+ fcn = [](float in) -> float {
+ half_float::half h = half_float::half(in);
+ float out = half_float::half_cast<half_float::half, float>(h);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32)
+ fcn = [](float in) -> float { return (float)in; };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to fp32
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+}
+
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to integer
+ fcn = [](float in) -> OutEigenType {
+ if (in >= float(OutMax))
+ return OutMax;
+ if (in <= float(OutMin))
+ return OutMin;
+
+ OutEigenType out = std::rint(in);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32)
+ fcn = [](float in) -> float {
+ half_float::half h = half_float::half(in);
+ float out = half_float::half_cast<half_float::half, float>(h);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32)
+ fcn = [](float in) -> float { return (float)in; };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to fp32
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+}
+
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // Integer data converted to fp8e4m3 (stored as fp32)
+ fcn = [](InEigenType in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in)));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // fp32 data converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // Integer data converted to fp8e5m2 (stored as fp32)
+ fcn = [](InEigenType in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in)));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // fp32 data converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+template <TOSA_REF_TYPE OutDtype>
CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
{
switch (OutDtype)
@@ -597,6 +760,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);