aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2021-07-20 16:01:06 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2021-07-20 16:01:06 +0100
commitf54d8a2d341e1ebc5f37465a2e1b08d7e9c9785c (patch)
tree91bf763d180c9bf22caf9c2094a79ae328df317a
parent16aac579b5d26b8efee57c6d1feb4695c265ce53 (diff)
downloadreference_model-f54d8a2d341e1ebc5f37465a2e1b08d7e9c9785c.tar.gz
Add INT8 table operator support to test generator.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I5f01fa589692f7c6d556a4c22a44caec7c906b9d
-rw-r--r--verif/tosa_test_gen.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index cf9e06a..4820503 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1090,11 +1090,17 @@ class TosaTestGen:
return result_tens
def build_table(self, op, a):
- # Constant size, random values
- table_arr = self.getRandTensor([513], DType.INT16)
- table_tens = self.ser.addConst(table_arr.shape, DType.INT16, table_arr)
+ # Constant size depending on type, random values
+ if a.dtype == DType.INT16:
+ table_dtype = DType.INT16
+ table_arr = self.getRandTensor([513], table_dtype)
+ else:
+ assert a.dtype == DType.INT8
+ table_dtype = DType.INT8
+ table_arr = self.getRandTensor([256], table_dtype)
- result_tens = OutputShaper.tableOp(self.ser, a, table_tens)
+ table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
+ result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
return result_tens
@@ -2176,7 +2182,7 @@ class TosaTestGen:
# a different type from the input
"operands": (1, 0),
"build_fcn": (build_table, TosaTensorGen.tgBasic, None),
- "types": [DType.INT16],
+ "types": [DType.INT8, DType.INT16],
},
# Elementwise Unary operators
"abs": {
@@ -2760,9 +2766,11 @@ class OutputShaper:
return ser.addOutput(output_shape, values_in.dtype)
@staticmethod
- def tableOp(ser, input, table):
- # Same shape as the input, but with the type of the table.
- return ser.addOutput(input.shape, DType.INT32)
+ def tableOp(ser, input, table_dtype):
+ # Same shape as the input, but dtype dependent on table dtype
+ assert table_dtype == DType.INT16 or table_dtype == DType.INT8
+ output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
+ return ser.addOutput(input.shape, output_dtype)
@staticmethod
def resizeOp(