From f54d8a2d341e1ebc5f37465a2e1b08d7e9c9785c Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 20 Jul 2021 16:01:06 +0100 Subject: Add INT8 table operator support to test generator. Signed-off-by: Jeremy Johnson Change-Id: I5f01fa589692f7c6d556a4c22a44caec7c906b9d --- verif/tosa_test_gen.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index cf9e06a..4820503 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1090,11 +1090,17 @@ class TosaTestGen: return result_tens def build_table(self, op, a): - # Constant size, random values - table_arr = self.getRandTensor([513], DType.INT16) - table_tens = self.ser.addConst(table_arr.shape, DType.INT16, table_arr) + # Constant size depending on type, random values + if a.dtype == DType.INT16: + table_dtype = DType.INT16 + table_arr = self.getRandTensor([513], table_dtype) + else: + assert a.dtype == DType.INT8 + table_dtype = DType.INT8 + table_arr = self.getRandTensor([256], table_dtype) - result_tens = OutputShaper.tableOp(self.ser, a, table_tens) + table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr) + result_tens = OutputShaper.tableOp(self.ser, a, table_dtype) self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None) return result_tens @@ -2176,7 +2182,7 @@ class TosaTestGen: # a different type from the input "operands": (1, 0), "build_fcn": (build_table, TosaTensorGen.tgBasic, None), - "types": [DType.INT16], + "types": [DType.INT8, DType.INT16], }, # Elementwise Unary operators "abs": { @@ -2760,9 +2766,11 @@ class OutputShaper: return ser.addOutput(output_shape, values_in.dtype) @staticmethod - def tableOp(ser, input, table): - # Same shape as the input, but with the type of the table. - return ser.addOutput(input.shape, DType.INT32) + def tableOp(ser, input, table_dtype): + # Same shape as the input, but dtype dependent on table dtype + assert table_dtype == DType.INT16 or table_dtype == DType.INT8 + output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8 + return ser.addOutput(input.shape, output_dtype) @staticmethod def resizeOp( -- cgit v1.2.1