diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-01-18 16:23:20 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-01-19 17:28:19 +0000 |
commit | eef866729ee6fe3a6972361afca3e2cda3b162b1 (patch) | |
tree | d2c04c60b6eddaa453769058c7b54b1d97c1bc81 | |
parent | 3407125ce2a470216bd54f8bb3ab5216f617c1be (diff) | |
download | reference_model-eef866729ee6fe3a6972361afca3e2cda3b162b1.tar.gz |
Fix for sign extending LOGICAL LEFT/RIGHT SHIFT results
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I04261178694c004409aef2ff5c84c32b04729433
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 53 |
1 files changed, 31 insertions, 22 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 3abf961..c697db0 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -288,27 +288,32 @@ int OpLogicalAnd<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpLogicalLeftShift<Rank, Dtype>::register_fcn() { - int32_t num_bits = 0; switch (Dtype) { case DType_INT8: - num_bits = 8; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int8_t>(a << b)); + }; break; case DType_INT16: - num_bits = 16; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int16_t>(a << b)); + }; break; case DType_INT32: - num_bits = 32; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int32_t>(a << b)); + }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t mask = ONES_MASK(num_bits); - REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); - return (a << b) & mask; - }; return 0; } @@ -316,29 +321,33 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpLogicalRightShift<Rank, Dtype>::register_fcn() { - int32_t num_bits = 0; switch (Dtype) { case DType_INT8: - num_bits = 8; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b); + }; break; case DType_INT16: - num_bits = 16; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b); + }; break; case DType_INT32: - num_bits = 32; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast<OutEigenType>(static_cast<int32_t>(a) >> b); + }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t mask = ONES_MASK(num_bits) >> b; - REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); - return (a >> b) & mask; - }; - return 0; } |