aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py60
1 files changed, 59 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 44e0f8ec..169da40d 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -19,7 +19,10 @@ from . import rewrite_graph
from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
+from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
from .graph_optimiser_util import calc_explicit_padding
+from .graph_optimiser_util import convert_depthwise_to_conv
+from .graph_optimiser_util import fix_sg_input_output
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
@@ -74,6 +77,43 @@ def add_padding_fields(op, arch, nng):
return op
+def remove_const_transpose(op, arch, nng):
+ if op.type == Op.Transpose:
+ removed = False
+ if len(op.ifm.ops) == 1:
+ prev_op = op.ifm.ops[0]
+ if prev_op.type == Op.Const:
+ # Transpose the Tensor and data and remove Transpose
+ # TODO move to Tensor?
+ reorder = op.attrs["perms"]
+ shape = op.ifm.shape.copy()
+ tens = op.ifm
+
+ tens.shape = [shape[idx] for idx in reorder]
+ tens.bandwidth_shape = tens.shape
+ tens.storage_shape = tens.shape
+
+ if tens.values is not None:
+ tens.values = tens.values.transpose(reorder)
+
+ op.ofm.values = tens.values
+ # Bypass the Transpose op
+ prev_op.set_output_tensor(op.ofm)
+ DebugDatabase.add_optimised(op, prev_op)
+ removed = True
+
+ if not removed:
+ print("Cannot remove Transpose, and handling of Transpose is not supported")
+ assert False
+
+ return op
+
+
+def remove_reshapes(op, arch):
+ if op.run_on_npu and op.type == Op.Reshape:
+ bypass_reshape_and_squeeze_ops(op)
+
+
def rewrite_activation(op, arch, nng):
if op.type not in (Op.ReluN, Op.Clamp):
return op
@@ -206,6 +246,7 @@ def fixup_quantization(op, arch, nng):
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
+ assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
return op
@@ -221,8 +262,25 @@ def tosa_optimise_graph(nng, arch):
nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
)
+ # Removal of Transpose
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
+ )
+
+ # Handle sg input output
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
+ )
+
+ # Removal of reshapes
+ for sg in nng.subgraphs:
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+ sg.refresh_after_modification()
+
# Rewite Operators step
- op_rewrite_list = [set_tensor_equivalence, rewrite_rescale]
+ op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv]
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(