From c5b549b599ff459a29115a48e8f067eaa5891638 Mon Sep 17 00:00:00 2001 From: Michael McGeagh Date: Fri, 7 Aug 2020 11:54:28 +0100 Subject: MLBEDSW-2637 Utilise new tensor and operator funcs add_input_tensor, set_output_tensor, create_const_tensor and create_reshape_tensor have recently been added. This replaces all found existing instances with these new helper functions Signed-off-by: Michael McGeagh Change-Id: If33be8dbf237b2087b562b03cdeb51da1f99a786 --- ethosu/vela/extract_npu_subgraphs.py | 3 +- ethosu/vela/graph_optimiser.py | 82 +++++++++++------------------------- ethosu/vela/insert_dma.py | 3 +- ethosu/vela/npu_serialisation.py | 3 +- ethosu/vela/operation.py | 3 +- ethosu/vela/pass_packing.py | 3 +- ethosu/vela/tflite_reader.py | 13 ++---- 7 files changed, 33 insertions(+), 77 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/extract_npu_subgraphs.py b/ethosu/vela/extract_npu_subgraphs.py index 6747ec98..4adddc17 100644 --- a/ethosu/vela/extract_npu_subgraphs.py +++ b/ethosu/vela/extract_npu_subgraphs.py @@ -80,9 +80,8 @@ def rewrite_tensor_cpu_producer_npu_consumers( op_type = "Const" op = Operation(op_type, orig_tens.name + "_input") op.attrs["npu_block_type"] = NpuBlockType.Default - op.outputs = [new_tens] op.scheduled_pass = startup_init_ps - new_tens.ops = [op] + op.set_output_tensor(new_tens) startup_init_ps.ops.append(op) startup_init_ps.outputs.append(new_tens) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index a9d5cce5..582924c4 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -28,6 +28,8 @@ from .numeric_util import full_shape from .operation import NpuBlockType from .operation import Operation from .softmax import SoftMax +from .tensor import create_const_tensor +from .tensor import create_reshape_tensor from .tensor import QuantizationParameters from .tensor import Tensor @@ -84,7 +86,6 @@ def rewrite_split(tens, arch): tens.ops = [] new_op = Operation("SplitSliceRead", split_op.name) new_op.inputs = [inp] - new_op.outputs = [tens] # For Split the offset cannot be extracted from the tensor so it has to # be calculated from the index of the output tensor @@ -102,7 +103,7 @@ def rewrite_split(tens, arch): new_op.attrs["split_start"] = offset_start new_op.attrs["split_end"] = offset_end new_op.run_on_npu = True - tens.ops.append(new_op) + new_op.set_output_tensor(tens) return tens @@ -168,14 +169,12 @@ def fixup_conv2d_backprop(op, arch): if len(op.inputs) < 4: # Add bias/scale tensor filled with zeros - scale_op = Operation("Const", op.name + "_bias") scale_tens = Tensor([weight_sets], DataType.int32, op.name + "_bias_tens") scale_tens.values = [0] * weight_sets scale_tens.quant_values = [0] * weight_sets - scale_tens.ops = [scale_op] - scale_op.outputs = [scale_tens] - scale_tens.consumer_list = [op] - op.inputs.append(scale_tens) + scale_op = Operation("Const", op.name + "_bias") + scale_op.set_output_tensor(scale_tens) + op.add_input_tensor(scale_tens) # Update strides op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)}) @@ -199,8 +198,7 @@ def convert_resizebilinear_1x1_to_add(op): tens.quantization.zero_point = 0 tens.consumer_list = [op] tens_op = op.inputs[1].ops[0] - tens_op.outputs = [tens] - tens.ops = [tens_op] + tens_op.set_output_tensor(tens) # Set the add inputs op.inputs[1] = op.inputs[0] op.inputs[0] = tens @@ -233,22 +231,7 @@ def fixup_fully_connected_input(op, arch): desired_shape = [batch_size, n_in_elems] if inp.shape != desired_shape: # mismatch, insert a reshape to fix this. - reshape_name = op.name + "_reshape" - new_shape_tens = Tensor([1], DataType.int32, reshape_name + "_shape") - new_shape_tens.values = np.array(desired_shape) - new_shape_tens_const = Operation("Const", new_shape_tens.name + "_const") - new_shape_tens.ops = [new_shape_tens_const] - new_shape_tens_const.outputs = [new_shape_tens] - - reshape_op = Operation("Reshape", reshape_name) - reshape_op.inputs = [inp, new_shape_tens] - reshape_op.attrs["new_shape"] = desired_shape - reshape_out = inp.clone("_reshaped") - reshape_out.set_all_shapes(desired_shape) - reshape_out.ops = [reshape_op] - reshape_op.outputs = [reshape_out] - - op.inputs[0] = reshape_out + op.inputs[0] = create_reshape_tensor(inp, desired_shape) return op @@ -261,22 +244,16 @@ def fixup_pack_input(op, arch): desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:] # Construct 1 shape tensor to be used by all inserted reshape ops - new_shape_name = op.name + "_reshape_shape" - new_shape_tens = Tensor([1], DataType.int32, new_shape_name) - new_shape_tens.values = np.array(desired_shape) - new_shape_tens_const = Operation("Const", new_shape_tens.name + "_const") - new_shape_tens.ops = [new_shape_tens_const] - new_shape_tens_const.outputs = [new_shape_tens] + new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, desired_shape) for idx, inp in enumerate(op.inputs): - reshape_name = op.name + str(idx) + "_reshape" - reshape_op = Operation("Reshape", reshape_name) - reshape_op.inputs = [inp, new_shape_tens] - reshape_op.attrs["new_shape"] = desired_shape reshape_out = inp.clone("_reshaped") reshape_out.set_all_shapes(desired_shape) - reshape_out.ops = [reshape_op] - reshape_op.outputs = [reshape_out] + + reshape_op = Operation("Reshape", "{}{}_reshape".format(op.name, idx)) + reshape_op.attrs["new_shape"] = desired_shape + reshape_op.inputs = [inp, new_shape_tens] + reshape_op.set_output_tensor(reshape_out) op.inputs[idx] = reshape_out @@ -335,22 +312,17 @@ def fixup_unpack_output(tens, arch): reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:] # Construct 1 shape tensor to be used by all inserted reshape ops - new_shape_name = op.name + "_reshape_shape" - new_shape_tens = Tensor([1], DataType.int32, new_shape_name) - new_shape_tens.values = np.array(tens.shape) - new_shape_tens_const = Operation("Const", new_shape_tens.name + "_const") - new_shape_tens.ops = [new_shape_tens_const] - new_shape_tens_const.outputs = [new_shape_tens] + new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) for idx, out_tens in enumerate(op.outputs): - reshape_name = op.name + str(idx) + "_reshape" - reshape_op = Operation("Reshape", reshape_name) - reshape_op.outputs = [out_tens] reshape_in = out_tens.clone("_reshaped") reshape_in.set_all_shapes(reshape_input_shape) reshape_in.ops = [op] - out_tens.ops = [reshape_op] + + reshape_op = Operation("Reshape", "{}{}_reshape".format(op.name, idx)) + reshape_op.attrs["new_shape"] = reshape_input_shape reshape_op.inputs = [reshape_in, new_shape_tens] + reshape_op.set_output_tensor(out_tens) op.outputs[idx] = reshape_in @@ -517,17 +489,12 @@ def convert_conv_to_fc(op, arch): fc_ofm_tensor.set_all_shapes([1, fc_ofm_tensor.shape[-1]]) fc_ofm_tensor.ops = [op] # Add a reshape after the new OFM to convert it back to the original 4D shape - reshape_name = op.name + "_reshape_post" - new_shape_tens = Tensor([1], DataType.int32, reshape_name + "_shape") - new_shape_tens.values = np.array(orig_ofm_tensor.shape) - new_shape_tens_const = Operation("Const", new_shape_tens.name + "_const") - new_shape_tens.ops = [new_shape_tens_const] - new_shape_tens_const.outputs = [new_shape_tens] + reshape_name = op.name + "_reshape" + new_shape_tens = create_const_tensor(reshape_name + "_shape", [1], DataType.int32, orig_ofm_tensor.shape) reshape_op = Operation("Reshape", reshape_name) - reshape_op.inputs = [fc_ofm_tensor, new_shape_tens] reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape - orig_ofm_tensor.ops = [reshape_op] - reshape_op.outputs = [orig_ofm_tensor] + reshape_op.inputs = [fc_ofm_tensor, new_shape_tens] + reshape_op.set_output_tensor(orig_ofm_tensor) # Replace this ops OFM to point to the 2D tensor op.outputs[0] = fc_ofm_tensor return op @@ -542,8 +509,7 @@ def fixup_act_reorder(op, arch): act_op.inputs = [prep_op.inputs[0]] act_op_out = act_op.inputs[0].clone("_acted") act_op_out.quantization = op.outputs[0].quantization.clone() - act_op_out.ops = [act_op] - act_op.outputs = [act_op_out] + act_op.set_output_tensor(act_op_out) prep_op.inputs[0] = act_op_out prep_op.outputs[0].quantization = act_op_out.quantization.clone() diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py index 76016f1f..6c5c8031 100644 --- a/ethosu/vela/insert_dma.py +++ b/ethosu/vela/insert_dma.py @@ -56,11 +56,10 @@ def insert_dma_cmd(op, arch): new_tens = tens.clone_into_fast_storage(arch) dma_cmd = Operation("DMA", tens.ops[0].name + "_dma") dma_cmd.inputs = [tens] - dma_cmd.outputs = [new_tens] + dma_cmd.set_output_tensor(new_tens) dma_cmd.attrs["source"] = tens.mem_area dma_cmd.attrs["destination"] = new_tens.mem_area dma_cmd.run_on_npu = True - new_tens.ops = [dma_cmd] if tens.purpose == TensorPurpose.LUT: # TODO: Add support more than one LUT at a time # Reserve last 2 blocks for LUT diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py index 030503de..c6b0d877 100644 --- a/ethosu/vela/npu_serialisation.py +++ b/ethosu/vela/npu_serialisation.py @@ -149,8 +149,7 @@ def serialise_npu_subgraph_into_tensors(nng, sg, arch, scratch_tens, scratch_fas def add_const_tens_to_startup_cascaded_pass(startup_cps, tens): op = Operation("Const", tens.name + "_const") - op.outputs = [tens] - tens.ops = [op] + op.set_output_tensor(tens) startup_cps.passes[0].ops.insert(0, op) startup_cps.passes[0].outputs.insert(0, tens) startup_cps.outputs.insert(0, tens) diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index adbbff51..0290e811 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -307,10 +307,9 @@ input and output tensors, as well as an attribute dictionary.""" return input_tens, outputs, axis, offset_start, offset_end def set_activation_lut(self, lut_tensor): - lut_tensor.consumer_list.append(self) self.attrs["fused_activation_function"] = "LUT" self.activation_lut = lut_tensor - self.inputs.append(lut_tensor) + self.add_input_tensor(lut_tensor) def add_input_tensor(self, tens): self.inputs.append(tens) diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index fab00e00..8e108dbf 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -456,8 +456,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): avgpool_op.attrs["explicit_padding"] = [0, 0, 0, 0] avgpool_out = inp.clone("_avgpooled") avgpool_out.consumer_list.append(op) - avgpool_out.ops = [avgpool_op] - avgpool_op.outputs = [avgpool_out] + avgpool_op.set_output_tensor(avgpool_out) op.inputs[0] = avgpool_out ops_list.insert(0, avgpool_op) diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index bf3fe950..5e966d1b 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -54,9 +54,7 @@ def clone_and_reshape_tensor(src_tens, reorder): tens.quant_values = tens.quant_values.transpose(reorder) op = Operation("Const", tens.name) - op.outputs = [tens] - tens.ops = [op] - + op.set_output_tensor(tens) return tens @@ -81,14 +79,12 @@ class TFLiteSubgraph: TensorError(tens, "This subgraph input tensor has unexpected driving operators.") op = Operation("Placeholder", tens.name) - op.outputs = [tens] - tens.ops = [op] + op.set_output_tensor(tens) for tens in self.tensors: if not tens.ops: op = Operation("Const", tens.name) - op.outputs = [tens] - tens.ops = [op] + op.set_output_tensor(tens) def get_tensors_from_indices_remove_duplicates(self, indices, warning_str): tensors = [] @@ -190,8 +186,7 @@ class TFLiteSubgraph: act_op = Operation(activation_function_to_split_out, name + activation_function_to_split_out) out_tens = op.outputs[0] intermediate_tens = out_tens.clone("_act_intermediate") - out_tens.ops = [act_op] - act_op.outputs = [out_tens] + act_op.set_output_tensor(out_tens) intermediate_tens.ops = [op] op.outputs[0] = intermediate_tens act_op.inputs = [intermediate_tens] -- cgit v1.2.1