diff options
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 17 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 14 |
2 files changed, 29 insertions, 2 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 287ad92..7f30e30 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -281,16 +281,27 @@ int OpLogicalAnd<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpLogicalLeftShift<Rank, Dtype>::register_fcn() { + int32_t num_bits = 0; switch (Dtype) { case DType_INT8: + num_bits = 8; + break; case DType_INT16: + num_bits = 16; + break; case DType_INT32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; }; + num_bits = 32; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } + this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { + uint32_t mask = ONES_MASK(num_bits); + REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); + return (a << b) & mask; + }; return 0; } @@ -314,8 +325,10 @@ int OpLogicalRightShift<Rank, Dtype>::register_fcn() ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { + this->fcn = [this, num_bits](InEigenType a, InEigenType b) -> OutEigenType { uint32_t mask = ONES_MASK(num_bits) >> b; + REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", + (int32_t)b); return (a >> b) & mask; }; diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 239a64e..cb97acb 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -5952,6 +5952,20 @@ class TosaTestGen: self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count]) ) tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:])) + elif op["op"] == Op.LOGICAL_LEFT_SHIFT or op["op"] == Op.LOGICAL_RIGHT_SHIFT: + assert ( + pCount == 2 and cCount == 0 + ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts" + values_arr = self.getRandTensor(shapeList[0], dtypeList[0]) + shift_arr = np.int32(self.rng.integers(low=0, high=32, size=shapeList[1])) + placeholders = [] + placeholders.append( + self.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr) + ) + placeholders.append( + self.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr) + ) + tens.extend(placeholders) else: tens.extend( self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount]) |