diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-14 16:21:29 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-20 00:02:15 +0000 |
commit | f36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch) | |
tree | b1aa6a7314ef598561f0259c4d614a4169451031 /reference_model/src/ops/tensor_ops.h | |
parent | 0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff) | |
download | reference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz |
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in
ConvAttribute and TransposeConvAttribute
Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index e2bb811..2e65548 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -75,7 +75,7 @@ protected: int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpConv2d : public GraphNode { public: @@ -87,7 +87,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; @@ -105,7 +105,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpConv3d : public GraphNode { public: @@ -117,7 +117,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 5>; using TWeight = Eigen::Tensor<WeightEigenType, 5>; @@ -135,7 +135,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpDepthwiseConv2d : public GraphNode { public: @@ -147,7 +147,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; @@ -294,7 +294,7 @@ protected: tosa::TosaRFFTAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpTransposeConv2d : public GraphNode { public: @@ -306,7 +306,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; |