// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef OPS_DATA_LAYOUT_H #define OPS_DATA_LAYOUT_H #include "graph_node.h" using namespace tosa; namespace TosaReference { template class OpConcat : public GraphNode { public: OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpConcat(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: Eigen::array reverser; TosaReference::TensorTemplate* lhs; TosaReference::TensorTemplate* rhs; TosaAxisAttribute* attribute; TosaReference::TensorTemplate* out; }; template class OpPad : public GraphNode { public: OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpPad(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: Eigen::array, Rank> paddings_array; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; TosaPadQuantInfo* qinfo; }; template class OpReshape : public GraphNode { public: OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpReshape(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: Eigen::array array_shape; Eigen::array in_reverser; Eigen::array out_reverser; TosaReference::TensorTemplate* in; TosaReshapeAttribute* attribute; TosaReference::TensorTemplate* out; }; template class OpReverse : public GraphNode { public: OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpReverse(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: TosaAxisAttribute* attribute; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; Eigen::array reverse_array; }; template class OpSlice : public GraphNode { public: OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpSlice(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: TosaSliceAttribute* attribute; Eigen::array begin_array; Eigen::array size_array; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; }; template class OpTileBase : public GraphNode { public: OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTileBase(); virtual int checkTensorAttributes(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: TosaTileAttribute* attribute; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; }; // primary template for op tile template class OpTile : public OpTileBase { public: OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : OpTileBase(attribute_, qinfo_, id_) {} protected: virtual int eval(); }; // partial specialization for specific rank #define DEF_OP_TILE_RANK(N) \ template \ class OpTile : public OpTileBase \ { \ public: \ OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ : OpTileBase(attribute_, qinfo_, id_) \ {} \ \ protected: \ virtual int eval(); \ }; DEF_OP_TILE_RANK(1) DEF_OP_TILE_RANK(2) DEF_OP_TILE_RANK(3) DEF_OP_TILE_RANK(4) #undef DEF_OP_TILE_RANK template class OpTranspose : public GraphNode { public: OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTranspose(); virtual int checkTensorAttributes(); virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: Eigen::array perm_array; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate>* perm_tensor; TosaReference::TensorTemplate* out; }; }; // namespace TosaReference #endif