From 3c07c97e0202c1cf01eba06c24b37a8f15ff7a7c Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Thu, 7 May 2020 08:12:58 +0200 Subject: MLBEDSW-1941: Bug fix shared weights If same weight tensor was used with different block configs, errors would occur. Fixed by always cloning weight tensors, using a global weight compression cache and modifying the linear allocator to detect multiple usage of same weight compression. Change-Id: I91ca59176e1c59c66e0ac7a4227f2b5f0b47053f Signed-off-by: Louis Verhaard --- ethosu/vela/compiler_driver.py | 5 +- ethosu/vela/high_level_command_stream_generator.py | 8 +- ethosu/vela/mark_tensors.py | 17 +-- ethosu/vela/nn_graph.py | 1 + ethosu/vela/tensor.py | 34 ++++-- ethosu/vela/tensor_allocation.py | 16 ++- ethosu/vela/tflite_reader.py | 34 +++--- ethosu/vela/weight_compressor.py | 127 ++++++++++++--------- 8 files changed, 138 insertions(+), 104 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py index 64aff06b..b6a98a64 100644 --- a/ethosu/vela/compiler_driver.py +++ b/ethosu/vela/compiler_driver.py @@ -144,13 +144,14 @@ def compiler_driver(nng, arch, options, scheduler_options): # processed first during serialization into tensors first_npu_sg = nng.subgraphs[1] assert first_npu_sg.placement == PassPlacement.Npu + # Use the linear allocator for constant tensors tensor_allocation.allocate_tensors( nng, first_npu_sg, arch, permanent_storage, scheduler_options.use_ifm_ofm_overlap, - options.tensor_allocator, + TensorAllocator.LinearAlloc, options.verbose_allocation, options.show_minimum_possible_allocation, lr_graph_flash, @@ -195,7 +196,7 @@ def compiler_driver(nng, arch, options, scheduler_options): arch, permanent_storage, scheduler_options.use_ifm_ofm_overlap, - options.tensor_allocator, + TensorAllocator.LinearAlloc, options.verbose_allocation, options.show_minimum_possible_allocation, ) diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 0cc70a7f..3b968dc8 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -27,12 +27,8 @@ from .operation import NpuBlockType from .tensor import TensorPurpose -def need_dma(tens): - return len(tens.ops) == 1 and tens.ops[0].type == "DMA" - - def dma_if_necessary(ps, box, tensor): - if need_dma(tensor): + if tensor.needs_dma(): dma_op = tensor.ops[0] in_tensor = dma_op.inputs[0] yield DMA(in_tensor, tensor, box) @@ -93,7 +89,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id if strat == SchedulingStrategy.WeightStream: ofm_step = block_config[-1] ofm_stop = ofm_end[-1] - if weight_tensor is None or not need_dma(weight_tensor): + if weight_tensor is None or not weight_tensor.needs_dma(): ofm_step = ofm_stop for start in range(ofm_start[-1], ofm_stop, ofm_step): end = min(start + ofm_step, ofm_stop) diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py index e7b3e50f..cd70446b 100644 --- a/ethosu/vela/mark_tensors.py +++ b/ethosu/vela/mark_tensors.py @@ -17,7 +17,6 @@ # Mark purpose and select formats for Tensors. Also compresses the weights. from . import rewrite_graph from . import weight_compressor -from .architecture_features import Block from .operation import NpuBlockType from .tensor import TensorFormat from .tensor import TensorPurpose @@ -348,18 +347,12 @@ def mark_tensor_format(nng, arch, verbose_tensor_format=False): for tens, fmt in formats_for_tensor.items(): tens.set_format(fmt, arch) if fmt == TensorFormat.WeightsCompressed and tens.values is not None: - npu_block_type = find_npu_usage_of_tensor(tens) - if len(tens.ops) == 1 and tens.ops[0].type == "DMA": - weight_compressor.compress_weights(tens, arch, npu_block_type, Block(32, 32, 32), 32) + src_tens = tens.get_dma_src_tensor() + if src_tens is not None: + npu_block_type = find_npu_usage_of_tensor(tens) + weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 32, 32) # Alias compressed weights back into source tensor - src_tens = tens.ops[0].inputs[0] - src_tens.compressed_values = tens.compressed_values - src_tens.storage_shape = tens.storage_shape - src_tens.brick_size = tens.brick_size - src_tens.weight_compression_scales = tens.weight_compression_scales - src_tens.weight_compressed_offsets = tens.weight_compressed_offsets - src_tens.compression_scale_for_worst_weight_stream = tens.compression_scale_for_worst_weight_stream - src_tens.storage_compression_scale = tens.storage_compression_scale + src_tens.copy_compressed_weight_info(tens) if verbose_tensor_format: nng.print_passes_with_tensors() diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index ed2ab322..ea35c087 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -485,6 +485,7 @@ class Graph: self.bits_per_element = {} self.total_size = {} self.total_elements = {} + self.weight_cache = None # See CompressedWeightCache def get_root_subgraph(self): return self.subgraphs[0] diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 160cf630..2f91f61c 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -225,7 +225,6 @@ class Tensor: "quantization", "weight_compressed_offsets", "element_size_bytes", - "reshaped", "block_traversal", "offset", "cpu_tensor", @@ -273,8 +272,6 @@ class Tensor: # quantization parameters self.quantization = None - - self.reshaped = False self.block_traversal = TensorBlockTraversal.Default self.resampling_mode = resampling_mode.NONE @@ -294,20 +291,13 @@ class Tensor: res.values = self.values res.quant_values = self.quant_values - res.compressed_values = self.compressed_values res.mem_area = self.mem_area res.format = self.format res.purpose = self.purpose res.sub_purpose = self.sub_purpose res.alignment = self.alignment - res.weight_transpose_depthwise = self.weight_transpose_depthwise - - res.storage_compression_scale = self.storage_compression_scale res.bandwidth_compression_scale = self.bandwidth_compression_scale - res.compression_scale_for_worst_weight_stream = self.compression_scale_for_worst_weight_stream - res.weight_compression_scales = self.weight_compression_scales res.storage_rounding_quantum = self.storage_rounding_quantum - res.brick_size = self.brick_size res.address = 0 if self.quantization is not None: @@ -317,6 +307,7 @@ class Tensor: res.resampling_mode = self.resampling_mode + res.copy_compressed_weight_info(self) return res def clone_into_fast_storage(self, arch): @@ -324,6 +315,19 @@ class Tensor: res.mem_area = arch.fast_storage_mem_area return res + def copy_compressed_weight_info(self, src_tens): + # Copies compressed values + all related weight compression info from the given tensor + self.compressed_values = src_tens.compressed_values + self.storage_shape = src_tens.storage_shape + self.brick_size = src_tens.brick_size + self.weight_compression_scales = src_tens.weight_compression_scales + self.weight_compressed_offsets = src_tens.weight_compressed_offsets + self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise + self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream + self.storage_compression_scale = src_tens.storage_compression_scale + self.block_traversal = src_tens.block_traversal + self.weight_compression_config = src_tens.weight_compression_config + def set_format(self, fmt, arch): self.format = fmt shape_len = 0 @@ -527,6 +531,14 @@ class Tensor: return strides + def needs_dma(self): + return len(self.ops) == 1 and self.ops[0].type == "DMA" + + def get_dma_src_tensor(self): + # For weight tensors that need DMA: returns the source tensor in Flash, else None + # Note: for DMA ops, Pass.weight_tensor is referring to the SRAM weight tensor + return self.ops[0].inputs[0] if self.needs_dma() else None + def compressed_stream_index_from_coord(self, coord): assert self.format == TensorFormat.WeightsCompressed assert len(self.compressed_values) > 0 @@ -575,7 +587,7 @@ class Tensor: if len(self.weight_compressed_offsets) == 0: return 0 - if len(self.ops) == 1 and self.ops[0].type == "DMA" and self.sub_purpose == TensorSubPurpose.DoubleBuffer: + if self.needs_dma() and self.sub_purpose == TensorSubPurpose.DoubleBuffer: depth = orig_coord[-1] brick_depth = self.brick_size[-1] # Clamp position at final element index diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py index cd2b570f..e3952df3 100644 --- a/ethosu/vela/tensor_allocation.py +++ b/ethosu/vela/tensor_allocation.py @@ -27,18 +27,26 @@ from .nn_graph import TensorAllocator from .tensor import MemArea -def linear_allocate_live_ranges(live_ranges, alloc_granularity=256): +def linear_allocate_live_ranges(live_ranges, alloc_granularity=16): + # Allocates using increasing addresses. Duplicate constant tensors will be allocated to the same address total_sz = 0 allocated_tensors = [] - # just assign increasing addresses + # just assign increasing addresses, except for duplicates for tens, lr in live_ranges.ranges.items(): if tens in allocated_tensors: continue - lr.set_address(total_sz) + address = total_sz + if tens.weight_compression_config is not None: + for allocated_tens in allocated_tensors: + if allocated_tens.weight_compression_config == tens.weight_compression_config: + address = allocated_tens.address + break + lr.set_address(address) allocated_tensors += lr.tensors - total_sz += numeric_util.round_up(int(math.ceil(lr.size)), alloc_granularity) + if address == total_sz: + total_sz += numeric_util.round_up(int(math.ceil(lr.size)), alloc_granularity) return total_sz diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 109ae0ec..5ab90f04 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -39,24 +39,24 @@ def decode_str(s): return s.decode("utf-8") -def reshape_tensor_add_const_op(tens, reorder): - if not tens.reshaped: - original_shape = tens.shape - tens.name = tens.name + "_reshape" - tens.shape = [original_shape[idx] for idx in reorder] - tens.bandwidth_shape = tens.shape - tens.storage_shape = tens.shape +def clone_and_reshape_tensor(src_tens, reorder): - if tens.values is not None: - tens.values = tens.values.transpose(reorder) + tens = src_tens.clone("_reshape") + tens.shape = [src_tens.shape[idx] for idx in reorder] + tens.bandwidth_shape = tens.shape + tens.storage_shape = tens.shape - if tens.quant_values is not None: - tens.quant_values = tens.quant_values.transpose(reorder) + if tens.values is not None: + tens.values = tens.values.transpose(reorder) - op = Operation("Const", tens.name) - op.outputs = [tens] - tens.ops = [op] - tens.reshaped = True + if tens.quant_values is not None: + tens.quant_values = tens.quant_values.transpose(reorder) + + op = Operation("Const", tens.name) + op.outputs = [tens] + tens.ops = [op] + + return tens class TFLiteSubgraph: @@ -137,10 +137,10 @@ class TFLiteSubgraph: activation_function_to_split_out = None if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"): - reshape_tensor_add_const_op(inputs[1], (1, 2, 3, 0)) + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0)) if op_type.startswith("FullyConnected"): - reshape_tensor_add_const_op(inputs[1], (1, 0)) + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) if opt_serializer is not None: op.attrs = opt_serializer.deserialize(op_data.BuiltinOptions(), op_data.CustomOptionsAsNumpy()) diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index a81b1fb4..450e091e 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -21,7 +21,6 @@ from collections import namedtuple import numpy as np from ethosu import mlw_codec -from .architecture_features import Block from .data_type import DataType from .errors import UnsupportedFeatureError from .nn_graph import SchedulingStrategy @@ -35,6 +34,46 @@ from .tensor import TensorPurpose from .tensor import TensorSubPurpose +# Contains meta info for a weight compression. If two tensors have identical weight compression config, +# then they also will have identical compressed weights. +WeightCompressionConfig = namedtuple( + "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "equivalence_id"] +) + + +def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step): + # Note: for an ofm block only its depth is used in weight compression. + # And block depth > ofm depth gives same result as block depth == ofm depth + block_depth = min(ofm_block_depth, tens.quant_values.shape[-1]) + return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, tens.equivalence_id) + + +def set_storage_shape(tens): + # Sets the storage shape depending on the tensor's sub purpose + if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2: + offset = 2 * np.amax([len(x) for x in tens.compressed_values]) + assert offset % 16 == 0 + else: + offset = tens.weight_compressed_offsets[-1] + tens.storage_shape = [1, 1, 1, offset] + + +class CompressedWeightCache: + # Contains weight compressions for all weight tensors in a graph + def __init__(self): + self.cache = {} # maps from WeightCompressionConfig to a tensor clone containing compressed weights + + def get_tensor_with_same_compression(self, wcc): + return self.cache.get(wcc) + + def add(self, tens): + # Adds the compressed weights from the tensor to the cache + wcc = tens.weight_compression_config + # Clone the tensor to make sure that nothing related to the weight compression is modified + tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step)) + self.cache[wcc] = tens_clone + + def encode(weight_stream): assert np.amin(weight_stream) >= -255 assert np.amax(weight_stream) <= 255 @@ -51,7 +90,7 @@ def encode(weight_stream): return compressed -def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth): +def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth): is_depthwise = block_traversal == TensorBlockTraversal.DepthWise is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst subkernel_max = arch.subkernel_max @@ -74,8 +113,8 @@ def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth stream = [] # Top level striping - OFM blocks in the entire brick's depth - for ofm_block_z in range(0, ofm_depth, ofm_block.depth): - clipped_ofm_block_depth = min(ofm_block.depth, ofm_depth - ofm_block_z) + for ofm_block_z in range(0, ofm_depth, ofm_block_depth): + clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z) # IFM blocks required for the brick for ifm_block_z in range(0, (1 if is_depthwise else ifm_depth), ifm_block_depth): if is_depthwise: @@ -139,20 +178,23 @@ def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth # Compress the weights -def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_val=None, max_val=None): +def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step): assert tens.purpose == TensorPurpose.Weights assert tens.format == TensorFormat.WeightsCompressed - WeightCompressionConfig = namedtuple("WeightCompressionConfig", ["npu_block_type", "ofm_block", "ofm_depth_step"]) - - # check if weights have already been compressed - wcc = tens.weight_compression_config - if wcc is not None: - assert wcc.npu_block_type == npu_block_type, "Weights not used by the same operator type" - - if wcc.ofm_block == ofm_block and wcc.ofm_depth_step == ofm_depth_step: - return - + # Check the weight cache + if nng.weight_cache is None: + nng.weight_cache = CompressedWeightCache() + wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step) + tens.weight_compression_config = wcc + tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc) + if tens_cached is not None: + # Cache hit, copy weights from the cache + tens.copy_compressed_weight_info(tens_cached) + set_storage_shape(tens) + return + + # No cache hit, perform the compression assert tens.quantization is not None assert tens.quantization.scale_f32 is not None assert tens.quantization.zero_point is not None @@ -173,7 +215,6 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_ compressed_offsets = [] encoded_streams = [] offset = 0 - max_single_buffer_len = 0 ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits() ifm_depth = weights.shape[-2] @@ -200,14 +241,10 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_ brick_weights = weights[:, :, :, idx : idx + count] # Encode all weights into one chunk - raw_stream = generate_brick(arch, brick_weights, ofm_block, tens.block_traversal, ifm_bitdepth) + raw_stream = generate_brick(arch, brick_weights, ofm_block_depth, tens.block_traversal, ifm_bitdepth) encoded = encode(raw_stream) encoded_streams.append(encoded) - # Remember maximum encoded length for DoubleBuffering - if max_single_buffer_len < len(encoded): - max_single_buffer_len = len(encoded) - # Remember where we put it for linear addressing compressed_offsets.append(offset) offset += len(encoded) @@ -219,18 +256,14 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_ # Also track complete length in the offsets array compressed_offsets.append(offset) - if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(encoded_streams) > 2: - offset = 2 * max_single_buffer_len - assert offset % 16 == 0 - - tens.storage_shape = [1, 1, 1, offset] tens.weight_compression_scales = compression_scales - tens.weight_compression_config = WeightCompressionConfig(npu_block_type, ofm_block, ofm_depth_step) tens.weight_compressed_offsets = compressed_offsets tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales) tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales) tens.compressed_values = encoded_streams tens.brick_size = (weights_shape[0], weights_shape[1], weights_shape[2], min(tens.shape[-1], ofm_depth_step)) + set_storage_shape(tens) + nng.weight_cache.add(tens) def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False): @@ -352,39 +385,29 @@ def update_pass_weight_and_scale_tensors(nng, arch): for sg in nng.subgraphs: for ps in sg.passes: - if ps.weight_tensor is not None: - npu_usage_of_tensor = find_npu_usage_of_tensor(ps.weight_tensor) + tens = ps.weight_tensor + if tens is not None: + npu_usage_of_tensor = find_npu_usage_of_tensor(tens) if npu_usage_of_tensor == NpuBlockType.ConvolutionDepthWise: - ps.weight_tensor.quant_values = np.transpose(ps.weight_tensor.quant_values, (0, 1, 3, 2)) - ps.weight_tensor.shape = ps.weight_tensor.storage_shape = ps.weight_tensor.bandwidth_shape = list( - ps.weight_tensor.quant_values.shape - ) - ps.weight_tensor.weight_transpose_depthwise = True + tens.quant_values = np.transpose(tens.quant_values, (0, 1, 3, 2)) + tens.shape = tens.storage_shape = tens.bandwidth_shape = list(tens.quant_values.shape) + tens.weight_transpose_depthwise = True - needs_dma = len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA" + needs_dma = tens.needs_dma() if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma: ofm_depth_step = ps.block_config[-1] else: - ofm_depth_step = ps.weight_tensor.shape[-1] - + ofm_depth_step = tens.shape[-1] compress_weights( - ps.weight_tensor, - arch, - npu_usage_of_tensor, - Block(ps.block_config[-3], ps.block_config[-4], ps.block_config[-1]), - ofm_depth_step, + arch, nng, tens, npu_usage_of_tensor, ps.block_config[-1], ofm_depth_step, ) # Update source tensor - if len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA": - src_tens = ps.weight_tensor.ops[0].inputs[0] - src_tens.shape = ps.weight_tensor.shape - src_tens.weight_transpose_depthwise = ps.weight_tensor.weight_transpose_depthwise - src_tens.quant_values = ps.weight_tensor.quant_values - src_tens.compressed_values = ps.weight_tensor.compressed_values - src_tens.storage_shape = [1, 1, 1, ps.weight_tensor.weight_compressed_offsets[-1]] - src_tens.brick_size = ps.weight_tensor.brick_size - src_tens.weight_compression_scales = ps.weight_tensor.weight_compression_scales - src_tens.weight_compressed_offsets = ps.weight_tensor.weight_compressed_offsets + if needs_dma: + src_tens = tens.get_dma_src_tensor() + src_tens.shape = tens.shape + src_tens.quant_values = tens.quant_values + src_tens.copy_compressed_weight_info(tens) + set_storage_shape(src_tens) if ps.scale_tensor is not None: rescale_for_faf = False -- cgit v1.2.1