aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py64
1 files changed, 32 insertions, 32 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index e7c15cdc..4f435dcb 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -68,14 +68,14 @@ activation_ops = set(("Sigmoid", "Tanh")) | relu_ops
memory_only_ops = set(("Reshape",))
-def remove_passthrough_tensor(tens, arch):
+def remove_passthrough_tensor(tens, arch, nng):
if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
assert len(tens.ops[0].inputs) == 1
tens = tens.ops[0].inputs[0]
return tens
-def rewrite_concat(tens, arch):
+def rewrite_concat(tens, arch, nng):
if len(tens.ops) == 1 and tens.ops[0].is_concat_op():
concat_op = tens.ops[0]
if tens != concat_op.outputs[0]:
@@ -114,7 +114,7 @@ def rewrite_concat(tens, arch):
return tens
-def rewrite_split(tens, arch):
+def rewrite_split(tens, arch, nng):
if len(tens.ops) == 1 and tens.ops[0].is_split_op():
split_op = tens.ops[0]
@@ -205,7 +205,7 @@ def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dim
return padding, skirt
-def fixup_conv2d_backprop(op, arch):
+def fixup_conv2d_backprop(op, arch, nng):
if op.type == "Conv2DBackpropInput":
# flip the inputs
op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
@@ -295,7 +295,7 @@ def convert_resizebilinear_to_2x2_pool(op):
return op
-def fixup_resizebilinear(op, arch):
+def fixup_resizebilinear(op, arch, nng):
if op.type == "ResizeBilinear" and op.run_on_npu:
if op.inputs[0].shape == op.outputs[0].shape:
# Bypass nop resizebilinear
@@ -309,7 +309,7 @@ def fixup_resizebilinear(op, arch):
return op
-def convert_nop_split_to_identity(op, arch):
+def convert_nop_split_to_identity(op, arch, nng):
if op.type == "Split" and op.attrs.get("num_splits") == 1:
# the list comprehension should return a list with a single tensor
# if it shouldn't, remove_passthrough_tensor will fail appropriately
@@ -318,7 +318,7 @@ def convert_nop_split_to_identity(op, arch):
return op
-def fixup_fully_connected_input(op, arch):
+def fixup_fully_connected_input(op, arch, nng):
if op.type == "FullyConnectedAct":
inp = op.inputs[0]
weights = op.inputs[1]
@@ -336,7 +336,7 @@ def fixup_fully_connected_input(op, arch):
return op
-def convert_batched_fc_to_conv(op, arch):
+def convert_batched_fc_to_conv(op, arch, nng):
if op.type == "FullyConnectedAct":
ifm = op.inputs[0]
ofm = op.outputs[0]
@@ -407,7 +407,7 @@ def convert_batched_fc_to_conv(op, arch):
return op
-def fixup_pack_input(op, arch):
+def fixup_pack_input(op, arch, nng):
if op.type == "Pack":
# Pack is also referred to as Stack
# Requires the rewrite_concat function to be called on the op afterwards
@@ -433,7 +433,7 @@ def fixup_pack_input(op, arch):
return op
-def unfuse_activation_function(op, arch):
+def unfuse_activation_function(op, arch, nng):
unfuse_ops = ("ConcatTFLite",)
if op.type in unfuse_ops and op.run_on_npu and op.attrs.get("fused_activation_function", None) is not None:
act = op.attrs["fused_activation_function"]
@@ -448,7 +448,7 @@ def unfuse_activation_function(op, arch):
return op
-def fixup_unpack_output(tens, arch):
+def fixup_unpack_output(tens, arch, nng):
op = tens.ops[0]
if op.type in set(("Unpack", "StridedSlice")):
# Unpack is also referred to as Unstack
@@ -515,7 +515,7 @@ def fixup_unpack_output(tens, arch):
return tens
-def add_padding_fields(op, arch):
+def add_padding_fields(op, arch, nng):
if op.run_on_npu:
if "padding" in op.attrs:
if op.type in conv_op | depthwise_op:
@@ -564,7 +564,7 @@ def get_prepend_op(op):
return None
-def mark_npu_block_type(op, arch):
+def mark_npu_block_type(op, arch, nng):
npu_block_type = NpuBlockType.Default
if op.type in conv_op:
npu_block_type = NpuBlockType.ConvolutionMxN
@@ -583,7 +583,7 @@ def mark_npu_block_type(op, arch):
return op
-def convert_depthwise_to_conv(op, arch):
+def convert_depthwise_to_conv(op, arch, nng):
# Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
# the ofm depth equals the depth multipler.
# If those conditions are true, then we can perform a simple
@@ -610,7 +610,7 @@ def convert_depthwise_to_conv(op, arch):
return op
-def reorder_depthwise_weights(op, arch):
+def reorder_depthwise_weights(op, arch, nng):
if op.type in depthwise_op:
weight_tensor = op.inputs[1]
weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
@@ -620,7 +620,7 @@ def reorder_depthwise_weights(op, arch):
return op
-def convert_conv_to_fc(op, arch):
+def convert_conv_to_fc(op, arch, nng):
# Conv 1x1 can be equivalent to Fully Connected.
# By representing certain convs as fully connected layers, Vela can better determine wether or not to use
# caching/double buffering for the weights.
@@ -661,7 +661,7 @@ def convert_conv_to_fc(op, arch):
return op
-def fixup_relus_with_differing_ifm_ofm_scaling(op, arch):
+def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
if op.run_on_npu and op.type in relu_ops:
ifm = op.inputs[0]
ofm = op.outputs[0]
@@ -690,7 +690,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch):
# Reorder activation op if it's after the memory only operations
-def fixup_act_reorder(op, arch):
+def fixup_act_reorder(op, arch, nng):
if op.type in activation_ops:
prep_op = get_prepend_op(op)
if prep_op is not None:
@@ -715,7 +715,7 @@ def fixup_act_reorder(op, arch):
return op
-def fixup_elementwise_with_scalars(op, arch):
+def fixup_elementwise_with_scalars(op, arch, nng):
if op.type in binary_elementwise_op:
ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
@@ -736,7 +736,7 @@ def fixup_elementwise_with_scalars(op, arch):
# Set input/output tensor equivalence to the same id for memory operations
-def set_tensor_equivalence(op, arch):
+def set_tensor_equivalence(op, arch, nng):
if op.type in memory_only_ops:
eid = op.outputs[0].equivalence_id
for inp in op.inputs:
@@ -744,14 +744,14 @@ def set_tensor_equivalence(op, arch):
return op
-def convert_softmax(op, arch):
+def convert_softmax(op, arch, nng):
if op.type == "Softmax" and op.run_on_npu:
softmax = SoftMax(op)
op = softmax.get_graph()
return op
-def convert_mul_max_to_abs_or_lrelu(op, arch):
+def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
r"""Whenever there is a subgraph with this topology:
Input X For X = -1 or X > 0
@@ -958,7 +958,7 @@ def convert_lrelu_to_lut(op, arch):
return convert_to_lut(op, values)
-def convert_lrelu(op, arch):
+def convert_lrelu(op, arch, nng):
# Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
if op.type != "LeakyRelu":
return op
@@ -972,7 +972,7 @@ def convert_lrelu(op, arch):
return convert_lrelu_to_mul_max(op, arch)
-def convert_tanh_sigmoid_to_lut(op, arch):
+def convert_tanh_sigmoid_to_lut(op, arch, nng):
# Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
if op.type == "Sigmoid":
return convert_to_lut8(op, clamp_sigmoid)
@@ -981,7 +981,7 @@ def convert_tanh_sigmoid_to_lut(op, arch):
return op
-def remove_unwanted_reshapes(op, arch):
+def remove_unwanted_reshapes(op, arch, nng):
# Try to remove reshapes enclosing ElementWise operator with only one non-constant input
if not op.run_on_npu or op.attrs["npu_block_type"] != NpuBlockType.ElementWise:
return op
@@ -1016,7 +1016,7 @@ def remove_unwanted_reshapes(op, arch):
return op
-def fuse_activation_function_with_prev(op, arch):
+def fuse_activation_function_with_prev(op, arch, nng):
# if op is a no-op: attempts to move the activation function to the preceding op
if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None:
return op
@@ -1049,7 +1049,7 @@ def fuse_activation_function_with_prev(op, arch):
return op
-def add_attrs_to_resizebilinear(op, arch):
+def add_attrs_to_resizebilinear(op, arch, nng):
if op.type == "ResizeBilinear" and op.run_on_npu:
input_tensor = op.inputs[0]
upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2]
@@ -1069,7 +1069,7 @@ def add_attrs_to_resizebilinear(op, arch):
return op
-def fixup_bias_tensors(op, arch):
+def fixup_bias_tensors(op, arch, nng):
if op.needs_bias() and not op.inputs[-1]:
# Op has no bias, add bias tensor filled with zeros
nr_biases = op.inputs[1].shape[-1]
@@ -1081,7 +1081,7 @@ def fixup_bias_tensors(op, arch):
return op
-def supported_operator_check(op, arch):
+def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.supported_operators.is_operator_supported(op)
return op
@@ -1121,13 +1121,13 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# rewrite graph pass
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False
+ nng, sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False
)
for idx, sg in enumerate(nng.subgraphs):
# remove passthrough tensors and attempt further optimizations
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
+ nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
)
if verbose_graph:
@@ -1141,7 +1141,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# combined rewrite graph pass
- nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [rewrite_concat, rewrite_split], [])
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [rewrite_concat, rewrite_split], [])
if verbose_graph:
nng.print_graph()