From eef866729ee6fe3a6972361afca3e2cda3b162b1 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 18 Jan 2023 16:23:20 +0000 Subject: Fix for sign extending LOGICAL LEFT/RIGHT SHIFT results Signed-off-by: Jeremy Johnson Change-Id: I04261178694c004409aef2ff5c84c32b04729433 --- reference_model/src/ops/ewise_binary.cc | 53 +++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 3abf961..c697db0 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -288,27 +288,32 @@ int OpLogicalAnd::register_fcn() template int OpLogicalLeftShift::register_fcn() { - int32_t num_bits = 0; switch (Dtype) { case DType_INT8: - num_bits = 8; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast(static_cast(a << b)); + }; break; case DType_INT16: - num_bits = 16; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast(static_cast(a << b)); + }; break; case DType_INT32: - num_bits = 32; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast(static_cast(a << b)); + }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t mask = ONES_MASK(num_bits); - REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); - return (a << b) & mask; - }; return 0; } @@ -316,29 +321,33 @@ int OpLogicalLeftShift::register_fcn() template int OpLogicalRightShift::register_fcn() { - int32_t num_bits = 0; switch (Dtype) { case DType_INT8: - num_bits = 8; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast(static_cast(a) >> b); + }; break; case DType_INT16: - num_bits = 16; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast(static_cast(a) >> b); + }; break; case DType_INT32: - num_bits = 32; + this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return static_cast(static_cast(a) >> b); + }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t mask = ONES_MASK(num_bits) >> b; - REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); - return (a >> b) & mask; - }; - return 0; } -- cgit v1.2.1