From ef3ebddb6b4fd951046cbe7799721f6f0ed2fc87 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Fri, 1 Oct 2021 11:10:25 +0200 Subject: TOSA: Add support for Identity operation Added support for Identity operation. Signed-off-by: Patrik Gustavsson Change-Id: If00b30528932f7531807ce3914d6c1875ab72fa4 --- ethosu/vela/graph_optimiser_util.py | 1 + ethosu/vela/operation.py | 2 +- ethosu/vela/tosa_graph_optimiser.py | 8 ++++---- ethosu/vela/tosa_mapping.py | 3 +-- ethosu/vela/tosa_supported_operators.py | 12 ++++++++++-- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index d2d3d833..73fbf6c7 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -35,6 +35,7 @@ memory_only_ops = ( Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims, + Op.Identity, ) diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index b4267926..1e733d56 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -189,7 +189,7 @@ class Op(Enum): GreaterEqual = OperatorInfo() HardSwish = OperatorInfo(indices=NNG_IFM_INDICES) HashtableLookup = OperatorInfo() - Identity = OperatorInfo() + Identity = OperatorInfo(indices=NNG_IFM_INDICES) If = OperatorInfo() L2Norm = OperatorInfo() L2Pool2D = OperatorInfo() diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index d32955d5..954ac68f 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -182,7 +182,7 @@ def insert_add_copy_op_after_tens(tens, ifm_ofm_shape): def fix_sg_input_output_tosa(op, arch, nng): - if not op.run_on_npu or op.type != Op.Reshape: + if not op.run_on_npu or op.type not in (Op.Reshape, Op.Identity): return op # For the Reshape operators we want to remove, tensors are removed. @@ -306,8 +306,8 @@ def rewrite_concat(op): assert op.ofm_shapes[0][axis_4D] == offset -def remove_reshapes(op, arch): - if op.run_on_npu and op.type == Op.Reshape: +def remove_memory_ops(op, arch): + if op.run_on_npu and op.type in (Op.Reshape, Op.Identity): bypass_memory_only_ops(op) @@ -820,7 +820,7 @@ def tosa_optimise_graph(nng, arch): # Removal of reshapes for sg in nng.subgraphs: - rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_ops]) sg.refresh_after_modification() # Decomposing of elementwise diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py index f80a9156..75f1c9c5 100644 --- a/ethosu/vela/tosa_mapping.py +++ b/ethosu/vela/tosa_mapping.py @@ -223,7 +223,6 @@ unsupported_tosa_operators = { TosaOp.SCATTER, TosaOp.RESIZE, TosaOp.CAST, - TosaOp.IDENTITY, TosaOp.CUSTOM, TosaOp.COND_IF, TosaOp.WHILE_LOOP, @@ -316,7 +315,7 @@ tosa_operator_map = { # TODO TosaOp.CAST TosaOp.RESCALE: (Op.Rescale, rescale_attrs, None, TOSA_IFM_INDICES), TosaOp.CONST: (Op.Const, None, None, TOSA_NO_INDICES), - # TODO TosaOp.IDENTITY + TosaOp.IDENTITY: (Op.Identity, None, None, TOSA_IFM_INDICES), # TODO TosaOp.CUSTOM # TODO TosaOp.COND_IF # TODO TosaOp.WHILE_LOOP diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index 2692c05f..5a85b0eb 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -46,13 +46,21 @@ class TosaSupportedOperators: activation_ops = relu_ops | set((Op.Table,)) pad_ops = set((Op.Pad,)) - rank_unlimited_ops = set((Op.Concat, Op.Reshape)) + rank_unlimited_ops = set((Op.Concat, Op.Reshape, Op.Identity)) rank6_limited_ops = elem_wise_ops batch_enabled_ops = rank6_limited_ops | rank_unlimited_ops large_tens_dims_enabled_ops = batch_enabled_ops | set((Op.SplitSliceRead,)) npu_post_ops = activation_ops - supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops + supported_operators = ( + mac_main_ops + | type_conversion_ops + | npu_post_ops + | memory_only_ops + | elem_wise_ops + | pad_ops + | set((Op.Identity,)) + ) # Supported data types # TODO will differ compared to TensorFlow Lite, currently set to the same -- cgit v1.2.1