// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ewise_binary.h" #include "arith_util.h" #include "quant_util.h" #include "template_types.h" using namespace TosaReference; using namespace Eigen; using namespace tosa; template BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) : GraphNode(sgt_, op_, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); a = b = nullptr; result = nullptr; fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; } template BinaryNodeBase::~BinaryNodeBase() {} template int BinaryNodeBase::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) { return 1; } // A & B must be the same rank and types if (inputs[0]->matchRankType(*inputs[1])) { printNodeValidationError("Binary operator input types must match"); return 1; } if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */)) { std::string err = "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match"; printNodeValidationError(err.c_str()); return 1; } if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */)) { std::string err = "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match"; printNodeValidationError(err.c_str()); return 1; } ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match"); a = dynamic_cast*>(inputs[0]); b = dynamic_cast*>(inputs[1]); result = dynamic_cast*>(outputs[0]); ASSERT_MEM(a && b && result); return 0; } template int BinaryNodeBase::broadcast() { const std::vector& a_shape = a->getShape(); const std::vector& b_shape = b->getShape(); const std::vector& output_shape = result->getShape(); for (int i = 0; i < Rank; i++) { bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1; bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1; } return 0; } template int BinaryNode::eval() { this->broadcast(); Eigen::array reshaper; reshaper.fill(1); TIn ia, ib; ia = this->a->getTensor().broadcast(this->bcast_a); ib = this->b->getTensor().broadcast(this->bcast_b); this->result->getTensor() = ia.binaryExpr(ib, this->fcn); return GraphNode::eval(); } // still need to partial specialize this, or Eigen will throw static assertion template int BinaryNode<0, InDtype, OutDtype>::eval() { this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn); return GraphNode::eval(); } template int OpAdd::register_fcn() { switch (InDtype) { case DType_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast(a) + b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range"); return static_cast(res_in_64); }; break; case DType_FP16: case DType_BF16: case DType_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a + b); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template int OpArithmeticRightShift::register_fcn() { bool round = attribute->round(); int32_t num_bits = 0; switch (Dtype) { case DType_INT8: num_bits = 8; break; case DType_INT16: num_bits = 16; break; case DType_INT32: num_bits = 32; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", (int32_t)b, num_bits); InEigenType acc = a >> b; if (round && b > 0 && (a >> (b - 1) & 1) != 0) { acc++; } return acc; }; return 0; } template OpArithmeticRightShift::~OpArithmeticRightShift() { if (attribute) delete attribute; } template int OpBitwiseAnd::register_fcn() { switch (Dtype) { case DType_INT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpBitwiseOr::register_fcn() { switch (Dtype) { case DType_INT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpBitwiseXor::register_fcn() { switch (Dtype) { case DType_INT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpIntdiv::register_fcn() { switch (InDtype) { case DType_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value"); int64_t res_in_64 = static_cast(a) / b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range"); return static_cast(res_in_64); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template int OpLogicalAnd::register_fcn() { switch (Dtype) { case DType_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalLeftShift::register_fcn() { switch (Dtype) { case DType_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a << b)); }; break; case DType_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a << b)); }; break; case DType_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a << b)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalRightShift::register_fcn() { switch (Dtype) { case DType_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a) >> b); }; break; case DType_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a) >> b); }; break; case DType_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a) >> b); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalOr::register_fcn() { switch (Dtype) { case DType_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalXor::register_fcn() { switch (Dtype) { case DType_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpMaximum::register_fcn() { switch (Dtype) { case DType_FP16: case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpMinimum::register_fcn() { switch (Dtype) { case DType_FP16: case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpMul::register_fcn() { int32_t shift = attribute->shift(); switch (InDtype) { case DType_FP16: case DType_BF16: case DType_FP32: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a * b); }; break; case DType_INT32: this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { int64_t result; if (shift > 0) { int64_t round = 1L << (shift - 1); result = static_cast(a) * static_cast(b) + round; result = result >> shift; REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax); } else { result = static_cast(a) * b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); return static_cast(result); } return static_cast(result); }; break; case DType_INT8: case DType_INT16: this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType { OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs; OutEigenType clamped_output = std::min(QMax, std::max(raw_output, QMin)); return clamped_output; }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template OpMul::~OpMul() { if (attribute) delete attribute; } template int OpPow::register_fcn() { switch (Dtype) { case DType_FP16: case DType_BF16: case DType_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(powf(a, b)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpSub::register_fcn() { switch (InDtype) { case DType_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast(a) - b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range"); return static_cast(res_in_64); }; break; case DType_FP16: case DType_BF16: case DType_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a - b); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template OpTable::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TABLE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); INIT_ATTRIBUTE(Table); } template OpTable::~OpTable() { if (attribute) delete attribute; } template int OpTable::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) { return 1; } ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type"); ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type"); ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries); for (uint32_t i = 0; i < TableNumEntries; i++) { table[i] = (TableEigenType)attribute->table()[i]; } in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); ASSERT_MEM(in && out); return 0; } template int OpTable::eval() { switch (InDtype) { case DType_INT8: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); int32_t index = input_truncated - QInMin; int32_t value = table[index]; return value; }); break; case DType_INT16: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { // 1. make sure input is int16 range int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); // 2. calculate index and interpolation fraction int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1)); index = std::min(std::max(index, 0), NumTableEntries - 1); // 9-bit index int32_t frac = (input_truncated)&0x7F; // 7-bit fraction // 3. Add REQUIRE CHECK for extreme large/small slopes int32_t base = table[index]; int32_t next = table[index + 1]; int32_t slope = next - base; REQUIRE(slope <= std::numeric_limits::max() && slope >= std::numeric_limits::min(), "OpTable: slope out of int16_t range"); // 4. interpolate, generate 16.7 (23-bit) output int32_t value = (base << 7) + (slope) * frac; return value; }); break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return GraphNode::eval(); } // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BOOL, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIntdiv, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); // Instantiation of nodes for comparison operators opEqual, opGreater // and opGreaterEqual DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);