aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/tosa_test_gen.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index cf9e06a..4820503 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1090,11 +1090,17 @@ class TosaTestGen:
return result_tens
def build_table(self, op, a):
- # Constant size, random values
- table_arr = self.getRandTensor([513], DType.INT16)
- table_tens = self.ser.addConst(table_arr.shape, DType.INT16, table_arr)
+ # Constant size depending on type, random values
+ if a.dtype == DType.INT16:
+ table_dtype = DType.INT16
+ table_arr = self.getRandTensor([513], table_dtype)
+ else:
+ assert a.dtype == DType.INT8
+ table_dtype = DType.INT8
+ table_arr = self.getRandTensor([256], table_dtype)
- result_tens = OutputShaper.tableOp(self.ser, a, table_tens)
+ table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
+ result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
return result_tens
@@ -2176,7 +2182,7 @@ class TosaTestGen:
# a different type from the input
"operands": (1, 0),
"build_fcn": (build_table, TosaTensorGen.tgBasic, None),
- "types": [DType.INT16],
+ "types": [DType.INT8, DType.INT16],
},
# Elementwise Unary operators
"abs": {
@@ -2760,9 +2766,11 @@ class OutputShaper:
return ser.addOutput(output_shape, values_in.dtype)
@staticmethod
- def tableOp(ser, input, table):
- # Same shape as the input, but with the type of the table.
- return ser.addOutput(input.shape, DType.INT32)
+ def tableOp(ser, input, table_dtype):
+ # Same shape as the input, but dtype dependent on table dtype
+ assert table_dtype == DType.INT16 or table_dtype == DType.INT8
+ output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
+ return ser.addOutput(input.shape, output_dtype)
@staticmethod
def resizeOp(