diff options
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 484f768..5dbc7bd 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -15,6 +15,7 @@ #include "type_conversion.h" #include "arith_util.h" +#include "float_utils.h" #include "half.hpp" #include "quant_util.h" #include "template_types.h" @@ -24,6 +25,12 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; +using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>; +using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>; +using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>; +using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>; +using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>; + template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_RESCALE, id_) @@ -527,6 +534,162 @@ CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper() } template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast<half_float::half, float>(h); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast<half_float::half, float>(h); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // Integer data converted to fp8e4m3 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in))); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // fp32 data converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // Integer data converted to fp8e5m2 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in))); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // fp32 data converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +template <TOSA_REF_TYPE OutDtype> CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper() { switch (OutDtype) @@ -597,6 +760,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); |