aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-01-18 16:23:20 +0000
committerEric Kunze <eric.kunze@arm.com>2023-01-19 17:28:19 +0000
commiteef866729ee6fe3a6972361afca3e2cda3b162b1 (patch)
treed2c04c60b6eddaa453769058c7b54b1d97c1bc81
parent3407125ce2a470216bd54f8bb3ab5216f617c1be (diff)
downloadreference_model-eef866729ee6fe3a6972361afca3e2cda3b162b1.tar.gz
Fix for sign extending LOGICAL LEFT/RIGHT SHIFT results
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I04261178694c004409aef2ff5c84c32b04729433
-rw-r--r--reference_model/src/ops/ewise_binary.cc53
1 files changed, 31 insertions, 22 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 3abf961..c697db0 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -288,27 +288,32 @@ int OpLogicalAnd<Rank, Dtype>::register_fcn()
template <int Rank, DType Dtype>
int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
{
- int32_t num_bits = 0;
switch (Dtype)
{
case DType_INT8:
- num_bits = 8;
+ this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+ REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
+ (int32_t)b);
+ return static_cast<OutEigenType>(static_cast<int8_t>(a << b));
+ };
break;
case DType_INT16:
- num_bits = 16;
+ this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+ REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
+ (int32_t)b);
+ return static_cast<OutEigenType>(static_cast<int16_t>(a << b));
+ };
break;
case DType_INT32:
- num_bits = 32;
+ this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+ REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
+ (int32_t)b);
+ return static_cast<OutEigenType>(static_cast<int32_t>(a << b));
+ };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
- this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
- uint32_t mask = ONES_MASK(num_bits);
- REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
- (int32_t)b);
- return (a << b) & mask;
- };
return 0;
}
@@ -316,29 +321,33 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
template <int Rank, DType Dtype>
int OpLogicalRightShift<Rank, Dtype>::register_fcn()
{
- int32_t num_bits = 0;
switch (Dtype)
{
case DType_INT8:
- num_bits = 8;
+ this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+ REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
+ (int32_t)b);
+ return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b);
+ };
break;
case DType_INT16:
- num_bits = 16;
+ this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+ REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
+ (int32_t)b);
+ return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b);
+ };
break;
case DType_INT32:
- num_bits = 32;
+ this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+ REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
+ (int32_t)b);
+ return static_cast<OutEigenType>(static_cast<int32_t>(a) >> b);
+ };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
- this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
- uint32_t mask = ONES_MASK(num_bits) >> b;
- REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
- (int32_t)b);
- return (a >> b) & mask;
- };
-
return 0;
}