aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 3e14c5b..44582ac 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -381,7 +381,7 @@ class TosaTensorGen:
new_shapeList = [shape.copy()]
length_on_axis = shape[axis]
remaining_length = length_on_axis
- for i in range(len(shapeList)-2):
+ for i in range(len(shapeList) - 2):
# Calculate split on axis and remaining value
split_shape_val = int(shape[axis] / 2)
remaining_length = remaining_length - split_shape_val
@@ -396,7 +396,6 @@ class TosaTensorGen:
return new_shapeList
-
class TosaArgGen:
"""Argument generators create exhaustive or random lists of attributes for operators that take
attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
@@ -1339,7 +1338,7 @@ class TosaTestGen:
result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
attr = ts.TosaSerializerAttribute()
- attr.Pool2dAttribute(kernel, stride, pad)
+ attr.PoolAttribute(kernel, stride, pad)
self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
return result_tens
@@ -1351,7 +1350,7 @@ class TosaTestGen:
)
attr = ts.TosaSerializerAttribute()
- attr.Conv2dAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations)
self.ser.addOperator(
op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
@@ -1365,7 +1364,7 @@ class TosaTestGen:
result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
attr = ts.TosaSerializerAttribute()
- attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
+ attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
self.ser.addOperator(
op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
@@ -1380,7 +1379,7 @@ class TosaTestGen:
)
attr = ts.TosaSerializerAttribute()
- attr.Conv2dAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations)
self.ser.addOperator(
op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
@@ -1463,7 +1462,7 @@ class TosaTestGen:
return result_tens
def build_concat(self, op, *a):
- assert (type(a[-1]) == int)
+ assert type(a[-1]) == int
# To store variable length list of input tensors we need to store axis along with it
axis = a[-1]
@@ -1944,12 +1943,12 @@ class TosaTestGen:
if isinstance(dtype_or_dtypeList, list):
dtypeList = dtype_or_dtypeList
- elif op['op'] == Op.CONCAT:
+ elif op["op"] == Op.CONCAT:
dtypeList = [dtype_or_dtypeList] * len(shapeList)
else:
dtypeList = [dtype_or_dtypeList] * (num_operands)
- if op['op'] != Op.CONCAT:
+ if op["op"] != Op.CONCAT:
assert (
len(shapeList) == num_operands
), "shapeList length {} must match number of operands {}".format(