aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py38
1 files changed, 34 insertions, 4 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 83081ee..b1f9938 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -5704,13 +5704,15 @@ class TosaTestGen:
# Build the random tensor operands and the test
tens = []
- tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
-
if qgen is not None:
qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
else:
qinfo = None
+ tens = self.generate_tensors(
+ op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
try:
if error_if_validators is None:
if qinfo is not None:
@@ -5748,11 +5750,39 @@ class TosaTestGen:
# The test is not valid
print(f"Invalid ERROR_IF test created: {opName} {testStr}")
- def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
+ def generate_tensors(
+ self, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ ):
pCount, cCount = op["operands"]
tens = []
- if (
+ if op["op"] == Op.NEGATE and dtypeList[0] != DType.FLOAT and error_name is None:
+ assert (
+ pCount == 1 and cCount == 0
+ ), "Op.NEGATE must have 1 placeholders, 0 consts"
+ # Must create tensors with values within negatable ranges
+ if dtypeList[0] == DType.INT8:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ max_val = 127 + qinfo.ints[0][1]
+ min_val = -127 + qinfo.ints[0][1]
+ elif dtypeList[0] == DType.INT16:
+ max_val = 32767
+ min_val = -max_val
+ else:
+ assert (
+ dtypeList[0] == DType.INT32
+ ), "Op.NEGATE found with unsupported input type"
+ max_val = (1 << 31) - 1
+ min_val = -max_val
+ arr = np.int32(
+ self.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
+ )
+ placeholders = []
+ placeholders.append(
+ self.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
+ )
+ tens.extend(placeholders)
+ elif (
(op["op"] == Op.ADD or op["op"] == Op.SUB)
and dtypeList[0] == DType.INT32
and error_name is None