diff options
author | James Ward <james.ward@arm.com> | 2023-01-18 14:51:25 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-02-02 22:03:25 +0000 |
commit | d34b3fc5eeef48ecc781a02433ce022a28e3373c (patch) | |
tree | 13aa36aa89c618e56eb2f51915a172ff8e4276d9 /reference_model/src/ops/tensor_ops.h | |
parent | 512c1caa8b6d494de81f3ac83a6ebb96e1e0f8e0 (diff) | |
download | reference_model-d34b3fc5eeef48ecc781a02433ce022a28e3373c.tar.gz |
Remove accumulator attributes from all but AVG_POOL2D
Signed-off-by: James Ward <james.ward@arm.com>
Change-Id: If67f503a1848967bc1671646c3011d055b622c52
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 60 |
1 files changed, 30 insertions, 30 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index ed9a55c..0d2b3eb 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -74,7 +74,7 @@ protected: ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template <DType InDtype, DType WeightDtype, DType AccDtype> +template <DType InDtype, DType WeightDtype, DType OutDtype> class OpConv2d : public GraphNode { public: @@ -86,15 +86,15 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; using TBias = Eigen::Tensor<OutEigenType, 1>; using TOut = Eigen::Tensor<OutEigenType, 4>; - static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; - static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + static constexpr int64_t AccQMin = GetQMin<OutDtype>::value; + static constexpr int64_t AccQMax = GetQMax<OutDtype>::value; protected: TosaReference::TensorTemplate<TIn>* input; @@ -104,7 +104,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype, DType AccDtype> +template <DType InDtype, DType WeightDtype, DType OutDtype> class OpConv3d : public GraphNode { public: @@ -116,15 +116,15 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 5>; using TWeight = Eigen::Tensor<WeightEigenType, 5>; using TBias = Eigen::Tensor<OutEigenType, 1>; using TOut = Eigen::Tensor<OutEigenType, 5>; - static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; - static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + static constexpr int64_t AccQMin = GetQMin<OutDtype>::value; + static constexpr int64_t AccQMax = GetQMax<OutDtype>::value; protected: TosaReference::TensorTemplate<TIn>* input; @@ -134,7 +134,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype, DType AccDtype> +template <DType InDtype, DType WeightDtype, DType OutDtype> class OpDepthwiseConv2d : public GraphNode { public: @@ -146,15 +146,15 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; using TBias = Eigen::Tensor<OutEigenType, 1>; using TOut = Eigen::Tensor<OutEigenType, 4>; - static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; - static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + static constexpr int64_t AccQMin = GetQMin<OutDtype>::value; + static constexpr int64_t AccQMax = GetQMax<OutDtype>::value; protected: TosaReference::TensorTemplate<TIn>* input; @@ -164,7 +164,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype, DType AccDtype> +template <DType InDtype, DType WeightDtype, DType OutDtype> class OpFullyConnected : public GraphNode { public: @@ -176,15 +176,15 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 2>; using TWeight = Eigen::Tensor<WeightEigenType, 2>; using TBias = Eigen::Tensor<OutEigenType, 1>; using TOut = Eigen::Tensor<OutEigenType, 2>; - static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; - static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + static constexpr int64_t AccQMin = GetQMin<OutDtype>::value; + static constexpr int64_t AccQMax = GetQMax<OutDtype>::value; protected: TosaReference::TensorTemplate<TIn>* input; @@ -195,7 +195,7 @@ protected: tosa::TosaFullyConnectedAttribute* attribute; }; -template <DType Dtype, DType AccDtype> +template <DType Dtype, DType OutDtype> class OpMatMul : public GraphNode { public: @@ -206,14 +206,14 @@ public: virtual int eval() final; using InEigenType = typename GetEigenType<Dtype>::type; - using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 3>; using TOut = Eigen::Tensor<OutEigenType, 3>; using TInRank2 = Eigen::Tensor<InEigenType, 2>; using TAccRank2 = Eigen::Tensor<AccEigenType, 2>; - static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; - static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + static constexpr int64_t AccQMin = GetQMin<OutDtype>::value; + static constexpr int64_t AccQMax = GetQMax<OutDtype>::value; protected: TosaReference::TensorTemplate<TIn>* a; @@ -269,7 +269,7 @@ protected: TosaReference::TensorTemplate<TOut>* out_imag; }; -template <DType InDtype, DType WeightDtype, DType AccDtype> +template <DType InDtype, DType WeightDtype, DType OutDtype> class OpTransposeConv2d : public GraphNode { public: @@ -281,15 +281,15 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType<AccDtype>::type; + using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; using TBias = Eigen::Tensor<OutEigenType, 1>; using TOut = Eigen::Tensor<OutEigenType, 4>; - static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; - static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + static constexpr int64_t AccQMin = GetQMin<OutDtype>::value; + static constexpr int64_t AccQMax = GetQMax<OutDtype>::value; protected: TosaReference::TensorTemplate<TIn>* input; |