aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index f3cddadd..2d1245b0 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
+import numpy as np
+
from . import rewrite_graph
from .api import NpuRoundingMode
from .data_type import DataType
@@ -80,6 +82,39 @@ def add_padding_fields(op, arch, nng):
return op
+# Counts leading zeroes for a (int32)
+def count_leading_zeros(a):
+ lz = int(32)
+ if a != 0:
+ mask = 1 << (32 - 1)
+ lz = 0
+ while (mask & a) == 0:
+ mask = mask >> 1
+ lz = lz + 1
+ return lz
+
+
+def calc_scaling_avgpool(op, arch, nng):
+ if op.type == Op.AvgPool:
+ top, left, _, _ = op.attrs["explicit_padding"]
+ # TODO Only support for when global scaling can be used.
+ # That is when there is no padding
+ assert top == 0 and left == 0
+ assert op.explicit_scaling is None
+ multiplier = []
+ shift = []
+
+ kernel_wh = op.kernel.elements_wh()
+ k = 32 - count_leading_zeros(kernel_wh - 1)
+ numerator = np.int64(((1 << 30) + 1) << k)
+ multiplier.append(numerator // kernel_wh)
+ shift.append(30 + k)
+
+ op.rounding_mode = NpuRoundingMode.NATURAL
+ op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
+ return op
+
+
def remove_const_transpose(op, arch, nng):
if op.type == Op.Transpose:
removed = False
@@ -432,6 +467,12 @@ def tosa_optimise_graph(nng, arch):
rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
sg.refresh_after_modification()
+ # TODO, when and where to best handle calc_scaling_avgpool
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False,
+ )
+
# Rewite Operators step
op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv]