// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ewise_unary.h" #include "quant_util.h" #include "template_types.h" #include using namespace TosaReference; using namespace Eigen; using namespace tosa; template UnaryNode::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); fcn = [](InEigenType a) -> OutEigenType { ASSERT_MSG(0, "In default UnaryNode function, missing function registration"); return OutEigenType(); }; } template UnaryNode::~UnaryNode() {} template int UnaryNode::checkTensorAttributes() { // Check Tosa Level auto tosa_level = g_func_config.tosa_level; LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK"); if (validateRequiredOperands()) return 1; // output and input must be the same types if (inputs[0]->matchRankTypeShape(*outputs[0])) { printNodeValidationError("UnaryNode: input and output rank/type/shape must match"); return 1; } a = dynamic_cast*>(inputs[0]); result = dynamic_cast*>(outputs[0]); ASSERT_MEM(a && result); return 0; } template int UnaryNode::eval() { this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); return GraphNode::eval(); } template int OpAbs::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP32: // No fpTrunc for FP32 as it is a no-op case TOSA_REF_TYPE_FP64: case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(a > (InEigenType)0 ? a : (-a)); }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpBitwiseNot::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_INT8: case TOSA_REF_TYPE_INT16: case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpCeil::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(ceilf(a)); }; break; case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a) -> OutEigenType { return ceil(a); }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpClz::register_fcn() { int32_t num_bits; switch (Dtype) { case TOSA_REF_TYPE_INT32: num_bits = 32; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } this->fcn = [num_bits](int32_t a) -> int32_t { int32_t leading_zeros = 0; for (int bit = num_bits - 1; bit >= 0; bit--) { if (((a >> bit) & 0x1) == 0) { leading_zeros++; } else { break; } } return leading_zeros; }; return 0; } template int OpCos::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(cos(a)); }; break; case TOSA_REF_TYPE_FP64: if (g_func_config.abs_mode) { // ABS_ERROR bounds return 1.0 this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; } else { this->fcn = [](InEigenType a) -> OutEigenType { return cos(a); }; }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpExp::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(expf(a)); }; break; case TOSA_REF_TYPE_FP64: if (g_func_config.abs_mode) { // ABS_ERROR bounds return (1+abs(a)) this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 + (a > (InEigenType)0 ? a : (-a)); }; } else { this->fcn = [](InEigenType a) -> OutEigenType { return exp(a); }; } break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpFloor::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(floorf(a)); }; break; case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a) -> OutEigenType { return floor(a); }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpLog::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(logf(a)); }; break; case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a) -> OutEigenType { return log(a); }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpLogicalNot::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template OpNegate::OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : UnaryNode(sgt_, Op_NEGATE, id_) { INIT_ATTRIBUTE(Negate); register_fcn(); } template OpNegate::~OpNegate() { if (attribute) delete attribute; } template int OpNegate::register_fcn() { ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); return fpTrunc(result); }; break; case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a) -> OutEigenType { OutEigenType result = -(a); return result; }; break; case TOSA_REF_TYPE_INT16: case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0L - a; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); int64_t max_clip_in_64, min_clip_in_64; if (Dtype == TOSA_REF_TYPE_INT16) { max_clip_in_64 = static_cast(std::numeric_limits::max()); min_clip_in_64 = static_cast(std::numeric_limits::min()); } else { max_clip_in_64 = i32_max_in_64; min_clip_in_64 = i32_min_in_64; } return static_cast( std::min(max_clip_in_64, std::max(min_clip_in_64, res_in_64))); }; break; case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0 - (a - attribute->input1_zp()); int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); res_in_64 += attribute->output_zp(); InEigenType result = static_cast( std::min(std::max(res_in_64, static_cast(QMin)), static_cast(QMax))); return result; }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpReciprocal::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / a); }; break; case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / a); }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpRsqrt::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / sqrtf(a)); }; break; case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / sqrt(a)); }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } template int OpSin::register_fcn() { switch (Dtype) { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(sin(a)); }; break; case TOSA_REF_TYPE_FP64: if (g_func_config.abs_mode) { // ABS_ERROR bounds return 1.0 this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; } else { this->fcn = [](InEigenType a) -> OutEigenType { return sin(a); }; }; break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64);