aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.h
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-14 16:21:29 +0000
committerEric Kunze <eric.kunze@arm.com>2024-03-20 00:02:15 +0000
commitf36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch)
treeb1aa6a7314ef598561f0259c4d614a4169451031 /reference_model/src/ops/tensor_ops.h
parent0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff)
downloadreference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly <tai.ly@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r--reference_model/src/ops/tensor_ops.h16
1 files changed, 8 insertions, 8 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index e2bb811..2e65548 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -75,7 +75,7 @@ protected:
int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpConv2d : public GraphNode
{
public:
@@ -87,7 +87,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
@@ -105,7 +105,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpConv3d : public GraphNode
{
public:
@@ -117,7 +117,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 5>;
using TWeight = Eigen::Tensor<WeightEigenType, 5>;
@@ -135,7 +135,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
@@ -147,7 +147,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
@@ -294,7 +294,7 @@ protected:
tosa::TosaRFFTAttribute* attribute;
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpTransposeConv2d : public GraphNode
{
public:
@@ -306,7 +306,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;