diff options
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index f3cddadd..2d1245b0 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -15,6 +15,8 @@ # limitations under the License. # Description: # Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph. +import numpy as np + from . import rewrite_graph from .api import NpuRoundingMode from .data_type import DataType @@ -80,6 +82,39 @@ def add_padding_fields(op, arch, nng): return op +# Counts leading zeroes for a (int32) +def count_leading_zeros(a): + lz = int(32) + if a != 0: + mask = 1 << (32 - 1) + lz = 0 + while (mask & a) == 0: + mask = mask >> 1 + lz = lz + 1 + return lz + + +def calc_scaling_avgpool(op, arch, nng): + if op.type == Op.AvgPool: + top, left, _, _ = op.attrs["explicit_padding"] + # TODO Only support for when global scaling can be used. + # That is when there is no padding + assert top == 0 and left == 0 + assert op.explicit_scaling is None + multiplier = [] + shift = [] + + kernel_wh = op.kernel.elements_wh() + k = 32 - count_leading_zeros(kernel_wh - 1) + numerator = np.int64(((1 << 30) + 1) << k) + multiplier.append(numerator // kernel_wh) + shift.append(30 + k) + + op.rounding_mode = NpuRoundingMode.NATURAL + op.explicit_scaling = ExplicitScaling(False, shift, multiplier) + return op + + def remove_const_transpose(op, arch, nng): if op.type == Op.Transpose: removed = False @@ -432,6 +467,12 @@ def tosa_optimise_graph(nng, arch): rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) sg.refresh_after_modification() + # TODO, when and where to best handle calc_scaling_avgpool + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False, + ) + # Rewite Operators step op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv] |