diff options
Diffstat (limited to 'reference_model/src/ops/type_conversion.h')
-rw-r--r-- | reference_model/src/ops/type_conversion.h | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index b0de30c..e2fc6e2 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -136,6 +136,76 @@ private: FcnType fcn; }; +template <> +class CastHelper<DType_FP32, DType_FP16> +{ +public: + using InEigenType = typename GetEigenType<DType_FP32>::type; + using OutEigenType = typename GetEigenType<DType_FP16>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <DType InDtype> +class CastHelper<InDtype, DType_BF16> +{ +public: + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<DType_BF16>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <DType OutDtype> +class CastHelper<DType_BF16, OutDtype> +{ +public: + using InEigenType = typename GetEigenType<DType_BF16>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + static constexpr int32_t OutMin = GetQMin<OutDtype>::value; + static constexpr int32_t OutMax = GetQMax<OutDtype>::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<DType_FP32, DType_BF16> +{ +public: + using InEigenType = typename GetEigenType<DType_FP32>::type; + using OutEigenType = typename GetEigenType<DType_BF16>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template <DType InDtype> class CastHelper<InDtype, DType_FP32> { @@ -153,6 +223,40 @@ private: FcnType fcn; }; +template <> +class CastHelper<DType_FP16, DType_FP32> +{ +public: + using InEigenType = typename GetEigenType<DType_FP16>::type; + using OutEigenType = typename GetEigenType<DType_FP32>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<DType_BF16, DType_FP32> +{ +public: + using InEigenType = typename GetEigenType<DType_BF16>::type; + using OutEigenType = typename GetEigenType<DType_FP32>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template <DType OutDtype> class CastHelper<DType_FP32, OutDtype> { |