aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py22
1 files changed, 5 insertions, 17 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 9f02489..b1f8942 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -563,27 +563,15 @@ class TosaTensorValuesGen:
@staticmethod
def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
- if dtypeList[0] != DType.FLOAT and error_name is None:
+ if dtypeList[0] == DType.INT32 and error_name is None:
pCount, cCount = op["operands"]
assert (
pCount == 1 and cCount == 0
), "Op.NEGATE must have 1 placeholders, 0 consts"
- # Must create tensors with values within negatable ranges
- if dtypeList[0] == DType.INT8:
- # Must be within int8, adjustable by input_zp and then negatable
- # and be within int8
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- max_val = min(127, 127 + qinfo.ints[0][1])
- min_val = max(-127, -127 + qinfo.ints[0][1])
- elif dtypeList[0] == DType.INT16:
- max_val = 32767
- min_val = -max_val
- else:
- assert (
- dtypeList[0] == DType.INT32
- ), "Op.NEGATE found with unsupported input type"
- max_val = (1 << 31) - 1
- min_val = -max_val
+ # Must create tensors with values within accumulator (int32) negatable
+ # range
+ max_val = (1 << 31) - 1
+ min_val = -max_val
arr = np.int32(
testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
)