From 78b9412b07e0a46e58e8ecb9da8d661399c006a5 Mon Sep 17 00:00:00 2001 From: Rob Elliott Date: Thu, 25 Jan 2024 13:05:16 +0000 Subject: Modifications of rescale to enable basic form quantized network support. Minor fixes for TOSA 0.80.0 and 0.80.1 field naming following from the 0.2 to 0.8 conversion. Change-Id: I2ac1b3ac1ec60cf765edf54030cd2338bf001289 Signed-off-by: Rob Elliott --- ethosu/vela/tosa_graph_optimiser.py | 54 ++++++++----------------------------- ethosu/vela/tosa_mapping.py | 4 +-- ethosu/vela/tosa_reader.py | 4 +-- 3 files changed, 15 insertions(+), 47 deletions(-) diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index df6b575..c068937 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -337,7 +337,7 @@ def rewrite_concat(op): def remove_memory_ops(op, arch): if op.run_on_npu and op.type in (Op.Reshape, Op.Identity): - bypass_memory_only_ops(op) + bypass_memory_only_ops(op, arch, None) def rewrite_activation(op, arch, nng): @@ -357,7 +357,6 @@ def rewrite_activation(op, arch, nng): return op - def rewrite_rescale(op, arch, nng): if op.type == Op.Rescale: ifm = op.ifm @@ -367,9 +366,6 @@ def rewrite_rescale(op, arch, nng): assert len(ifm.ops) == 1 prev_op = ifm.ops[0] - # TODO currently not supported - assert len(ifm.consumer_list) == 1 - input_zp = op.attrs["input_zp"] output_zp = op.attrs["output_zp"] multiplier = op.attrs["multiplier"] @@ -390,6 +386,7 @@ def rewrite_rescale(op, arch, nng): assert False ifm.quantization.zero_point = input_zp ofm.quantization.zero_point = output_zp + for s, m in zip(shift, multiplier): # TODO these are the TOSA limitations assert m >= 0 @@ -403,45 +400,16 @@ def rewrite_rescale(op, arch, nng): else: rounding_mode = RoundingMode.HalfUp - if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected: - assert len(multiplier) == len(shift) == len(prev_op.bias.values) - - if ifm.dtype == DataType.int32 and per_channel: - prev_op.explicit_scaling = explicit_scaling - prev_op.rounding_mode = rounding_mode - - # Bypass op - prev_op.set_output_tensor(ofm) - DebugDatabase.add_optimised(op, prev_op) - return op - else: - print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) - assert False - # TODO which are the cases we need to and can do standalone Rescale? - # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops? - # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE? - # limited to these at the moment: - elif ( - (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) - or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8) - or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8) - ): - # Create NOP performing the RESCALE - avgpool_op = replace_rescale_with_avg_pool(op) - avgpool_op.rounding_mode = rounding_mode - - if per_channel: - # TODO - avgpool_op.explicit_scaling = explicit_scaling - print("Warning, unsupported TOSA Rescale") - assert False - else: - avgpool_op.explicit_scaling = explicit_scaling + # Generate Rescale behaviour attached to a compatible NOP + avgpool_op = replace_rescale_with_avg_pool(op) + avgpool_op.rounding_mode = rounding_mode + + if per_channel: + assert False, "per_channel rescale not supported" else: - print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) - assert False - return op + avgpool_op.explicit_scaling = explicit_scaling + return op def convert_pad_in_width(op): """ diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py index 2dafd81..0ec2f1d 100644 --- a/ethosu/vela/tosa_mapping.py +++ b/ethosu/vela/tosa_mapping.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -148,7 +148,7 @@ transpose_conv_attrs = AttrSerializer( ) transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),)) axis_attrs = AttrSerializer("AxisAttribute", ("axis",)) -reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),)) +reshape_attrs = AttrSerializer("ReshapeAttribute", (("new_shape", is_vec),)) slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec))) tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),)) resize_attrs = AttrSerializer( diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 56af59d..2f37478 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -294,7 +294,7 @@ class TosaGraph: def check_version(self, tosa_graph): version = tosa_graph.Version() version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}" - if version_str != "0.80.0": + if version_str not in ( "0.80.0", "0.80.1" ): print(f"Unsupported TOSA version: {version_str}") assert False -- cgit v1.2.1