From 31947ad1aec50b64508bf367cb3e87c93f8c4693 Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Thu, 4 Apr 2024 15:50:08 +0200 Subject: Fix various pre-commit errors Change-Id: I8e584a036036f35a8883b2a4884cb2d54e675e39 Signed-off-by: Johan Alfven --- ethosu/vela/debug_database.py | 13 +++++++++++-- ethosu/vela/tosa_graph_optimiser.py | 3 ++- ethosu/vela/tosa_reader.py | 2 +- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py index 597f8410..f52cd023 100644 --- a/ethosu/vela/debug_database.py +++ b/ethosu/vela/debug_database.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -58,7 +58,16 @@ class DebugDatabase: cls._sourceUID[op] = uid ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1) cls._sourceTable.append( - [uid, str(op.type), op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1], op.op_index] + [ + uid, + str(op.type), + op.kernel.width, + op.kernel.height, + ofm_shape[-2], + ofm_shape[-3], + ofm_shape[-1], + op.op_index, + ] ) # Ops are added when their type changes, and after optimisation. If an op was already diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index bcb4aac8..19244c27 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -357,6 +357,7 @@ def rewrite_activation(op, arch, nng): return op + def rewrite_rescale(op, arch, nng): if op.type == Op.Rescale: ifm = op.ifm @@ -364,7 +365,6 @@ def rewrite_rescale(op, arch, nng): # some error checking assert len(ifm.ops) == 1 - prev_op = ifm.ops[0] input_zp = op.attrs["input_zp"] output_zp = op.attrs["output_zp"] @@ -409,6 +409,7 @@ def rewrite_rescale(op, arch, nng): return op + def convert_pad_in_width(op): """ Rewrites PAD operator to an add that copies the IFM to the OFM diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 9ffda801..6d80e10d 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -349,7 +349,7 @@ class TosaGraph: def check_version(self, tosa_graph): version = tosa_graph.Version() version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}" - if version_str not in ( "0.80.0", "0.80.1" ): + if version_str not in ("0.80.0", "0.80.1"): print(f"Unsupported TOSA version: {version_str}") assert False -- cgit v1.2.1