aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-07 13:30:29 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-08 13:21:24 +0200
commitf366fb1fdffcc9d0eb2e6daf60dc89a2bd442ce6 (patch)
tree7048f09f9dcdd5bac4070760e9100406a39b1bdc /ethosu/vela/tosa_graph_optimiser.py
parentf1580f0167d7e9539a17ac8e33b0b595300f8090 (diff)
downloadethos-u-vela-f366fb1fdffcc9d0eb2e6daf60dc89a2bd442ce6.tar.gz
TOSA: Fix AVGPOOL scaling
-Only support for avgpool when there is no padding. For this case, global scaling can be used. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I026b83b05f02c57c79f49935f5ec501a6d28bb91
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index f3cddadd..2d1245b0 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
+import numpy as np
+
from . import rewrite_graph
from .api import NpuRoundingMode
from .data_type import DataType
@@ -80,6 +82,39 @@ def add_padding_fields(op, arch, nng):
return op
+# Counts leading zeroes for a (int32)
+def count_leading_zeros(a):
+ lz = int(32)
+ if a != 0:
+ mask = 1 << (32 - 1)
+ lz = 0
+ while (mask & a) == 0:
+ mask = mask >> 1
+ lz = lz + 1
+ return lz
+
+
+def calc_scaling_avgpool(op, arch, nng):
+ if op.type == Op.AvgPool:
+ top, left, _, _ = op.attrs["explicit_padding"]
+ # TODO Only support for when global scaling can be used.
+ # That is when there is no padding
+ assert top == 0 and left == 0
+ assert op.explicit_scaling is None
+ multiplier = []
+ shift = []
+
+ kernel_wh = op.kernel.elements_wh()
+ k = 32 - count_leading_zeros(kernel_wh - 1)
+ numerator = np.int64(((1 << 30) + 1) << k)
+ multiplier.append(numerator // kernel_wh)
+ shift.append(30 + k)
+
+ op.rounding_mode = NpuRoundingMode.NATURAL
+ op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
+ return op
+
+
def remove_const_transpose(op, arch, nng):
if op.type == Op.Transpose:
removed = False
@@ -432,6 +467,12 @@ def tosa_optimise_graph(nng, arch):
rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
sg.refresh_after_modification()
+ # TODO, when and where to best handle calc_scaling_avgpool
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False,
+ )
+
# Rewite Operators step
op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv]