diff options
Diffstat (limited to 'reference_model/src/ops/ewise_unary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 33 |
1 files changed, 19 insertions, 14 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 95a1102..041bbdb 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -23,8 +23,8 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType Dtype> -UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_) - : GraphNode(op_, id_) +UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) + : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -80,7 +80,7 @@ int OpAbs<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -97,7 +97,7 @@ int OpBitwiseNot<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -112,7 +112,7 @@ int OpCeil<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -128,7 +128,7 @@ int OpClz<Rank, Dtype>::register_fcn() num_bits = 32; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](int32_t a) -> int32_t { @@ -159,7 +159,7 @@ int OpExp<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -174,7 +174,7 @@ int OpFloor<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -189,7 +189,7 @@ int OpLog<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -204,7 +204,7 @@ int OpLogicalNot<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -213,6 +213,12 @@ int OpLogicalNot<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpNegate<Rank, Dtype>::register_fcn() { + if (Dtype != DType_INT8 && this->qinfo) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); + } + switch (Dtype) { case DType_FLOAT: @@ -229,7 +235,6 @@ int OpNegate<Rank, Dtype>::register_fcn() }; break; case DType_INT8: - ASSERT(this->qinfo); this->fcn = [this](InEigenType a) -> OutEigenType { InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp(); result = std::min(std::max(result, static_cast<InEigenType>(QMin)), static_cast<InEigenType>(QMax)); @@ -237,7 +242,7 @@ int OpNegate<Rank, Dtype>::register_fcn() }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -252,7 +257,7 @@ int OpReciprocal<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -267,7 +272,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; |