aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_unary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_unary.cc')
-rw-r--r--reference_model/src/ops/ewise_unary.cc302
1 files changed, 302 insertions, 0 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
new file mode 100644
index 0000000..d7bddc0
--- /dev/null
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -0,0 +1,302 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_unary.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::~UnaryNode()
+{}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("UnaryNode: input and output rank must match");
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(a && result);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::eval()
+{
+ this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAbs<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpCeil<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpClz<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits;
+ switch (Dtype)
+ {
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](int32_t a) -> int32_t {
+ int32_t leading_zeros = 0;
+ for (int bit = num_bits - 1; bit >= 0; bit--)
+ {
+ if (((a >> bit) & 0x1) == 0)
+ {
+ leading_zeros++;
+ }
+ else
+ {
+ break;
+ }
+ }
+ return leading_zeros;
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpExp<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpFloor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLog<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpNegate<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_AINT8:
+ ASSERT(this->qinfo);
+ this->fcn = [this](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp();
+ return result;
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReciprocal<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpRsqrt<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);