// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ewise_binary.h" #include "arith_util.h" #include "quant_util.h" #include "template_types.h" using namespace TosaReference; using namespace Eigen; using namespace tosa; template BinaryNodeBase::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(op_, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); a_rank = b_rank = max_input_rank = -1; a = b = nullptr; a_rank0 = b_rank0 = nullptr; result = nullptr; fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; } template BinaryNodeBase::~BinaryNodeBase() {} template int BinaryNodeBase::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) { return 1; } a_rank = inputs[0]->getRank(); b_rank = inputs[1]->getRank(); if (a_rank != 0 && b_rank != 0 && a_rank != b_rank) { printNodeValidationError("Binary operator input ranks must match"); return 1; } max_input_rank = a_rank >= b_rank ? a_rank : b_rank; // A & B must be the same types if (inputs[0]->matchType(*inputs[1])) { printNodeValidationError("Binary operator input types must match"); return 1; } // Result's geometry must match, but the type may be wider if (outputs[0]->getRank() != max_input_rank) { printNodeValidationError("Binary operator input and output genometry must match"); return 1; } if (a_rank == max_input_rank) { a = dynamic_cast*>(inputs[0]); } else { a_rank0 = dynamic_cast>*>(inputs[0]); } if (b_rank == max_input_rank) { b = dynamic_cast*>(inputs[1]); } else { b_rank0 = dynamic_cast>*>(inputs[1]); } result = dynamic_cast*>(outputs[0]); // either a or b can be rank0 // a_rank0 and b_rank0 can't be valid at the same time. // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0' ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result); return 0; } template int BinaryNodeBase::broadcast() { auto output_shape = result->getTensor().dimensions(); std::vector a_shape, b_shape; if (a_rank == max_input_rank) { a_shape = a->getShape(); } else { a_shape.assign(max_input_rank, 1); } if (b_rank == max_input_rank) { b_shape = b->getShape(); } else { b_shape.assign(max_input_rank, 1); } for (int i = 0; i < max_input_rank; i++) { if (a_shape[i] != output_shape[i] && a_shape[i] == 1) { bcast_a[i] = output_shape[i]; } else { bcast_a[i] = 1; } if (b_shape[i] != output_shape[i] && b_shape[i] == 1) { bcast_b[i] = output_shape[i]; } else { bcast_b[i] = 1; } } return 0; } template int BinaryNode::eval() { this->broadcast(); Eigen::array reshaper; reshaper.fill(1); TIn ia, ib; if (this->a_rank == this->max_input_rank) { ia = this->a->getTensor().broadcast(this->bcast_a); } else { ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a); } if (this->b_rank == this->max_input_rank) { ib = this->b->getTensor().broadcast(this->bcast_b); } else { ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b); } this->result->getTensor() = ia.binaryExpr(ib, this->fcn); return GraphNode::eval(); } // still need to partial specialize this, or Eigen will throw static assertion template int BinaryNode<0, InDtype, OutDtype>::eval() { this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn); return GraphNode::eval(); } template int OpAdd::register_fcn() { switch (InDtype) { case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template int OpArithmeticRightShift::register_fcn() { int32_t num_bits = 0; switch (Dtype) { case DType_INT8: num_bits = 8; break; case DType_INT16: num_bits = 16; break; case DType_INT32: num_bits = 32; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { uint32_t sign = a & (1 << (num_bits - 1)); uint32_t ones_mask = ONES_MASK(b) << (num_bits - b); if (sign) return ones_mask | (a >> b); else return (~ones_mask) & (a >> b); }; return 0; } template int OpBitwiseAnd::register_fcn() { switch (Dtype) { case DType_AINT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpBitwiseOr::register_fcn() { switch (Dtype) { case DType_AINT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpBitwiseXor::register_fcn() { switch (Dtype) { case DType_AINT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalAnd::register_fcn() { switch (Dtype) { case DType_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalLeftShift::register_fcn() { switch (Dtype) { case DType_INT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalRightShift::register_fcn() { int32_t num_bits = 0; switch (Dtype) { case DType_INT8: num_bits = 8; break; case DType_INT16: num_bits = 16; break; case DType_INT32: num_bits = 32; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { uint32_t mask = ONES_MASK(num_bits) >> b; return (a >> b) & mask; }; return 0; } template int OpLogicalOr::register_fcn() { switch (Dtype) { case DType_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalXor::register_fcn() { switch (Dtype) { case DType_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpMaximum::register_fcn() { switch (Dtype) { case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpMinimum::register_fcn() { switch (Dtype) { case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpMul::register_fcn() { switch (InDtype) { case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; break; case DType_INT8: case DType_INT16: this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType { OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs; OutEigenType clamped_output = std::min(QMax, std::max(raw_output, QMin)); return clamped_output; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template int OpPow::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpSub::register_fcn() { switch (InDtype) { case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template OpTable::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_TABLE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); } template OpTable::~OpTable() {} template int OpTable::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) { return 1; } if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) { FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries"); return 1; } in = dynamic_cast*>(inputs[0]); table = dynamic_cast*>(inputs[1]); out = dynamic_cast*>(outputs[0]); ASSERT_MEM(in && table && out); return 0; } template int OpTable::eval() { this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { // 1. make sure input is int16 range int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); // 2. calculate index and interpolation fraction int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1)); index = std::min(std::max(index, 0), NumTableEntries - 1); // 9-bit index int32_t frac = (input_truncated)&0x7F; // 7-bit fraction // 3. interpolate, generate 16.7 (23-bit) output int32_t base = this->table->getTensor()(index); int32_t next = this->table->getTensor()(index + 1); int32_t value = (base << 7) + (next - base) * frac; return value; }); return GraphNode::eval(); } // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); DEF_INSTANTIATE_ONE_RANK_0_6(OpTable); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);