aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-07-22 14:30:53 +0100
committerEric Kunze <eric.kunze@arm.com>2021-08-12 16:00:26 +0000
commitcac4ee9575a9bae4f6502f8ba7f86e294b92edff (patch)
tree92e617680ee42b210f7485e7027ea5de78209f76
parent2a29dc69170630775523366f29c5914a7981d264 (diff)
downloadreference_model-cac4ee9575a9bae4f6502f8ba7f86e294b92edff.tar.gz
Add support for UINT8
* RESCALE can now produce tests with UINT8 as the input/output datatype. Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Change-Id: I0a5d7b3c6dd7c2501d14e5d558b1f18e5e967fa9
-rw-r--r--verif/tosa_test_gen.py24
1 files changed, 20 insertions, 4 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index e08add3..c05abc0 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -547,7 +547,14 @@ class TosaArgGen:
arg_list = []
# Enumerate the output types here
- for dtype in [DType.INT8, DType.INT16, DType.INT32]:
+ for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ if inDtype == DType.UINT8 and dtype != DType.INT8:
+ # The only output dtype for UINT8 is INT8, skip all other combinations
+ continue
+ if inDtype != DType.INT8 and dtype == DType.UINT8:
+ # The only input dtype for UINT8 is INT8, skip all other combinations
+ continue
+
for scale32 in [False, True]:
for double_round in [False, True]:
for per_channel in [False, True]:
@@ -555,6 +562,9 @@ class TosaArgGen:
if inDtype == DType.INT48 and scale32:
# Illegal condition. Must be scale32=False
continue
+ if double_round and not scale32:
+ # Illegal condition. ERROR_IF(!scale32 && double_round)
+ continue
arg_list.append(
(
@@ -1426,13 +1436,19 @@ class TosaTestGen:
out_type_width = self.typeWidth(out_dtype)
if val.dtype == DType.INT8:
- input_zp = self.randInt(-128, 127)
+ input_zp = self.randInt(-128, 128)
+ in_type_width = in_type_width + 1
+ elif val.dtype == DType.UINT8:
+ input_zp = self.randInt(0, 256)
in_type_width = in_type_width + 1
else:
input_zp = 0
if out_dtype == DType.INT8:
- output_zp = self.randInt(-128, 127)
+ output_zp = self.randInt(-128, 128)
+ out_type_width = out_type_width + 1
+ elif out_dtype == DType.UINT8:
+ output_zp = self.randInt(0, 256)
out_type_width = out_type_width + 1
else:
output_zp = 0
@@ -2415,7 +2431,7 @@ class TosaTestGen:
"op": Op.RESCALE,
"operands": (1, 0),
"build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
- "types": [DType.INT8, DType.INT16, DType.INT32, DType.INT48],
+ "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
},
# Custom
# Not implemented.