diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-09-01 12:51:58 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-09-16 01:06:27 +0100 |
commit | 1533b85d198a1dd2b1ce995b6c9d69456e56eb3f (patch) | |
tree | 9c2926e6f646d82ff72f832fcb383e88a688f66b /reference_model/src/ops/tensor_ops.h | |
parent | 93a1628bc3dd48d9ba099de503b586a561b4751f (diff) | |
download | reference_model-1533b85d198a1dd2b1ce995b6c9d69456e56eb3f.tar.gz |
Implement Conv3D kernel.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ic16e918b1a2423ad563684e29ce70d9efdbf9c02
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index eea351d..2174d62 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -109,6 +109,38 @@ protected: }; template <DType InDtype, DType WeightDtype> +class OpConv3d : public GraphNode +{ +public: + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConv3d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; + + using InEigenType = typename GetEigenType<InDtype>::type; + using WeightEigenType = typename GetEigenType<WeightDtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 5>; + using TWeight = Eigen::Tensor<WeightEigenType, 5>; + using TBias = Eigen::Tensor<AccEigenType, 1>; + using TAcc = Eigen::Tensor<AccEigenType, 5>; + + static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; + static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* input; + TosaReference::TensorTemplate<TWeight>* weight; + TosaReference::TensorTemplate<TBias>* bias; + TosaReference::TensorTemplate<TAcc>* output; + tosa::TosaConvAttribute* attribute; + tosa::TosaConvQuantInfo* qinfo; +}; + +template <DType InDtype, DType WeightDtype> class OpDepthwiseConv2d : public GraphNode { public: |