// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef OPS_EWISE_BINARY_H #define OPS_EWISE_BINARY_H #include "graph_node.h" using namespace tosa; namespace TosaReference { // class BinaryNodeBase: hold common functions of all the binary nodes // when an binary op is created, the virtual OpXXX::register_fcn() will be called // and 'fcn' will be register with lambda function which has two inputs // class BinaryNode: the level of indirection to partially specialize template for rank 0 // eval() from toplevel called should call the .binaryExpr(dims, fcn) here // this needs to be partially specialize or // compiler will statically fail when trying to broadcast rank0 tensor // class OpXXX: implement per-element lambda function based on different data type // unlike BinaryNode, this doesn't need to be partially specialized // Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.) // which might be faster since it could be implemented with SIMD instructions // the way of registering lambda + .binaryExpr() might sacrifice performance here // but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...} // needs to revisit if performance becomes a bottleneck here template class BinaryNodeBase : public GraphNode { public: BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); virtual ~BinaryNodeBase(); virtual int checkTensorAttributes() final; virtual int eval() = 0; virtual int register_fcn() = 0; using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: int broadcast(); protected: std::function fcn; Eigen::array bcast_a; Eigen::array bcast_b; TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; TosaReference::TensorTemplate>* a_rank0; TosaReference::TensorTemplate>* b_rank0; TosaReference::TensorTemplate* result; int a_rank; int b_rank; int max_input_rank; }; // primary class template class BinaryNode : public BinaryNodeBase { public: BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) : BinaryNodeBase(op_, qinfo_, id_) {} virtual ~BinaryNode() {} virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; }; // partial specialization for rank 0 template class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_) {} virtual ~BinaryNode() {} virtual int eval(); }; #define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \ template \ class Op##Opname : public BinaryNode \ { \ public: \ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ : BinaryNode(Op_##OPNAME, qinfo_, id_) \ { \ register_fcn(); \ } \ static constexpr DType InDtype = Dtype; \ static constexpr DType OutDtype = Dtype; \ using InEigenType = typename GetEigenType::type; \ using OutEigenType = typename GetEigenType::type; \ virtual int register_fcn(); \ }; #define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \ template \ class Op##Opname : public BinaryNode \ { \ public: \ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ : BinaryNode(Op_##OPNAME, qinfo_, id_) \ { \ register_fcn(); \ } \ static constexpr int32_t QMin = GetQMin::value; \ static constexpr int32_t QMax = GetQMax::value; \ using InEigenType = typename GetEigenType::type; \ using OutEigenType = typename GetEigenType::type; \ virtual int register_fcn(); \ }; DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM) DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW) DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB) #undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE #undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE template class OpTable : public GraphNode { public: OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTable(); virtual int checkTensorAttributes(); virtual int eval(); static constexpr DType InDtype = DType_INT16; static constexpr DType TableDtype = DType_INT16; static constexpr DType OutDtype = DType_INT32; using InEigenType = typename GetEigenType::type; using TableEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TTable = Eigen::Tensor; using TOut = Eigen::Tensor; static constexpr int32_t IntegerBits = 9; static constexpr int32_t FractionBits = 7; static constexpr int32_t NumTableEntries = (1 << IntegerBits); static constexpr int32_t QInMin = GetQMin::value; static constexpr int32_t QInMax = GetQMax::value; static constexpr int32_t QOutMin = GetQMin::value; static constexpr int32_t QOutMax = GetQMax::value; protected: TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* table; TosaReference::TensorTemplate* out; }; }; // namespace TosaReference #endif