diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index eea351d..2174d62 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -109,6 +109,38 @@ protected: }; template <DType InDtype, DType WeightDtype> +class OpConv3d : public GraphNode +{ +public: + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConv3d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; + + using InEigenType = typename GetEigenType<InDtype>::type; + using WeightEigenType = typename GetEigenType<WeightDtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 5>; + using TWeight = Eigen::Tensor<WeightEigenType, 5>; + using TBias = Eigen::Tensor<AccEigenType, 1>; + using TAcc = Eigen::Tensor<AccEigenType, 5>; + + static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; + static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* input; + TosaReference::TensorTemplate<TWeight>* weight; + TosaReference::TensorTemplate<TBias>* bias; + TosaReference::TensorTemplate<TAcc>* output; + tosa::TosaConvAttribute* attribute; + tosa::TosaConvQuantInfo* qinfo; +}; + +template <DType InDtype, DType WeightDtype> class OpDepthwiseConv2d : public GraphNode { public: |