diff options
Diffstat (limited to 'reference_model/src/ops/ewise_unary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 302 |
1 files changed, 302 insertions, 0 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc new file mode 100644 index 0000000..d7bddc0 --- /dev/null +++ b/reference_model/src/ops/ewise_unary.cc @@ -0,0 +1,302 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_unary.h" +#include "quant_util.h" +#include "template_types.h" +#include <cmath> + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_) + : GraphNode(op_, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); }; +} + +template <int Rank, DType Dtype> +UnaryNode<Rank, Dtype>::~UnaryNode() +{} + +template <int Rank, DType Dtype> +int UnaryNode<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankSize(*outputs[0])) + { + printNodeValidationError("UnaryNode: input and output rank must match"); + return 1; + } + + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(a && result); + + return 0; +} + +template <int Rank, DType Dtype> +int UnaryNode<Rank, Dtype>::eval() +{ + this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpAbs<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpBitwiseNot<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpCeil<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpClz<Rank, Dtype>::register_fcn() +{ + int32_t num_bits; + switch (Dtype) + { + case DType_INT32: + num_bits = 32; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + this->fcn = [num_bits](int32_t a) -> int32_t { + int32_t leading_zeros = 0; + for (int bit = num_bits - 1; bit >= 0; bit--) + { + if (((a >> bit) & 0x1) == 0) + { + leading_zeros++; + } + else + { + break; + } + } + return leading_zeros; + }; + + return 0; +} + +template <int Rank, DType Dtype> +int OpExp<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpFloor<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLog<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLogicalNot<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpNegate<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { + InEigenType result = -(a); + return result; + }; + break; + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { + InEigenType result = -(a); + return result; + }; + break; + case DType_AINT8: + ASSERT(this->qinfo); + this->fcn = [this](InEigenType a) -> OutEigenType { + InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp(); + return result; + }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpReciprocal<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpRsqrt<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); |