aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py54
1 files changed, 11 insertions, 43 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index df6b575..c068937 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -337,7 +337,7 @@ def rewrite_concat(op):
def remove_memory_ops(op, arch):
if op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
- bypass_memory_only_ops(op)
+ bypass_memory_only_ops(op, arch, None)
def rewrite_activation(op, arch, nng):
@@ -357,7 +357,6 @@ def rewrite_activation(op, arch, nng):
return op
-
def rewrite_rescale(op, arch, nng):
if op.type == Op.Rescale:
ifm = op.ifm
@@ -367,9 +366,6 @@ def rewrite_rescale(op, arch, nng):
assert len(ifm.ops) == 1
prev_op = ifm.ops[0]
- # TODO currently not supported
- assert len(ifm.consumer_list) == 1
-
input_zp = op.attrs["input_zp"]
output_zp = op.attrs["output_zp"]
multiplier = op.attrs["multiplier"]
@@ -390,6 +386,7 @@ def rewrite_rescale(op, arch, nng):
assert False
ifm.quantization.zero_point = input_zp
ofm.quantization.zero_point = output_zp
+
for s, m in zip(shift, multiplier):
# TODO these are the TOSA limitations
assert m >= 0
@@ -403,45 +400,16 @@ def rewrite_rescale(op, arch, nng):
else:
rounding_mode = RoundingMode.HalfUp
- if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
- assert len(multiplier) == len(shift) == len(prev_op.bias.values)
-
- if ifm.dtype == DataType.int32 and per_channel:
- prev_op.explicit_scaling = explicit_scaling
- prev_op.rounding_mode = rounding_mode
-
- # Bypass op
- prev_op.set_output_tensor(ofm)
- DebugDatabase.add_optimised(op, prev_op)
- return op
- else:
- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
- assert False
- # TODO which are the cases we need to and can do standalone Rescale?
- # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
- # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
- # limited to these at the moment:
- elif (
- (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
- or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
- or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
- ):
- # Create NOP performing the RESCALE
- avgpool_op = replace_rescale_with_avg_pool(op)
- avgpool_op.rounding_mode = rounding_mode
-
- if per_channel:
- # TODO
- avgpool_op.explicit_scaling = explicit_scaling
- print("Warning, unsupported TOSA Rescale")
- assert False
- else:
- avgpool_op.explicit_scaling = explicit_scaling
+ # Generate Rescale behaviour attached to a compatible NOP
+ avgpool_op = replace_rescale_with_avg_pool(op)
+ avgpool_op.rounding_mode = rounding_mode
+
+ if per_channel:
+ assert False, "per_channel rescale not supported"
else:
- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
- assert False
- return op
+ avgpool_op.explicit_scaling = explicit_scaling
+ return op
def convert_pad_in_width(op):
"""