// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef OPS_TENSOR_OPS_H #define OPS_TENSOR_OPS_H #include "graph_node.h" #include "quant_util.h" using namespace tosa; namespace TosaReference { template class OpArgMax : public GraphNode { public: OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: TosaAxisAttribute* attribute; TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* output; }; template class OpAvgPool2d : public GraphNode { public: OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); virtual int eval(); static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; static constexpr int64_t QMin = GetQMin::value; static constexpr int64_t QMax = GetQMax::value; protected: TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; tosa::TosaPoolAttribute* attribute; tosa::TosaUnaryQuantInfo* qinfo; protected: // return a 1D [N] tensor that describes a how many valid elements covered in the input space ETensor1 calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; template class OpConv2d : public GraphNode { public: OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; virtual int eval() final; static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TAcc = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; tosa::TosaConvQuantInfo* qinfo; }; template class OpConv3d : public GraphNode { public: OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpConv3d(); virtual int checkTensorAttributes() final; virtual int eval() final; static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TAcc = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; tosa::TosaConvQuantInfo* qinfo; }; template class OpDepthwiseConv2d : public GraphNode { public: OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; virtual int eval() final; static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TAcc = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; tosa::TosaConvQuantInfo* qinfo; }; template class OpFullyConnected : public GraphNode { public: OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; virtual int eval() final; static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TAcc = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvQuantInfo* qinfo; }; template class OpMatMul : public GraphNode { public: OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; virtual int eval() final; static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TAcc = Eigen::Tensor; using TInRank2 = Eigen::Tensor; using TAccRank2 = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; TosaReference::TensorTemplate* output; int64_t N; int64_t H; int64_t W; int64_t C; tosa::TosaMatMulQuantInfo* qinfo; }; template class OpMaxPool2d : public GraphNode { public: OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; tosa::TosaPoolAttribute* attribute; }; template class OpTransposeConv2d : public GraphNode { public: OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; virtual int eval() final; static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TAcc = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; TosaTransposeConvAttribute* attribute; TosaConvQuantInfo* qinfo; }; }; // namespace TosaReference #endif