// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef OPS_EWISE_BINARY_H #define OPS_EWISE_BINARY_H #include "graph_node.h" using namespace tosa; namespace TosaReference { // class BinaryNodeBase: hold common functions of all the binary nodes // when an binary op is created, the virtual OpXXX::register_fcn() will be called // and 'fcn' will be register with lambda function which has two inputs // class BinaryNode: the level of indirection to partially specialize template for rank 0 // eval() from toplevel called should call the .binaryExpr(dims, fcn) here // this needs to be partially specialize or // compiler will statically fail when trying to broadcast rank0 tensor // class OpXXX: implement per-element lambda function based on different data type // unlike BinaryNode, this doesn't need to be partially specialized // Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.) // which might be faster since it could be implemented with SIMD instructions // the way of registering lambda + .binaryExpr() might sacrifice performance here // but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...} // needs to revisit if performance becomes a bottleneck here template class BinaryNodeBase : public GraphNode { public: BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_); virtual ~BinaryNodeBase(); virtual int checkTensorAttributes() final; virtual int eval() = 0; virtual int register_fcn() = 0; using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; protected: int broadcast(); protected: std::function fcn; Eigen::array bcast_a; Eigen::array bcast_b; TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; TosaReference::TensorTemplate* result; }; // primary class template class BinaryNode : public BinaryNodeBase { public: BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_) : BinaryNodeBase(sgt_, op_, id_) {} virtual ~BinaryNode() {} virtual int eval(); using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; }; // partial specialization for rank 0 template class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_) : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, id_) {} virtual ~BinaryNode() {} virtual int eval(); }; #define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \ template \ class Op##Opname : public BinaryNode \ { \ public: \ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ : BinaryNode(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ static constexpr DType InDtype = Dtype; \ static constexpr DType OutDtype = Dtype; \ using InEigenType = typename GetEigenType::type; \ using OutEigenType = typename GetEigenType::type; \ virtual int register_fcn(); \ }; DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD) DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND) DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR) DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR) DEF_TEMPLATE_BINARY_OP_DEFAULT(Intdiv, INTDIV) DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND) DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT) DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT) DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR) DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR) DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM) DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM) DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW) DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB) #undef DEF_TEMPLATE_BINARY_OP_DEFAULT template class OpArithmeticRightShift : public BinaryNode { public: OpArithmeticRightShift(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : BinaryNode(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, id_) { INIT_ATTRIBUTE(ArithmeticRightShift); register_fcn(); } using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; virtual int register_fcn(); protected: TosaArithmeticRightShiftAttribute* attribute; }; template class OpMul : public BinaryNode { public: OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : BinaryNode(sgt_, Op_MUL, id_) { INIT_ATTRIBUTE(Mul); register_fcn(); } static constexpr int64_t QMin = GetQMin::value; static constexpr int64_t QMax = GetQMax::value; using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; virtual int register_fcn(); protected: TosaMulAttribute* attribute; }; template class OpTable : public GraphNode { public: OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTable(); virtual int checkTensorAttributes(); virtual int eval(); static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513; using InEigenType = typename GetEigenType::type; using TableEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TTable = Eigen::Tensor; using TOut = Eigen::Tensor; static constexpr int32_t IntegerBits = 9; static constexpr int32_t FractionBits = 7; static constexpr int32_t NumTableEntries = (1 << IntegerBits); static constexpr int32_t QInMin = GetQMin::value; static constexpr int32_t QInMax = GetQMax::value; static constexpr int32_t QOutMin = GetQMin::value; static constexpr int32_t QOutMax = GetQMax::value; protected: TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; TosaTableAttribute* attribute; std::array table; }; }; // namespace TosaReference #endif