diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-05-03 12:10:23 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-05-04 15:41:44 +0000 |
commit | 0e46364dfea65a7898639dd381250014ccca3efa (patch) | |
tree | b87868f377fa8a3391dcd11910374e208f1eb80a /verif/generator/tosa_arg_gen.py | |
parent | 5860df6fb7dfe4850133c213fe3276cdfb740baa (diff) | |
download | reference_model-0e46364dfea65a7898639dd381250014ccca3efa.tar.gz |
Fix for NEGATE using 32-bit accumulator
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ie5d119dc317303a0d2a71d018ac94ce6800ecbf5
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 22 |
1 files changed, 5 insertions, 17 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 9f02489..b1f8942 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -563,27 +563,15 @@ class TosaTensorValuesGen: @staticmethod def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): - if dtypeList[0] != DType.FLOAT and error_name is None: + if dtypeList[0] == DType.INT32 and error_name is None: pCount, cCount = op["operands"] assert ( pCount == 1 and cCount == 0 ), "Op.NEGATE must have 1 placeholders, 0 consts" - # Must create tensors with values within negatable ranges - if dtypeList[0] == DType.INT8: - # Must be within int8, adjustable by input_zp and then negatable - # and be within int8 - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - max_val = min(127, 127 + qinfo.ints[0][1]) - min_val = max(-127, -127 + qinfo.ints[0][1]) - elif dtypeList[0] == DType.INT16: - max_val = 32767 - min_val = -max_val - else: - assert ( - dtypeList[0] == DType.INT32 - ), "Op.NEGATE found with unsupported input type" - max_val = (1 << 31) - 1 - min_val = -max_val + # Must create tensors with values within accumulator (int32) negatable + # range + max_val = (1 << 31) - 1 + min_val = -max_val arr = np.int32( testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0]) ) |