diff options
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 53 |
1 files changed, 31 insertions, 22 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 3abf961..c697db0 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -288,27 +288,32 @@ int OpLogicalAnd<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpLogicalLeftShift<Rank, Dtype>::register_fcn() { - int32_t num_bits = 0; switch (Dtype) { case DType_INT8: - num_bits = 8; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int8_t>(a << b)); + }; break; case DType_INT16: - num_bits = 16; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int16_t>(a << b)); + }; break; case DType_INT32: - num_bits = 32; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int32_t>(a << b)); + }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t mask = ONES_MASK(num_bits); - REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); - return (a << b) & mask; - }; return 0; } @@ -316,29 +321,33 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpLogicalRightShift<Rank, Dtype>::register_fcn() { - int32_t num_bits = 0; switch (Dtype) { case DType_INT8: - num_bits = 8; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b); + }; break; case DType_INT16: - num_bits = 16; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b); + }; break; case DType_INT32: - num_bits = 32; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int32_t>(a) >> b); + }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t mask = ONES_MASK(num_bits) >> b; - REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); - return (a >> b) & mask; - }; - return 0; } |