diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-08-31 16:14:03 -0700 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2021-09-15 23:58:06 +0000 |
commit | 93a1628bc3dd48d9ba099de503b586a561b4751f (patch) | |
tree | adab77805c3d78cf3b30b00684e8a76316e11477 /verif/tosa_test_gen.py | |
parent | e3d6a8ffe0fffaf9d29167b03509b85a2f4d8308 (diff) | |
download | reference_model-93a1628bc3dd48d9ba099de503b586a561b4751f.tar.gz |
Rename attribute: Pool2d, Conv2d, TransposeConv2d -> Pool, Conv, TransposeConv
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I466dd1dcf5230e8e07df202ba88515e775e04a1e
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r-- | verif/tosa_test_gen.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 3e14c5b..44582ac 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -381,7 +381,7 @@ class TosaTensorGen: new_shapeList = [shape.copy()] length_on_axis = shape[axis] remaining_length = length_on_axis - for i in range(len(shapeList)-2): + for i in range(len(shapeList) - 2): # Calculate split on axis and remaining value split_shape_val = int(shape[axis] / 2) remaining_length = remaining_length - split_shape_val @@ -396,7 +396,6 @@ class TosaTensorGen: return new_shapeList - class TosaArgGen: """Argument generators create exhaustive or random lists of attributes for operators that take attributes or other parameters. The return value is a list of (descriptive_name, [arglist]) @@ -1339,7 +1338,7 @@ class TosaTestGen: result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad) attr = ts.TosaSerializerAttribute() - attr.Pool2dAttribute(kernel, stride, pad) + attr.PoolAttribute(kernel, stride, pad) self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo) return result_tens @@ -1351,7 +1350,7 @@ class TosaTestGen: ) attr = ts.TosaSerializerAttribute() - attr.Conv2dAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations) self.ser.addOperator( op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo @@ -1365,7 +1364,7 @@ class TosaTestGen: result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape) attr = ts.TosaSerializerAttribute() - attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape) + attr.TransposeConvAttribute(outpad, stride, dilation, output_shape) self.ser.addOperator( op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo @@ -1380,7 +1379,7 @@ class TosaTestGen: ) attr = ts.TosaSerializerAttribute() - attr.Conv2dAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations) self.ser.addOperator( op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo @@ -1463,7 +1462,7 @@ class TosaTestGen: return result_tens def build_concat(self, op, *a): - assert (type(a[-1]) == int) + assert type(a[-1]) == int # To store variable length list of input tensors we need to store axis along with it axis = a[-1] @@ -1944,12 +1943,12 @@ class TosaTestGen: if isinstance(dtype_or_dtypeList, list): dtypeList = dtype_or_dtypeList - elif op['op'] == Op.CONCAT: + elif op["op"] == Op.CONCAT: dtypeList = [dtype_or_dtypeList] * len(shapeList) else: dtypeList = [dtype_or_dtypeList] * (num_operands) - if op['op'] != Op.CONCAT: + if op["op"] != Op.CONCAT: assert ( len(shapeList) == num_operands ), "shapeList length {} must match number of operands {}".format( |