aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r--reference_model/src/ops/tensor_ops.h32
1 files changed, 32 insertions, 0 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index eea351d..2174d62 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -109,6 +109,38 @@ protected:
};
template <DType InDtype, DType WeightDtype>
+class OpConv3d : public GraphNode
+{
+public:
+ OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConv3d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 5>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 5>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 5>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConvAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
class OpDepthwiseConv2d : public GraphNode
{
public: