diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-03-23 15:32:34 +0000 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-03-24 09:13:35 +0000 |
commit | 81ee53d65a6e3e7d454eda967e6f9f157cae69f1 (patch) | |
tree | a35d19a9fb64345f561277880fe476a2491e6dcb /verif/generator/tosa_test_gen.py | |
parent | 25669b31bae45b16d4e96ec13fa9cdeb417975f6 (diff) | |
download | reference_model-81ee53d65a6e3e7d454eda967e6f9f157cae69f1.tar.gz |
Add missing REQUIRE to NEGATE op
And update test generation to create values in predictable range
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I4ba1ff445bf6caeec9f8782902fc45929fe0ee77
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 83081ee..b1f9938 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -5704,13 +5704,15 @@ class TosaTestGen: # Build the random tensor operands and the test tens = [] - tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name) - if qgen is not None: qinfo = qgen(self, op, dtype_or_dtypeList, error_name) else: qinfo = None + tens = self.generate_tensors( + op, dtypeList, shapeList, testArgs, qinfo, error_name + ) + try: if error_if_validators is None: if qinfo is not None: @@ -5748,11 +5750,39 @@ class TosaTestGen: # The test is not valid print(f"Invalid ERROR_IF test created: {opName} {testStr}") - def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None): + def generate_tensors( + self, op, dtypeList, shapeList, testArgs, qinfo, error_name=None + ): pCount, cCount = op["operands"] tens = [] - if ( + if op["op"] == Op.NEGATE and dtypeList[0] != DType.FLOAT and error_name is None: + assert ( + pCount == 1 and cCount == 0 + ), "Op.NEGATE must have 1 placeholders, 0 consts" + # Must create tensors with values within negatable ranges + if dtypeList[0] == DType.INT8: + # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp + max_val = 127 + qinfo.ints[0][1] + min_val = -127 + qinfo.ints[0][1] + elif dtypeList[0] == DType.INT16: + max_val = 32767 + min_val = -max_val + else: + assert ( + dtypeList[0] == DType.INT32 + ), "Op.NEGATE found with unsupported input type" + max_val = (1 << 31) - 1 + min_val = -max_val + arr = np.int32( + self.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0]) + ) + placeholders = [] + placeholders.append( + self.ser.addPlaceholder(shapeList[0], dtypeList[0], arr) + ) + tens.extend(placeholders) + elif ( (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name is None |