diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index e2bb811..2e65548 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -75,7 +75,7 @@ protected: int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpConv2d : public GraphNode { public: @@ -87,7 +87,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; @@ -105,7 +105,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpConv3d : public GraphNode { public: @@ -117,7 +117,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 5>; using TWeight = Eigen::Tensor<WeightEigenType, 5>; @@ -135,7 +135,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpDepthwiseConv2d : public GraphNode { public: @@ -147,7 +147,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; @@ -294,7 +294,7 @@ protected: tosa::TosaRFFTAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpTransposeConv2d : public GraphNode { public: @@ -306,7 +306,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; |