diff options
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 54 |
1 files changed, 11 insertions, 43 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index df6b575..c068937 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -337,7 +337,7 @@ def rewrite_concat(op): def remove_memory_ops(op, arch): if op.run_on_npu and op.type in (Op.Reshape, Op.Identity): - bypass_memory_only_ops(op) + bypass_memory_only_ops(op, arch, None) def rewrite_activation(op, arch, nng): @@ -357,7 +357,6 @@ def rewrite_activation(op, arch, nng): return op - def rewrite_rescale(op, arch, nng): if op.type == Op.Rescale: ifm = op.ifm @@ -367,9 +366,6 @@ def rewrite_rescale(op, arch, nng): assert len(ifm.ops) == 1 prev_op = ifm.ops[0] - # TODO currently not supported - assert len(ifm.consumer_list) == 1 - input_zp = op.attrs["input_zp"] output_zp = op.attrs["output_zp"] multiplier = op.attrs["multiplier"] @@ -390,6 +386,7 @@ def rewrite_rescale(op, arch, nng): assert False ifm.quantization.zero_point = input_zp ofm.quantization.zero_point = output_zp + for s, m in zip(shift, multiplier): # TODO these are the TOSA limitations assert m >= 0 @@ -403,45 +400,16 @@ def rewrite_rescale(op, arch, nng): else: rounding_mode = RoundingMode.HalfUp - if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected: - assert len(multiplier) == len(shift) == len(prev_op.bias.values) - - if ifm.dtype == DataType.int32 and per_channel: - prev_op.explicit_scaling = explicit_scaling - prev_op.rounding_mode = rounding_mode - - # Bypass op - prev_op.set_output_tensor(ofm) - DebugDatabase.add_optimised(op, prev_op) - return op - else: - print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) - assert False - # TODO which are the cases we need to and can do standalone Rescale? - # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops? - # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE? - # limited to these at the moment: - elif ( - (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) - or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8) - or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8) - ): - # Create NOP performing the RESCALE - avgpool_op = replace_rescale_with_avg_pool(op) - avgpool_op.rounding_mode = rounding_mode - - if per_channel: - # TODO - avgpool_op.explicit_scaling = explicit_scaling - print("Warning, unsupported TOSA Rescale") - assert False - else: - avgpool_op.explicit_scaling = explicit_scaling + # Generate Rescale behaviour attached to a compatible NOP + avgpool_op = replace_rescale_with_avg_pool(op) + avgpool_op.rounding_mode = rounding_mode + + if per_channel: + assert False, "per_channel rescale not supported" else: - print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) - assert False - return op + avgpool_op.explicit_scaling = explicit_scaling + return op def convert_pad_in_width(op): """ |