aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_unary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_unary.cc')
-rw-r--r--reference_model/src/ops/ewise_unary.cc34
1 files changed, 26 insertions, 8 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 8b83a50..8ef1e3c 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -29,7 +29,10 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
- fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
+ fcn = [](InEigenType a) -> OutEigenType {
+ ASSERT_MSG(0, "In default UnaryNode function, missing function registration");
+ return OutEigenType();
+ };
}
template <int Rank, DType Dtype>
@@ -211,13 +214,28 @@ int OpLogicalNot<Rank, Dtype>::register_fcn()
}
template <int Rank, DType Dtype>
+OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
+ : UnaryNode<Rank, Dtype>(sgt_, Op_NEGATE, id_)
+{
+ INIT_ATTRIBUTE(Negate);
+
+ register_fcn();
+}
+
+template <int Rank, DType Dtype>
+OpNegate<Rank, Dtype>::~OpNegate()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
int OpNegate<Rank, Dtype>::register_fcn()
{
- if (Dtype != DType_INT8 && this->qinfo)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t");
- ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
- }
+ ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
switch (Dtype)
{
@@ -251,11 +269,11 @@ int OpNegate<Rank, Dtype>::register_fcn()
break;
case DType_INT8:
this->fcn = [this](InEigenType a) -> OutEigenType {
- int64_t res_in_64 = 0 - (a - this->qinfo->input_zp());
+ int64_t res_in_64 = 0 - (a - attribute->input1_zp());
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)");
- res_in_64 += this->qinfo->output_zp();
+ res_in_64 += attribute->output_zp();
InEigenType result = static_cast<InEigenType>(std::min(std::max(res_in_64, static_cast<int64_t>(QMin)), static_cast<int64_t>(QMax)));
return result;
};