aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Elliott <robert.elliott@arm.com>2024-01-25 13:05:16 +0000
committertim.hall <tim.hall@arm.com>2024-02-27 12:41:06 +0000
commit78b9412b07e0a46e58e8ecb9da8d661399c006a5 (patch)
treed766d872c2d56981046da85599b604a67018a9ed
parentcc82c36d30346b7495c40d04e78f9a9fbc46c6f3 (diff)
downloadethos-u-vela-78b9412b07e0a46e58e8ecb9da8d661399c006a5.tar.gz
Modifications of rescale to enable basic form quantized network support.
Minor fixes for TOSA 0.80.0 and 0.80.1 field naming following from the 0.2 to 0.8 conversion. Change-Id: I2ac1b3ac1ec60cf765edf54030cd2338bf001289 Signed-off-by: Rob Elliott <Robert.Elliott@arm.com>
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py54
-rw-r--r--ethosu/vela/tosa_mapping.py4
-rw-r--r--ethosu/vela/tosa_reader.py4
3 files changed, 15 insertions, 47 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index df6b575..c068937 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -337,7 +337,7 @@ def rewrite_concat(op):
def remove_memory_ops(op, arch):
if op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
- bypass_memory_only_ops(op)
+ bypass_memory_only_ops(op, arch, None)
def rewrite_activation(op, arch, nng):
@@ -357,7 +357,6 @@ def rewrite_activation(op, arch, nng):
return op
-
def rewrite_rescale(op, arch, nng):
if op.type == Op.Rescale:
ifm = op.ifm
@@ -367,9 +366,6 @@ def rewrite_rescale(op, arch, nng):
assert len(ifm.ops) == 1
prev_op = ifm.ops[0]
- # TODO currently not supported
- assert len(ifm.consumer_list) == 1
-
input_zp = op.attrs["input_zp"]
output_zp = op.attrs["output_zp"]
multiplier = op.attrs["multiplier"]
@@ -390,6 +386,7 @@ def rewrite_rescale(op, arch, nng):
assert False
ifm.quantization.zero_point = input_zp
ofm.quantization.zero_point = output_zp
+
for s, m in zip(shift, multiplier):
# TODO these are the TOSA limitations
assert m >= 0
@@ -403,45 +400,16 @@ def rewrite_rescale(op, arch, nng):
else:
rounding_mode = RoundingMode.HalfUp
- if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
- assert len(multiplier) == len(shift) == len(prev_op.bias.values)
-
- if ifm.dtype == DataType.int32 and per_channel:
- prev_op.explicit_scaling = explicit_scaling
- prev_op.rounding_mode = rounding_mode
-
- # Bypass op
- prev_op.set_output_tensor(ofm)
- DebugDatabase.add_optimised(op, prev_op)
- return op
- else:
- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
- assert False
- # TODO which are the cases we need to and can do standalone Rescale?
- # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
- # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
- # limited to these at the moment:
- elif (
- (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
- or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
- or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
- ):
- # Create NOP performing the RESCALE
- avgpool_op = replace_rescale_with_avg_pool(op)
- avgpool_op.rounding_mode = rounding_mode
-
- if per_channel:
- # TODO
- avgpool_op.explicit_scaling = explicit_scaling
- print("Warning, unsupported TOSA Rescale")
- assert False
- else:
- avgpool_op.explicit_scaling = explicit_scaling
+ # Generate Rescale behaviour attached to a compatible NOP
+ avgpool_op = replace_rescale_with_avg_pool(op)
+ avgpool_op.rounding_mode = rounding_mode
+
+ if per_channel:
+ assert False, "per_channel rescale not supported"
else:
- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
- assert False
- return op
+ avgpool_op.explicit_scaling = explicit_scaling
+ return op
def convert_pad_in_width(op):
"""
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 2dafd81..0ec2f1d 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -148,7 +148,7 @@ transpose_conv_attrs = AttrSerializer(
)
transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),))
axis_attrs = AttrSerializer("AxisAttribute", ("axis",))
-reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),))
+reshape_attrs = AttrSerializer("ReshapeAttribute", (("new_shape", is_vec),))
slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec)))
tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),))
resize_attrs = AttrSerializer(
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 56af59d..2f37478 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -294,7 +294,7 @@ class TosaGraph:
def check_version(self, tosa_graph):
version = tosa_graph.Version()
version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}"
- if version_str != "0.80.0":
+ if version_str not in ( "0.80.0", "0.80.1" ):
print(f"Unsupported TOSA version: {version_str}")
assert False