// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ewise_unary.h" #include "quant_util.h" #include "template_types.h" #include using namespace TosaReference; using namespace Eigen; using namespace tosa; template UnaryNode::UnaryNode(const Op& op_, uint64_t id_) : GraphNode(op_, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); }; } template UnaryNode::~UnaryNode() {} template int UnaryNode::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) { return 1; } // output and input must be the same types if (inputs[0]->matchRankSize(*outputs[0])) { printNodeValidationError("UnaryNode: input and output rank must match"); return 1; } a = dynamic_cast*>(inputs[0]); result = dynamic_cast*>(outputs[0]); ASSERT_MEM(a && result); return 0; } template int UnaryNode::eval() { this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); return GraphNode::eval(); } template int OpAbs::register_fcn() { switch (Dtype) { case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpBitwiseNot::register_fcn() { switch (Dtype) { case DType_AINT8: case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpCeil::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpClz::register_fcn() { int32_t num_bits; switch (Dtype) { case DType_INT32: num_bits = 32; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](int32_t a) -> int32_t { int32_t leading_zeros = 0; for (int bit = num_bits - 1; bit >= 0; bit--) { if (((a >> bit) & 0x1) == 0) { leading_zeros++; } else { break; } } return leading_zeros; }; return 0; } template int OpExp::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpFloor::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLog::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpLogicalNot::register_fcn() { switch (Dtype) { case DType_BOOL: this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpNegate::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); return result; }; break; case DType_INT16: case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); return result; }; break; case DType_AINT8: ASSERT(this->qinfo); this->fcn = [this](InEigenType a) -> OutEigenType { InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp(); return result; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpReciprocal::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } template int OpRsqrt::register_fcn() { switch (Dtype) { case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; break; default: FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; } // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);