From 3010d9b5c90628e07c7d0f0c33e7355b8bc3e19d Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 1 Oct 2020 08:22:10 +0200 Subject: MLBEDSW-3060 Adjust check if weights fit in sram When deciding if weights fit sram: A compression of the weights has been added when a weight compression test limit makes it impossible to fit weights in a double buffer in sram. The worst compression ratio from compression, is used to decide if weights can be fit in sram. Signed-off-by: Patrik Gustavsson Change-Id: I9458769866b3f9fc15659185aae09658ed10fb38 --- ethosu/vela/graph_optimiser.py | 64 ++++++++++++++++++++-------------------- ethosu/vela/insert_dma.py | 44 +++++++++++++++++---------- ethosu/vela/mark_tensors.py | 4 +-- ethosu/vela/rewrite_graph.py | 6 ++-- ethosu/vela/weight_compressor.py | 1 - 5 files changed, 65 insertions(+), 54 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index e7c15cdc..4f435dcb 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -68,14 +68,14 @@ activation_ops = set(("Sigmoid", "Tanh")) | relu_ops memory_only_ops = set(("Reshape",)) -def remove_passthrough_tensor(tens, arch): +def remove_passthrough_tensor(tens, arch, nng): if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes: assert len(tens.ops[0].inputs) == 1 tens = tens.ops[0].inputs[0] return tens -def rewrite_concat(tens, arch): +def rewrite_concat(tens, arch, nng): if len(tens.ops) == 1 and tens.ops[0].is_concat_op(): concat_op = tens.ops[0] if tens != concat_op.outputs[0]: @@ -114,7 +114,7 @@ def rewrite_concat(tens, arch): return tens -def rewrite_split(tens, arch): +def rewrite_split(tens, arch, nng): if len(tens.ops) == 1 and tens.ops[0].is_split_op(): split_op = tens.ops[0] @@ -205,7 +205,7 @@ def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dim return padding, skirt -def fixup_conv2d_backprop(op, arch): +def fixup_conv2d_backprop(op, arch, nng): if op.type == "Conv2DBackpropInput": # flip the inputs op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0] @@ -295,7 +295,7 @@ def convert_resizebilinear_to_2x2_pool(op): return op -def fixup_resizebilinear(op, arch): +def fixup_resizebilinear(op, arch, nng): if op.type == "ResizeBilinear" and op.run_on_npu: if op.inputs[0].shape == op.outputs[0].shape: # Bypass nop resizebilinear @@ -309,7 +309,7 @@ def fixup_resizebilinear(op, arch): return op -def convert_nop_split_to_identity(op, arch): +def convert_nop_split_to_identity(op, arch, nng): if op.type == "Split" and op.attrs.get("num_splits") == 1: # the list comprehension should return a list with a single tensor # if it shouldn't, remove_passthrough_tensor will fail appropriately @@ -318,7 +318,7 @@ def convert_nop_split_to_identity(op, arch): return op -def fixup_fully_connected_input(op, arch): +def fixup_fully_connected_input(op, arch, nng): if op.type == "FullyConnectedAct": inp = op.inputs[0] weights = op.inputs[1] @@ -336,7 +336,7 @@ def fixup_fully_connected_input(op, arch): return op -def convert_batched_fc_to_conv(op, arch): +def convert_batched_fc_to_conv(op, arch, nng): if op.type == "FullyConnectedAct": ifm = op.inputs[0] ofm = op.outputs[0] @@ -407,7 +407,7 @@ def convert_batched_fc_to_conv(op, arch): return op -def fixup_pack_input(op, arch): +def fixup_pack_input(op, arch, nng): if op.type == "Pack": # Pack is also referred to as Stack # Requires the rewrite_concat function to be called on the op afterwards @@ -433,7 +433,7 @@ def fixup_pack_input(op, arch): return op -def unfuse_activation_function(op, arch): +def unfuse_activation_function(op, arch, nng): unfuse_ops = ("ConcatTFLite",) if op.type in unfuse_ops and op.run_on_npu and op.attrs.get("fused_activation_function", None) is not None: act = op.attrs["fused_activation_function"] @@ -448,7 +448,7 @@ def unfuse_activation_function(op, arch): return op -def fixup_unpack_output(tens, arch): +def fixup_unpack_output(tens, arch, nng): op = tens.ops[0] if op.type in set(("Unpack", "StridedSlice")): # Unpack is also referred to as Unstack @@ -515,7 +515,7 @@ def fixup_unpack_output(tens, arch): return tens -def add_padding_fields(op, arch): +def add_padding_fields(op, arch, nng): if op.run_on_npu: if "padding" in op.attrs: if op.type in conv_op | depthwise_op: @@ -564,7 +564,7 @@ def get_prepend_op(op): return None -def mark_npu_block_type(op, arch): +def mark_npu_block_type(op, arch, nng): npu_block_type = NpuBlockType.Default if op.type in conv_op: npu_block_type = NpuBlockType.ConvolutionMxN @@ -583,7 +583,7 @@ def mark_npu_block_type(op, arch): return op -def convert_depthwise_to_conv(op, arch): +def convert_depthwise_to_conv(op, arch, nng): # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and # the ofm depth equals the depth multipler. # If those conditions are true, then we can perform a simple @@ -610,7 +610,7 @@ def convert_depthwise_to_conv(op, arch): return op -def reorder_depthwise_weights(op, arch): +def reorder_depthwise_weights(op, arch, nng): if op.type in depthwise_op: weight_tensor = op.inputs[1] weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2)) @@ -620,7 +620,7 @@ def reorder_depthwise_weights(op, arch): return op -def convert_conv_to_fc(op, arch): +def convert_conv_to_fc(op, arch, nng): # Conv 1x1 can be equivalent to Fully Connected. # By representing certain convs as fully connected layers, Vela can better determine wether or not to use # caching/double buffering for the weights. @@ -661,7 +661,7 @@ def convert_conv_to_fc(op, arch): return op -def fixup_relus_with_differing_ifm_ofm_scaling(op, arch): +def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng): if op.run_on_npu and op.type in relu_ops: ifm = op.inputs[0] ofm = op.outputs[0] @@ -690,7 +690,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch): # Reorder activation op if it's after the memory only operations -def fixup_act_reorder(op, arch): +def fixup_act_reorder(op, arch, nng): if op.type in activation_ops: prep_op = get_prepend_op(op) if prep_op is not None: @@ -715,7 +715,7 @@ def fixup_act_reorder(op, arch): return op -def fixup_elementwise_with_scalars(op, arch): +def fixup_elementwise_with_scalars(op, arch, nng): if op.type in binary_elementwise_op: ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm() if ifm2_tensor.shape != [] and ifm_tensor.shape != []: @@ -736,7 +736,7 @@ def fixup_elementwise_with_scalars(op, arch): # Set input/output tensor equivalence to the same id for memory operations -def set_tensor_equivalence(op, arch): +def set_tensor_equivalence(op, arch, nng): if op.type in memory_only_ops: eid = op.outputs[0].equivalence_id for inp in op.inputs: @@ -744,14 +744,14 @@ def set_tensor_equivalence(op, arch): return op -def convert_softmax(op, arch): +def convert_softmax(op, arch, nng): if op.type == "Softmax" and op.run_on_npu: softmax = SoftMax(op) op = softmax.get_graph() return op -def convert_mul_max_to_abs_or_lrelu(op, arch): +def convert_mul_max_to_abs_or_lrelu(op, arch, nng): r"""Whenever there is a subgraph with this topology: Input X For X = -1 or X > 0 @@ -958,7 +958,7 @@ def convert_lrelu_to_lut(op, arch): return convert_to_lut(op, values) -def convert_lrelu(op, arch): +def convert_lrelu(op, arch, nng): # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max if op.type != "LeakyRelu": return op @@ -972,7 +972,7 @@ def convert_lrelu(op, arch): return convert_lrelu_to_mul_max(op, arch) -def convert_tanh_sigmoid_to_lut(op, arch): +def convert_tanh_sigmoid_to_lut(op, arch, nng): # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution if op.type == "Sigmoid": return convert_to_lut8(op, clamp_sigmoid) @@ -981,7 +981,7 @@ def convert_tanh_sigmoid_to_lut(op, arch): return op -def remove_unwanted_reshapes(op, arch): +def remove_unwanted_reshapes(op, arch, nng): # Try to remove reshapes enclosing ElementWise operator with only one non-constant input if not op.run_on_npu or op.attrs["npu_block_type"] != NpuBlockType.ElementWise: return op @@ -1016,7 +1016,7 @@ def remove_unwanted_reshapes(op, arch): return op -def fuse_activation_function_with_prev(op, arch): +def fuse_activation_function_with_prev(op, arch, nng): # if op is a no-op: attempts to move the activation function to the preceding op if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None: return op @@ -1049,7 +1049,7 @@ def fuse_activation_function_with_prev(op, arch): return op -def add_attrs_to_resizebilinear(op, arch): +def add_attrs_to_resizebilinear(op, arch, nng): if op.type == "ResizeBilinear" and op.run_on_npu: input_tensor = op.inputs[0] upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2] @@ -1069,7 +1069,7 @@ def add_attrs_to_resizebilinear(op, arch): return op -def fixup_bias_tensors(op, arch): +def fixup_bias_tensors(op, arch, nng): if op.needs_bias() and not op.inputs[-1]: # Op has no bias, add bias tensor filled with zeros nr_biases = op.inputs[1].shape[-1] @@ -1081,7 +1081,7 @@ def fixup_bias_tensors(op, arch): return op -def supported_operator_check(op, arch): +def supported_operator_check(op, arch, nng): op.run_on_npu = arch.supported_operators.is_operator_supported(op) return op @@ -1121,13 +1121,13 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False + nng, sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False ) for idx, sg in enumerate(nng.subgraphs): # remove passthrough tensors and attempt further optimizations nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields] + nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields] ) if verbose_graph: @@ -1141,7 +1141,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [rewrite_concat, rewrite_split], []) + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [rewrite_concat, rewrite_split], []) if verbose_graph: nng.print_graph() diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py index 9304526a..99b46c07 100644 --- a/ethosu/vela/insert_dma.py +++ b/ethosu/vela/insert_dma.py @@ -21,12 +21,13 @@ from .operation import Operation from .tensor import MemArea from .tensor import MemType from .tensor import TensorPurpose +from .weight_compressor import compress_weights binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum")) -def weights_fit_sram(arch, tens): +def weights_fit_sram(arch, op, tens, nng): if tens.purpose != TensorPurpose.Weights: return True @@ -36,25 +37,33 @@ def weights_fit_sram(arch, tens): elif len(tens.shape) == 2: min_weight_size = tens.shape[0] * arch.OFMSplitDepth - w_compression = 1 # TODO worst compression ratio currently assumed - # Need to be fit into Sram, as a double buffer - if (w_compression * min_weight_size * 2) > arch.sram_size: - print( - "Weights, {}, are too big to be DMAed to SRAM, estimated minimum size is {} bytes".format( - tens.name, (w_compression * min_weight_size * 2) + # Only evaluate when the compression test limit will make it impossible to fit + w_comp_test_limit = 2 + if (w_comp_test_limit * min_weight_size * 2) > arch.sram_size: + # check worst compression ratio + npu_block_type = op.attrs.get("npu_block_type", NpuBlockType.Default) + compress_weights(arch, nng, tens, npu_block_type, 16, 16, op.get_dilation_h_w()) + + worst_buffer_size = tens.compression_scale_for_worst_weight_stream * min_weight_size * 2 + if worst_buffer_size > arch.sram_size: + print( + "Weights, {}, are too big to be DMAed to SRAM, estimated minimum size is {} bytes".format( + tens.name, worst_buffer_size + ) ) - ) - return False + return False return True -def insert_dma_cmd(op, arch): +def insert_dma_cmd(op, arch, nng): if op.type == "DMA" or not op.run_on_npu: return op - is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in op.inputs) - max_ifm_shram_avail = (arch.available_shram_banks(is_lut_used) - arch.shram_reserved_output_banks) * arch.shram_bank_size // 2 + is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in op.inputs) + max_ifm_shram_avail = ( + (arch.available_shram_banks(is_lut_used) - arch.shram_reserved_output_banks) * arch.shram_bank_size // 2 + ) for idx, tens in enumerate(op.inputs): @@ -66,8 +75,11 @@ def insert_dma_cmd(op, arch): and arch.permanent_storage_mem_area != arch.fast_storage_mem_area ) or tens.purpose == TensorPurpose.LUT: if tens.purpose in (TensorPurpose.Weights, TensorPurpose.LUT) or ( - tens.purpose == TensorPurpose.FeatureMap and op.type in binary_elementwise_op and - tens.shape != [] and tens.shape != op.outputs[0].shape and tens.storage_size() > max_ifm_shram_avail + tens.purpose == TensorPurpose.FeatureMap + and op.type in binary_elementwise_op + and tens.shape != [] + and tens.shape != op.outputs[0].shape + and tens.storage_size() > max_ifm_shram_avail ): only_vector_product_consumers = True for oper in tens.consumers(): @@ -79,7 +91,7 @@ def insert_dma_cmd(op, arch): # Other operations re-reads tensors, this is better done from SRAM. # LUTs must be placed in the last 2 blocks of SHRAM. if ( - not only_vector_product_consumers and weights_fit_sram(arch, tens) + not only_vector_product_consumers and weights_fit_sram(arch, op, tens, nng) ) or tens.purpose == TensorPurpose.LUT: # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size. new_tens = tens.clone_into_fast_storage(arch) @@ -98,7 +110,7 @@ def insert_dma_cmd(op, arch): def insert_dma_commands(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [insert_dma_cmd]) + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [insert_dma_cmd]) if verbose_graph: nng.print_graph() return nng diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py index a971ef23..c4496cdc 100644 --- a/ethosu/vela/mark_tensors.py +++ b/ethosu/vela/mark_tensors.py @@ -266,7 +266,7 @@ def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False): ) # special case constants, as they must be in permanent storage tens.mem_type = MemType.Permanent_NPU - def rewrite_mark_tensor_purpose(op, arch): + def rewrite_mark_tensor_purpose(op, arch, nng): # find disconnected outputs and mark as parameters for tens in op.outputs: if not tens.consumers(): @@ -308,7 +308,7 @@ def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False): return op for sg in nng.subgraphs: - sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose]) + sg = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [rewrite_mark_tensor_purpose]) for tens in sg.output_tensors: mark_tensor_helper(tens, TensorPurpose.FeatureMap) diff --git a/ethosu/vela/rewrite_graph.py b/ethosu/vela/rewrite_graph.py index e76e9617..e71b228a 100644 --- a/ethosu/vela/rewrite_graph.py +++ b/ethosu/vela/rewrite_graph.py @@ -24,7 +24,7 @@ # Post-order traversal, this does not support rewrites. Therefore, functions must return the original value. -def rewrite_graph_pre_order(sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True): +def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True): op_visit_dict = dict() tens_visit_dict = dict() @@ -38,7 +38,7 @@ def rewrite_graph_pre_order(sg, arch, tensor_rewrite_list, op_rewrite_list, rewr prev_res = res for rewrite in op_rewrite_list: if res.run_on_npu or rewrite_unsupported: - res = rewrite(res, arch) + res = rewrite(res, arch, nng) op_visit_dict[op] = res op_visit_dict[res] = res @@ -64,7 +64,7 @@ def rewrite_graph_pre_order(sg, arch, tensor_rewrite_list, op_rewrite_list, rewr while prev_res != res: prev_res = res for rewrite in tensor_rewrite_list: - res = rewrite(res, arch) + res = rewrite(res, arch, nng) tens_visit_dict[tens] = res tens_visit_dict[res] = res diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index c5a3f3fd..8426705a 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -280,7 +280,6 @@ def core_deinterleave(hwio, core, ncores): # Compress the weights def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation): assert tens.purpose == TensorPurpose.Weights - assert tens.format == TensorFormat.WeightsCompressed # Check the weight cache if nng.weight_cache is None: -- cgit v1.2.1