From 2349d429d926e258e9a61d34c7fd97660ab9fb98 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Tue, 1 Dec 2020 16:02:29 +0100 Subject: MLBEDSW-3654 Add/use op ifm/ofm shapes Add ifm/ofm shapes to op Changed to rely on these shapes Signed-off-by: Patrik Gustavsson Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d --- ethosu/vela/debug_database.py | 2 +- ethosu/vela/graph_optimiser.py | 69 +++++++++++++++++----- ethosu/vela/high_level_command_stream.py | 2 +- ethosu/vela/high_level_command_stream_generator.py | 60 ++++++++++--------- ethosu/vela/high_level_command_to_npu_op.py | 12 ++-- ethosu/vela/insert_dma.py | 2 +- ethosu/vela/live_range.py | 4 +- ethosu/vela/nn_graph.py | 2 + ethosu/vela/npu_performance.py | 10 ++-- ethosu/vela/operation.py | 50 ++++++++++++++-- ethosu/vela/operation_util.py | 3 + ethosu/vela/pass_packing.py | 19 ++++++ ethosu/vela/scheduler.py | 9 +-- ethosu/vela/shared_buffer_allocation.py | 9 +-- ethosu/vela/softmax.py | 5 +- ethosu/vela/tensor.py | 55 ++++++++++------- ethosu/vela/test/test_graph_optimiser.py | 13 ++++ ethosu/vela/test/testutil.py | 5 ++ 18 files changed, 231 insertions(+), 100 deletions(-) diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py index 4f0a50ae..203503f2 100644 --- a/ethosu/vela/debug_database.py +++ b/ethosu/vela/debug_database.py @@ -79,7 +79,7 @@ class DebugDatabase: src_uid = cls._sourceUID[parent] uid = len(cls._optimisedUID) cls._optimisedUID[op] = (uid, src_uid) - ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1) + ofm_shape = op.ofm_shapes[0] if op.ofm_shapes else numeric_util.full_shape(3, op.outputs[0].shape, 1) cls._optimisedTable.append( [uid, src_uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]] ) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 4806001f..fdb0fae0 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -75,7 +75,7 @@ def rewrite_concat(tens, arch, nng): new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx)) new_op.inputs = [inp] new_op.outputs = [tens] - new_op.attrs["concat_axis"] = axis + new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape)) new_op.attrs["concat_start"] = offset offset += inp.shape[axis] new_op.attrs["concat_end"] = offset @@ -116,21 +116,20 @@ def rewrite_split(tens, arch, nng): # be calculated from the index of the output tensor if axis is not None: # Get the start and end of the split - offset_start = [0] * len(tens.shape) - offset_end = [0] * len(tens.shape) - for out in outputs: + offset_start = [0] * 4 + for idx, out in enumerate(outputs): if out == tens: break - offset_start[axis] += out.shape[axis] + axis_4D = axis + (4 - len(out.shape)) + offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D] # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input if (offset_start[-1] % 16) != 0: inp.avoid_NHCWB16 = True - - offset_end[axis] = offset_start[axis] + tens.shape[axis] + else: + offset_start = full_shape(4, offset_start, 0) new_op.attrs["split_start"] = offset_start - new_op.attrs["split_end"] = offset_end new_op.run_on_npu = True new_op.set_output_tensor(tens) DebugDatabase.add_optimised(split_op, new_op) @@ -217,6 +216,8 @@ def convert_resizebilinear_1x1_to_add(op): # Set the add inputs op.inputs[1] = op.inputs[0] op.inputs[0] = tens + op.ifm_shapes = [] + op.ofm_shapes = [] return op @@ -321,13 +322,16 @@ def convert_batched_fc_shape(op, arch, nng): ifm = op.inputs[0] ofm = op.outputs[0] # Check if the FC is 2D and first dimension indicates batching - if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1: + # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete + if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1: n = ifm.shape[0] batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)} h, w = batching_split.get(n, (1, n)) prev_op = ifm.ops[0] desired_shape = [1, h, w, ifm.shape[-1]] + op.ifm_shapes[0] = desired_shape + if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape: # There is a preceding Reshape # Compare input of prev_op and input of op, to see if prev_op can be removed @@ -352,6 +356,8 @@ def convert_batched_fc_shape(op, arch, nng): weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) desired_shape = [1, h, w, ofm.shape[-1]] + op.ofm_shapes[0] = desired_shape + if ( len(ofm.consumer_list) == 1 and ofm.consumer_list[0] is not None @@ -451,6 +457,7 @@ def fixup_stridedslice_output(tens, arch, nng): new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) for idx, out_tens in enumerate(op.outputs): + op.ofm_shapes[idx] = new_shape_tens reshape_in = out_tens.clone("_reshaped") reshape_in.set_all_shapes(reshape_input_shape) reshape_in.ops = [op] @@ -489,7 +496,6 @@ def fixup_unpack_output(tens, arch, nng): DebugDatabase.add_optimised(op, reshape_op) op.outputs[idx] = reshape_in - return tens @@ -582,7 +588,7 @@ def convert_conv_to_fc(op, arch, nng): # caching/double buffering for the weights. # (Weights dont need to be reloaded for convs when IFM H and W are 1) if op.type == Op.Conv2DBias: - _, h, w, _ = op.inputs[0].shape + _, h, w, _ = op.ifm_shapes[0] kh, kw, _, _ = op.inputs[1].shape if h == 1 and w == 1 and kh == 1 and kw == 1: # Overwrite this op as a Fully Connected Op @@ -595,6 +601,7 @@ def convert_conv_to_fc(op, arch, nng): weight_tensor = op.inputs[1] weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1)) weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) + # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it # back to 4D afterwards as the next layer is expecting that shape orig_ofm_tensor = op.outputs[0] @@ -609,6 +616,7 @@ def convert_conv_to_fc(op, arch, nng): reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape reshape_op.inputs = [fc_ofm_tensor, new_shape_tens] reshape_op.set_output_tensor(orig_ofm_tensor) + # Replace this ops OFM to point to the 2D tensor op.outputs[0] = fc_ofm_tensor # Record optimisation in debug database @@ -651,6 +659,8 @@ def fixup_act_reorder(op, arch, nng): prep_op = get_prepend_op(op) if prep_op is not None: act_op = op.clone("_reordered") + act_op.ifm_shapes = list(op.ifm_shapes) + act_op.ofm_shapes = list(op.ofm_shapes) # There is only one input tensor, overwrite it act_op.set_input_tensor(prep_op.inputs[0], 0) @@ -658,6 +668,8 @@ def fixup_act_reorder(op, arch, nng): act_op_out = act_op.inputs[0].clone("_acted") act_op_out.quantization = op.outputs[0].quantization.clone() act_op.set_output_tensor(act_op_out) + act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1) + act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1) # Update the consumer list act_op_out.consumer_list = op.outputs[0].consumer_list.copy() @@ -704,6 +716,15 @@ def set_tensor_equivalence(op, arch, nng): return op +def set_ifm_ofm_op_shapes(op, arch, nng): + if op.run_on_npu and op.type.needs_shapes(): + if op.ifm_shapes or op.ofm_shapes: + # Shapes already set + return op + op.set_ifm_ofm_shapes() + return op + + def convert_softmax(op, arch, nng): if op.type == Op.Softmax and op.run_on_npu: softmax = SoftMax(op) @@ -839,7 +860,7 @@ def convert_lrelu_to_mul_max(op, arch): mul_identity.add_input_tensor(identity_tens) fm_id = ofm.clone(op.name + "_id") mul_identity.set_output_tensor(fm_id) - DebugDatabase.add_optimised(op, mul_alpha) + DebugDatabase.add_optimised(op, mul_identity) # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs op.type = Op.Maximum @@ -869,6 +890,8 @@ def convert_to_lut(op, lut_values, lut_name): quantization.zero_point = 0 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization) op.add_input_tensor(tens) + op.ifm_shapes.append(full_shape(4, tens.shape, 1)) + # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale), # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions # should be the same as the IFM @@ -1072,10 +1095,20 @@ def optimise_graph_a(nng, arch, verbose_graph=False): if verbose_graph: nng.print_graph() + pre_process_list = [ + supported_operator_check, + set_ifm_ofm_op_shapes, + # TODO: memory-only Op removal + ] + + for idx, sg in enumerate(nng.subgraphs): + # rewrite graph pass + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, + ) + op_rewrite_list = [ set_tensor_equivalence, - supported_operator_check, - # then do any rewrites of supported operators convert_depthwise_to_conv, convert_conv_to_fc, convert_softmax, @@ -1106,7 +1139,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # remove passthrough tensors and attempt further optimizations nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields] + nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields], ) # Post-optimisation operator debug tracing @@ -1125,7 +1158,11 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [] + nng, + sg, + arch, + [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], + [set_ifm_ofm_op_shapes], ) if verbose_graph: diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index c45bc4e5..bb4f1424 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -197,7 +197,7 @@ class NpuStripe(Command): self.pad_top = pad_top self.pad_bottom = pad_bottom for i in range(len(self.ofm_box.end_coord)): - assert self.ofm_box.end_coord[i] <= self.ofm_tensor.shape[i] + assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0][i] def is_npu_pass_command(self): return True diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 905263d6..18a419c0 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -56,6 +56,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id # Ensure correct ifm and ifm2 order if match_tensor(ps.inputs[0], ps.primary_op.inputs[1]) and match_tensor(ps.inputs[1], ps.primary_op.inputs[0]): ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor + ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0] for op in ps.ops: if op.type == Op.SplitSliceRead: @@ -77,13 +78,20 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ifm_idx += 1 ifm_tensor = ps.ifm_tensor + ifm_shape = None + if ifm_tensor.shape != []: + ifm_shape = ps.ifm_shapes[0] ifm2_tensor = ps.ifm2_tensor + ifm2_shape = None + if ifm2_tensor is not None and ifm2_tensor.shape != []: + ifm2_shape = ps.ifm_shapes[1] ofm_tensor = ps.ofm_tensor + ofm_shape = ps.ofm_shapes[0] weight_tensor = ps.weight_tensor scale_tensor = ps.scale_tensor - ofm_start = [0] * len(ofm_tensor.shape) - ofm_end = list(ofm_tensor.shape) + ofm_start = [0] * len(ofm_shape) + ofm_end = list(ofm_shape) strides = None skirt = None @@ -92,9 +100,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id strides = ps.primary_op.attrs.get("strides", None) skirt = ps.primary_op.attrs.get("skirt", None) if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias: - upscaling = ofm_tensor.shape[-3] // ifm_tensor.shape[-3] + upscaling = ofm_shape[-3] // ifm_shape[-3] elif ps.primary_op.type == Op.ResizeBilinear: - upscaling = round_up_divide(ofm_tensor.shape[-3], ifm_tensor.shape[-3]) + upscaling = round_up_divide(ofm_shape[-3], ifm_shape[-3]) concat_axis = 0 concat_offset = 0 @@ -125,7 +133,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ifm_box = None ifm2_box = None - if ifm_tensor.shape != []: + if ifm_shape is not None: ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt( strides, skirt, @@ -138,16 +146,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ) else: ifm_box = Box([], []) - if ifm2_tensor is not None and ifm2_tensor.shape != []: + if ifm2_shape is not None: ifm2_box, _, _ = ofm_box.transform_with_strides_and_skirt( - strides, - skirt, - ifm2_tensor.shape, - npu_block_type, - concat_axis, - concat_offset, - split_offsets[1], - upscaling, + strides, skirt, ifm2_shape, npu_block_type, concat_axis, concat_offset, split_offsets[1], upscaling, ) else: ifm2_box = Box([], []) @@ -212,19 +213,17 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id elif strat == SchedulingStrategy.IfmStream: y_step = block_config[0] - y_start = 0 - y_dim = 1 - if len(ofm_tensor.shape) >= 3: - y_start = ofm_start[-3] - y_dim = ofm_end[-3] + y_start = ofm_start[-3] + y_dim = ofm_end[-3] + if idx > 0: ifm_y_present = 0 prev_pass = passes[idx - 1] prev_pass_gen = generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx - 1) else: ifm_y_present = 1 - if len(ifm_tensor.shape) >= 3: - ifm_y_present = ifm_tensor.shape[-3] + if len(ifm_shape) >= 3: + ifm_y_present = ifm_shape[-3] prev_pass_gen = [] prev_pass = None @@ -243,9 +242,8 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id for start in range(y_start, y_dim, y_step): end = min(start + y_step, y_dim) - if len(ofm_tensor.shape) >= 3: - ofm_start[-3] = start - ofm_end[-3] = end + ofm_start[-3] = start + ofm_end[-3] = end ofm_box = Box(ofm_start, ofm_end) k_height = 1 @@ -259,7 +257,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt( strides, skirt, - ifm_tensor.shape, + ifm_shape, npu_block_type, concat_axis, concat_offset, @@ -381,11 +379,15 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs): for cmd in generate_high_level_command_stream_for_pass_list(strat, passes, block_configs): if cmd.is_npu_pass_command(): if cmd.is_first: - ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.start_coord, is_top_box=False) + ifm_read = cmd.ifm_tensor.address_offset_for_coordinate( + cmd.ifm_box.start_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=False + ) if ifm_read is None: return 0 if cmd.is_last: - write_offset = cmd.ofm_tensor.address_offset_for_coordinate(cmd.ofm_box.end_coord, is_top_box=True) + write_offset = cmd.ofm_tensor.address_offset_for_coordinate( + cmd.ofm_box.end_coord, shape=cmd.ps.ofm_shapes[0], is_top_box=True + ) if write_offset is None: return 0 highest_ofm_write = max(write_offset, highest_ofm_write) @@ -396,7 +398,9 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs): min_overlap = min(min_overlap, can_overwrite) if cmd.is_first: - ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.end_coord, is_top_box=True) + ifm_read = cmd.ifm_tensor.address_offset_for_coordinate( + cmd.ifm_box.end_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=True + ) min_overlap = max(min_overlap, 0) return min_overlap diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 096a65cc..9380374e 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -231,7 +231,7 @@ def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]: return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point) -def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> NpuFeatureMap: +def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: List[int]) -> NpuFeatureMap: """Creates feature map with common fields populated""" fm = NpuFeatureMap() fm.region = get_region(tens, arch) @@ -242,7 +242,7 @@ def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> Np fm.layout = NpuLayout.NHCWB16 else: assert 0, "Incorrect tensor format" - height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord) + height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord, fm_shape) for idx, addr in enumerate(addresses): if addr is None: addresses[idx] = 0 @@ -326,12 +326,12 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box) - npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch) + npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0]) npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth) npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor) out_block = cmd.ofm_box.get_block() - npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch) + npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0]) npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth) npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor) @@ -397,13 +397,15 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}" elemwise_op = elementwise_op_map[op.type] npu_op = NpuElementWiseOperation(elemwise_op) + if elemwise_op not in UNARY_ELEMWISE_OPS: if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape): # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box + ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0] npu_op.reversed_operands = True - npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch) + npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1]) npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor) if cmd.ifm2_tensor.shape == []: # scalar diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py index fc1e7986..3797f43e 100644 --- a/ethosu/vela/insert_dma.py +++ b/ethosu/vela/insert_dma.py @@ -72,7 +72,7 @@ def insert_dma_cmd(op, arch, nng): tens.purpose == TensorPurpose.FeatureMap and op.type.is_binary_elementwise_op() and tens.shape != [] - and tens.shape != op.outputs[0].shape + and op.ifm_shapes[0] != op.ofm_shapes[0] and tens.storage_size() > max_ifm_shram_avail ): only_vector_product_consumers = True diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 14e83a33..0cc89e27 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -181,12 +181,12 @@ def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_s inps.append(elem_op.ifm2) if len(inps) > 0: - for inp in inps: + for i, inp in enumerate(inps): # check input format, dtype, broadcasting or if there are more input consumers if ( inp.format == elem_op.ofm.format and inp.dtype == elem_op.ofm.dtype - and inp.shape == elem_op.ofm.shape + and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0] and (len(inp.consumer_list) == 1 and len(inp.ops) == 1) ): lr_graph.fuse_ranges(inp, elem_op.ofm) diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 0ae3de9a..67925176 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -58,6 +58,8 @@ class Pass: self.name = name self.cascade = None self.placement = placement + self.ifm_shapes = [] + self.ofm_shapes = [] # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap. diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py index 9d83f6fb..c2ec4424 100644 --- a/ethosu/vela/npu_performance.py +++ b/ethosu/vela/npu_performance.py @@ -48,7 +48,7 @@ def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_conf if ps2.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct): op = ps2.primary_op - ifm_block_depth = arch.calc_ifm_block_depth(op.ifm.shape[-1], op.ifm.dtype.size_in_bits()) + ifm_block_depth = arch.calc_ifm_block_depth(op.ifm_shapes[0][-1], op.ifm.dtype.size_in_bits()) else: ifm_block_depth = block_config_ps2[-1] @@ -224,8 +224,8 @@ def estimate_conv_pooling_cycles( scale_tensor=None, ): ofm_ublock = Block(arch.config.ofm_ublock.width, arch.config.ofm_ublock.height, arch.config.ofm_ublock.depth) - ifm_tens_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1) - ofm_tens_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1) + ifm_tens_shape = primary_op.ifm_shapes[0] + ofm_tens_shape = primary_op.ofm_shapes[0] if ( arch.config.ofm_ublock.height == 2 @@ -420,8 +420,8 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None, npu_block_type = primary_op.type.npu_block_type ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm() - ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1) - ofm_tensor_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1) + ifm_tensor_shape = list(ps.primary_op.ifm_shapes[0]) + ofm_tensor_shape = list(ps.primary_op.ofm_shapes[0]) if npu_block_type == NpuBlockType.ReduceSum: block_traversal = TensorBlockTraversal.DepthFirst diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 30c32acc..be26a26b 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING from .errors import VelaError from .numeric_util import full_shape + if TYPE_CHECKING: from .tensor import Tensor @@ -129,7 +130,7 @@ class Op(Enum): Concat = OperatorInfo(indices=CONCAT_INDICES) ConcatEmbeddings = OperatorInfo() ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES) - ConcatTFLite = OperatorInfo() + ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES) Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES) Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES) @@ -197,7 +198,7 @@ class Op(Enum): NonMaxSuppressionV5 = OperatorInfo() NotEqual = OperatorInfo() OneHot = OperatorInfo() - Pack = OperatorInfo() + Pack = OperatorInfo(indices=IFM_INDICES) PackReshaped = OperatorInfo(indices=IFM_INDICES) Pad = OperatorInfo() PadV2 = OperatorInfo() @@ -260,7 +261,7 @@ class Op(Enum): UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) Unique = OperatorInfo() - Unpack = OperatorInfo() + Unpack = OperatorInfo(indices=IFM_INDICES) UnpackReshaped = OperatorInfo(indices=IFM_INDICES) Where = OperatorInfo() While = OperatorInfo() @@ -305,14 +306,17 @@ class Op(Enum): return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT) def is_split_op(self): - return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped) + return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack) def is_concat_op(self): - return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped) + return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack) def needs_bias(self): return bool(self.info.indices.biases) + def needs_shapes(self): + return bool(self.info.indices.ifms) + @classmethod def op_set(cls, predicate): # Returns the set of all operator codes that fulfill the given predicate @@ -400,6 +404,8 @@ class Operation: "forced_output_quantization", "activation_lut", "_kernel", + "ifm_shapes", + "ofm_shapes", ) def __init__(self, op_type: Op, name: str): @@ -421,6 +427,8 @@ class Operation: self.op_index = None # input network operator index self.activation_lut = None self._kernel = None + self.ifm_shapes = [] + self.ofm_shapes = [] def clone(self, suffix="_clone"): res = Operation(self.type, self.name + suffix) @@ -697,3 +705,35 @@ class Operation: lines += _print_tensors(self.outputs) raise VelaError("\n".join(lines)) + + def set_ifm_ofm_shapes(self): + ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm() + + # set all shapes to op, as 4D + if self.type == Op.FullyConnected: + n_in_elems = weight_tensor.shape[-2] + elms = ifm_tensor.elements() + batch_size = elms // n_in_elems + assert batch_size * n_in_elems == elms + + self.ifm_shapes.append([batch_size, 1, 1, n_in_elems]) + self.ofm_shapes.append(ofm_tensor.get_full_shape()) + elif self.type == Op.Softmax: + self.ifm_shapes.append(ifm_tensor.get_full_shape()) + self.ofm_shapes.append(ofm_tensor.get_full_shape()) + elif self.type.is_split_op or self.type.is_concat_op(): + for inp in self.inputs: + if inp is not None: + self.ifm_shapes.append(full_shape(4, inp.shape, 1)) + else: + self.ifm_shapes.append(None) + for out in self.outputs: + if out is not None: + self.ofm_shapes.append(full_shape(4, out.shape, 1)) + else: + self.ofm_shapes.append(None) + else: + self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1)) + if ifm2_tensor is not None: + self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1)) + self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1)) diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py index a267b2ad..a55b9548 100644 --- a/ethosu/vela/operation_util.py +++ b/ethosu/vela/operation_util.py @@ -61,6 +61,7 @@ def create_depthwise_maxpool( ofm = Tensor([1, height, 1, 1], ifm.dtype, op.name + "_tens0") ofm.quantization = quantization op.set_output_tensor(ofm) + op.set_ifm_ofm_shapes() return op @@ -81,6 +82,7 @@ def create_reduce_sum( sum_of_exp = Tensor(ofm_shape, DataType.int32, op.name + "_tens0") sum_of_exp.quantization = quantization op.set_output_tensor(sum_of_exp) + op.set_ifm_ofm_shapes() return op @@ -190,4 +192,5 @@ def create_binary_elementwise( ofm = Tensor(ofm_shape, dtype, f"{op.name}_tens0") ofm.quantization = quantization op.set_output_tensor(ofm) + op.set_ifm_ofm_shapes() return op diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 9bc04f29..095a78d4 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -397,11 +397,28 @@ def pack_into_passes(nng, arch, verbose_packing=False): if len(ps.inputs) > 2: ps.ifm_tensor = ps.inputs[-2] + + # Get the corresponding ifm_shapes + for op in input_ops_list + [primary_op]: + if ps.ifm_tensor == op.ifm: + ps.ifm_shapes.append(op.ifm_shapes[0]) + elif ps.ifm_tensor == op.ifm2: + ps.ifm_shapes.append(op.ifm_shapes[1]) + for op in input_ops_list + [primary_op]: + if ps.ifm2_tensor == op.ifm: + ps.ifm_shapes.append(op.ifm_shapes[0]) + elif ps.ifm2_tensor == op.ifm2: + ps.ifm_shapes.append(op.ifm_shapes[1]) else: ps.ifm_tensor = ifm_tensor ps.ifm2_tensor = None + if ps.primary_op is not None: + ps.ifm_shapes.append(ps.primary_op.ifm_shapes[0]) ps.ofm_tensor = ofm_tensor + if ps.primary_op is not None: + ps.ofm_shapes.append(ps.primary_op.ofm_shapes[0]) + assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None ps.weight_tensor = ps.get_primary_op_ifm_weights()[1] ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2] @@ -436,6 +453,8 @@ def pack_into_passes(nng, arch, verbose_packing=False): avgpool_out = inp.clone("_avgpooled") avgpool_out.consumer_list.append(op) avgpool_op.set_output_tensor(avgpool_out) + avgpool_op.ifm_shapes = op.ifm_shapes + avgpool_op.ofm_shapes = op.ofm_shapes op.inputs[0] = avgpool_out op_list.insert(0, avgpool_op) diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 2c10640b..6cbff500 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -34,7 +34,6 @@ from .npu_performance import make_bandwidth_array from .npu_performance import make_cycles_array from .npu_performance import make_metrics_arrays from .npu_performance import PassCycles -from .numeric_util import full_shape from .operation import NpuBlockType from .operation import Op from .operation import Operation @@ -188,7 +187,7 @@ class StrategySet: def __eq__(self, other): if (self.bws != other.bws).any(): return False - if (self.macs != other.macs).any(): + if self.macs != other.macs: return False if (self.cycles != other.cycles).any(): return False @@ -1000,10 +999,8 @@ class DynamicProgrammingScheduler: rewrites.extend(get_rewrites(op)) # Detect no-op reshapes by comparing their full input and output tensor shapes. - inshape = full_shape(4, op.inputs[0].shape, 1) - compatible_shape = [ - (inshape == full_shape(4, oper.outputs[0].shape, 1)) for oper in get_rewrites(op) - ] + inshape = op.ifm_shapes[0] + compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)] use_NHCWB16 = compatible_shape and all(compatible_shape) else: use_NHCWB16 = False diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py index 600b3170..1f027d60 100644 --- a/ethosu/vela/shared_buffer_allocation.py +++ b/ethosu/vela/shared_buffer_allocation.py @@ -193,15 +193,16 @@ def shared_buffer_allocation_for_pass(arch, ps) -> SharedBufferAllocation: if ifm_tensor: ifm_resampling_mode = ifm_tensor.resampling_mode ifm_bits = ifm_tensor.dtype.size_in_bits() + ifm_shape = ps.primary_op.ifm_shapes[0] - if ifm_tensor.shape != []: - ifm_depth = ifm_tensor.shape[-1] + if ifm_shape != []: + ifm_depth = ifm_shape[-1] if is_elementwise: ifm_count = 2 if ifm_tensor.shape == []: # Scalar in ifm1 assert ifm2_tensor - ifm_depth = ifm2_tensor.shape[-1] + ifm_depth = ps.primary_op.ifm_shapes[1][-1] ifm_count = 1 elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2 ifm_count = 1 @@ -215,7 +216,7 @@ def shared_buffer_allocation_for_pass(arch, ps) -> SharedBufferAllocation: ifm_bits=ifm_bits, ifm_depth=ifm_depth, ifm_count=ifm_count, - ofm_shape=ofm_tensor.shape, + ofm_shape=ps.primary_op.ofm_shapes[0], ) diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 8b061297..98496539 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -213,7 +213,7 @@ class SoftMax: ofm = self.op.outputs[0] # Reshape ifm/ofm (if needed) - full_shape = ifm.get_full_shape() + full_shape = self.op.ifm_shapes[0] if full_shape[0] > 1: full_shape[1] *= full_shape[0] full_shape[0] = 1 @@ -230,9 +230,6 @@ class SoftMax: def get_graph_8bit(self, ifm, ofm): exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32) - ifm = create_reshape_tensor(ifm, ifm.get_full_shape()) - DebugDatabase.add_optimised(self.op, ifm.ops[0]) - ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False) no_scale_quant = ifm.quantization.clone() no_scale_quant.scale_f32 = None no_scale_quant.zero_point = 0 diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 69618d2c..df8f8868 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -37,6 +37,7 @@ from .data_type import DataType from .errors import UnsupportedFeatureError from .errors import VelaError from .ethos_u55_regs.ethos_u55_regs import resampling_mode +from .numeric_util import full_shape from .operation import Op from .operation import Operation @@ -322,6 +323,8 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True): reshape_op.add_input_tensor(reshape_ifm) reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) reshape_op.set_output_tensor(reshape_ofm) + reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1)) + reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1)) return reshape_ofm if ifm_reshape else reshape_ifm @@ -605,20 +608,20 @@ class Tensor: def consumers(self) -> List[Operation]: return self.consumer_list - def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape) -> Tuple: + def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple: # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] ) - if len(start_coord) < 4: - box_height0 = 1 - box_width = 1 - - if len(start_coord) >= 2: - box_width = end_coord[-2] - start_coord[-2] - - return box_height0, box_height0, box_width, [self.address_for_coordinate(start_coord), None, None, None] + if self.storage_shape == []: + return ( + 1, + 1, + 1, + [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None], + ) - crossing_y = numeric_util.round_up(start_coord[1] + 1, self.storage_shape[1]) - crossing_x = numeric_util.round_up(start_coord[2] + 1, self.storage_shape[2]) + storage_shape_4D = full_shape(4, self.storage_shape, 1) + crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D[1]) + crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D[2]) crossing_y = min(crossing_y, end_coord[1]) crossing_x = min(crossing_x, end_coord[2]) @@ -627,20 +630,28 @@ class Tensor: box_width = crossing_x - start_coord[2] addresses: List = [None] * 4 - addresses[0] = self.address_for_coordinate(start_coord) + addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape) if end_coord[2] > crossing_x: - addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]]) + addresses[1] = self.address_for_coordinate( + [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape + ) raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: - addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]]) + addresses[2] = self.address_for_coordinate( + [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape + ) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: - addresses[3] = self.address_for_coordinate([start_coord[0], crossing_y, crossing_x, start_coord[3]]) + addresses[3] = self.address_for_coordinate( + [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape + ) return box_height0, box_height0, box_width, addresses - def address_for_coordinate(self, coord: Shape, is_top_box: bool = False) -> int: - offset = self.address_offset_for_coordinate(coord, is_top_box) + def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, shape: Shape = None) -> int: + if shape is None: + shape = self.shape + offset = self.address_offset_for_coordinate(coord, shape, is_top_box) assert offset is not None return self.address + offset @@ -752,18 +763,18 @@ class Tensor: assert 0 <= index < len(self.compressed_values) return index == len(self.compressed_values) - 1 - def address_offset_for_coordinate(self, orig_coord: Shape, is_top_box: bool = False) -> Optional[int]: + def address_offset_for_coordinate(self, orig_coord: Shape, shape: Shape, is_top_box: bool = False) -> Optional[int]: address_offset = 0 coord = orig_coord coord = coord[-len(self.storage_shape) :] if self.sub_purpose == TensorSubPurpose.Standard: - for idx, c in enumerate(coord): + for idx, c in enumerate(orig_coord): if is_top_box: - assert c > 0 and c <= self.shape[idx] + assert c > 0 and c <= shape[idx] else: - assert c >= 0 and c < self.shape[idx] + assert c >= 0 and c < shape[idx] if self.format == TensorFormat.WeightsCompressed: if len(self.weight_compressed_offsets) == 0: @@ -830,7 +841,7 @@ class Tensor: def get_full_shape(self) -> Shape: d = len(self.shape) if d in (1, 3): - return numeric_util.full_shape(4, self.shape, 1) + return full_shape(4, self.shape, 1) elif d == 2: return [self.shape[0], 1, 1, self.shape[1]] else: diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py index 62a1b763..45377417 100644 --- a/ethosu/vela/test/test_graph_optimiser.py +++ b/ethosu/vela/test/test_graph_optimiser.py @@ -32,9 +32,16 @@ def test_convert_batched_fc(): weights = create_const_tensor("weight_in", shape, np.uint8, np.zeros(shape)) ofm = Tensor(ifm.shape, np.uint8, "test_out") op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm) + ifm.consumer_list.append(op) + op.ifm_shapes.append([4, 1, 1, 8]) + op.ofm_shapes.append([4, 1, 1, 8]) + prev_op = op.clone() + prev_op.ifm_shapes = op.ifm_shapes + prev_op.ofm_shapes = op.ofm_shapes + conv_op = convert_batched_fc_shape(op, None, None) assert conv_op.ifm != prev_op.ifm @@ -51,7 +58,13 @@ def test_convert_batched_fc(): op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm) ifm.consumer_list.append(op) + op.ifm_shapes.append([1, 1, 1, 8]) + op.ofm_shapes.append([1, 1, 1, 8]) + prev_op = op.clone() + prev_op.ifm_shapes = op.ifm_shapes + prev_op.ofm_shapes = op.ofm_shapes + conv_op = convert_batched_fc_shape(op, None, None) assert conv_op.ifm == prev_op.ifm diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py index 9ba39bc5..63f841b4 100644 --- a/ethosu/vela/test/testutil.py +++ b/ethosu/vela/test/testutil.py @@ -69,6 +69,8 @@ def create_elemwise_op( ofm = Tensor(ofm_shape, datatype, name + "_ofm") ofm.quantization = ofm_quant op.set_output_tensor(ofm) + op.set_ifm_ofm_shapes() + return op @@ -104,6 +106,8 @@ def create_op_with_quant_tensors( qp.zero_point = np.zeros(bias_shape) bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), np.int32, quantization=qp) op.add_input_tensor(bias) + + op.set_ifm_ofm_shapes() return op @@ -113,6 +117,7 @@ def create_op(op_type, inputs, output, attrs=None): op.outputs = [output] if attrs is not None: op.attrs = attrs + op.set_ifm_ofm_shapes() return op -- cgit v1.2.1