diff options
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 0f38056..8b83a50 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -231,28 +231,32 @@ int OpNegate<Rank, Dtype>::register_fcn() case DType_INT32: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0L - a; - int64_t max_in_64, min_in_64; - if (Dtype == DType_INT16) { - max_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::max()); - min_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::min()); + int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min()); + REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); + + int64_t max_clip_in_64, min_clip_in_64; + if (Dtype == DType_INT16) + { + max_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::max()); + min_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::min()); } else { - max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); - min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min()); + max_clip_in_64 = i32_max_in_64; + min_clip_in_64 = i32_min_in_64; } - REQUIRE(res_in_64 <= max_in_64 && res_in_64 >= min_in_64, "OpNegate: result not in input type range"); - return static_cast<InEigenType>(res_in_64); + return static_cast<InEigenType>(std::min<int64_t>(max_clip_in_64, std::max<int64_t>(min_clip_in_64, res_in_64))); }; break; case DType_INT8: this->fcn = [this](InEigenType a) -> OutEigenType { - int32_t res_in_32 = 0 - (a - this->qinfo->input_zp()); - int32_t max_in_32 = static_cast<int32_t>(std::numeric_limits<int8_t>::max()); - int32_t min_in_32 = static_cast<int32_t>(std::numeric_limits<int8_t>::min()); - REQUIRE(res_in_32 <= max_in_32 && res_in_32 >= min_in_32, "OpNegate: result not in i8 range"); - res_in_32 += this->qinfo->output_zp(); - InEigenType result = static_cast<InEigenType>(std::min(std::max(res_in_32, static_cast<int32_t>(QMin)), static_cast<int32_t>(QMax))); + int64_t res_in_64 = 0 - (a - this->qinfo->input_zp()); + int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min()); + REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); + res_in_64 += this->qinfo->output_zp(); + InEigenType result = static_cast<InEigenType>(std::min(std::max(res_in_64, static_cast<int64_t>(QMin)), static_cast<int64_t>(QMax))); return result; }; break; |