aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-10-06 14:46:46 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-10-08 10:10:09 +0200
commit1bf0f1976ae9d9ae8ef1e2c94af885a62276af43 (patch)
tree9b800539f13c82b1886424af94b8fd3f57948ff6
parent6f87be40a97a46a97c52a81e6e46eda0bdb73f9e (diff)
downloadethos-u-vela-1bf0f1976ae9d9ae8ef1e2c94af885a62276af43.tar.gz
TOSA: Added support for Const output
Added support for a Const operator generating network output. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Ia81990a94cc497a58535914124a29e7dbb511247
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py117
1 files changed, 74 insertions, 43 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index e27dbed6..9e72a6c1 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -151,6 +151,33 @@ def remove_const_transpose(op, arch, nng):
return op
+def insert_add_copy_for_const(op, ifm_ofm_shape):
+ assert op.type == Op.Const
+ ofm = op.ofm
+ copy_tens = ofm.clone()
+ op.set_output_tensor(copy_tens)
+
+ name = ofm.name + "_add"
+ ifm2 = create_const_tensor(
+ name + "_zero_scalar",
+ [1],
+ copy_tens.dtype,
+ [0],
+ copy_tens.dtype.as_numpy_type(),
+ quantization=copy_tens.quantization,
+ )
+ copy_op = create_add_nop(name)
+ copy_op.add_input_tensor(copy_tens)
+ copy_op.add_input_tensor(ifm2)
+ copy_op.set_output_tensor(ofm)
+ copy_op.ifm_shapes.append(ifm_ofm_shape)
+ copy_op.ifm_shapes.append(Shape4D(ifm2.shape))
+ copy_op.ofm_shapes.append(ifm_ofm_shape)
+ copy_op.run_on_npu = True
+
+ DebugDatabase.add_optimised(op, copy_op)
+
+
# TODO can we change to add for both TFLite and TOSA?
def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
tens_cons_list_copy = tens.consumer_list.copy()
@@ -184,51 +211,55 @@ def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
DebugDatabase.add_optimised(tens.ops[0], copy_op)
-def fix_sg_input_output_tosa(op, arch, nng):
- if not op.run_on_npu or op.type not in (Op.Reshape, Op.Identity):
- return op
+def get_shape_for_copy_op(shape):
+ # remove dimensions that are set to 1
+ new_shape = []
+ for dim in shape:
+ if dim != 1:
+ new_shape.append(dim)
+ if not new_shape:
+ new_shape = [1]
+
+ rank = len(new_shape)
+ if rank > 3:
+ # Reshape so that batch becomes 1, by moving elements to H dimension
+ n = rank - 2
+ h = 1
+ for i in range(n):
+ h *= shape[i]
+ new_shape = Shape4D(new_shape[n:]).with_height(h)
+ else:
+ new_shape = Shape4D(new_shape)
+ return new_shape
- # For the Reshape operators we want to remove, tensors are removed.
- # But in order to to do this, they cannot be outputs of the sg,
- # this need to be fixed prior to the removal.
- # Solution is to add a copy op, to maintain the original tensor.
- # This is also valid when reshape ifm/ofm is produced respectively
- # consumed by CPU
-
- # Check if operator ifm/ofm are sg ifm/ofm
- ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
- ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
- ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
- # Check if ifm/ofm is produced repectivly consumed by CPU
- ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
- ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
-
- if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
- # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
-
- # Decide on ifm/ofm shapes for the copy op based on ifm
- shape = op.ifm.shape.copy()
- # remove dimensions that are set to 1
- new_shape = []
- for dim in shape:
- if dim != 1:
- new_shape.append(dim)
- if not new_shape:
- new_shape = [1]
-
- rank = len(new_shape)
- if rank > 3:
- # Reshape so that batch becomes 1, by moving elements to H dimension
- n = rank - 2
- h = 1
- for i in range(n):
- h *= shape[i]
- new_shape = Shape4D(new_shape[n:]).with_height(h)
- else:
- new_shape = Shape4D(new_shape)
- insert_add_copy_op_after_tens(op.ifm, new_shape)
+def fix_sg_input_output_tosa(op, arch, nng):
+ if op.type == Op.Const and any(ofm_cons is None for ofm_cons in op.ofm.consumer_list):
+ # Const operator with sg output, insert copy op before the ofm
+ new_shape = get_shape_for_copy_op(op.ofm.shape.copy())
+ insert_add_copy_for_const(op, new_shape)
+ elif op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
+ # For the Reshape operators we want to remove, tensors are removed.
+ # But in order to to do this, they cannot be outputs of the sg,
+ # this need to be fixed prior to the removal.
+ # Solution is to add a copy op, to maintain the original tensor.
+ # This is also valid when reshape ifm/ofm is produced respectively
+ # consumed by CPU
+
+ # Check if operator ifm/ofm are sg ifm/ofm
+ ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+ ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+ ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+ # Check if ifm/ofm is produced repectivly consumed by CPU
+ ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
+ if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Operator
+ # Decide on ifm/ofm shapes for the copy op based on ifm
+ new_shape = get_shape_for_copy_op(op.ifm.shape.copy())
+ insert_add_copy_op_after_tens(op.ifm, new_shape)
return op
@@ -862,7 +893,7 @@ def tosa_optimise_graph(nng, arch):
# Handle sg input output
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
+ nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=True,
)
# Removal of reshapes