diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-01 20:59:32 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-06 18:27:07 +0000 |
commit | 6e1e2bc06bff785e87577f24064bbc846300f8fd (patch) | |
tree | 0a96aeac6f88799fbc297e5937cc0ffc44adcfff /reference_model/src/ops/type_conversion.h | |
parent | 1d5ddeda5d853642fe3b2eade7d765386727021f (diff) | |
download | reference_model-6e1e2bc06bff785e87577f24064bbc846300f8fd.tar.gz |
[ref model] Change RescaleOp attrs to inputs
This patch implements changes required for RescaleOp's
multiplier and shift changing from attributes to inputs
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I178919727e3220c749dad0ebce141e695868fee0
Diffstat (limited to 'reference_model/src/ops/type_conversion.h')
-rw-r--r-- | reference_model/src/ops/type_conversion.h | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 75f244d..a06dccc 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -32,10 +32,16 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<OutDtype>::type; - using TIn = Eigen::Tensor<InEigenType, Rank>; - using TOut = Eigen::Tensor<OutEigenType, Rank>; + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + using I8EigenType = typename GetEigenType<TOSA_REF_TYPE::TOSA_REF_TYPE_INT8>::type; + using I16EigenType = typename GetEigenType<TOSA_REF_TYPE::TOSA_REF_TYPE_INT16>::type; + using I32EigenType = typename GetEigenType<TOSA_REF_TYPE::TOSA_REF_TYPE_INT32>::type; + using TMultiplierI16 = Eigen::Tensor<I16EigenType, 1>; + using TMultiplierI32 = Eigen::Tensor<I32EigenType, 1>; + using TShift = Eigen::Tensor<I8EigenType, 1>; static constexpr int32_t QMin = GetQMin<OutDtype>::value; static constexpr int32_t QMax = GetQMax<OutDtype>::value; @@ -44,6 +50,9 @@ protected: TosaRescaleAttribute* attribute; TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; + TosaReference::TensorTemplate<TMultiplierI16>* multiplierI16; + TosaReference::TensorTemplate<TMultiplierI32>* multiplierI32; + TosaReference::TensorTemplate<TShift>* shift; }; template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> |