diff options
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 29 |
1 files changed, 25 insertions, 4 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index d6a95e1..c33f646 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -201,8 +201,16 @@ int OpAdd<Rank, Dtype>::register_fcn() { switch (InDtype) { - case DType_FLOAT: case DType_INT32: + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + int64_t res_in_64 = static_cast<int64_t>(a) + b; + int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min()); + REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpAdd: result not in i32 range"); + return static_cast<InEigenType>(res_in_64); + }; + break; + case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; default: @@ -310,7 +318,8 @@ int OpIntdiv<Rank, Dtype>::register_fcn() REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value"); int64_t res_in_64 = static_cast<int64_t>(a) / b; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); - REQUIRE(a <= i32_max_in_64, "OpIntDiv: result not in i32 range"); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min()); + REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpIntDiv: result not in i32 range"); return static_cast<InEigenType>(res_in_64); }; break; @@ -466,7 +475,11 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() } else { - result = a * b; + result = static_cast<int64_t>(a) * b; + int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min()); + REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); + return static_cast<InEigenType>(result); } return static_cast<OutEigenType>(result); @@ -509,8 +522,16 @@ int OpSub<Rank, Dtype>::register_fcn() { switch (InDtype) { - case DType_FLOAT: case DType_INT32: + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + int64_t res_in_64 = static_cast<int64_t>(a) - b; + int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min()); + REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpSub: result not in i32 range"); + return static_cast<InEigenType>(res_in_64); + }; + break; + case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; default: |