From d784af7e8995a10fb403157af48371699c35bbfe Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Tue, 8 Jun 2021 21:25:57 +0100 Subject: MLBEDSW-4602: Fix Deepspeech scale & bias reuse issue. - Deepspeech reuses identical weights and biases throughout the network. Since biases are now interleaved with weights there is a scaling issue when the ifm scales differ between operations using the same weight and scale tensor. - This commit uses interleaved weights/scales on their first use but separates scales to source memory on subsequent use (if the ifm scale is different). Signed-off-by: Tim Hall Change-Id: I7aae163438160a919cae04e235966e75355a6148 --- ethosu/vela/high_level_command_stream.py | 2 + ethosu/vela/high_level_command_stream_generator.py | 2 + ethosu/vela/high_level_command_to_npu_op.py | 23 ++- ethosu/vela/live_range.py | 10 +- ethosu/vela/npu_serialisation.py | 2 + ethosu/vela/scheduler.py | 14 +- ethosu/vela/tensor_allocation.py | 1 + ethosu/vela/weight_compressor.py | 178 ++++++++++++--------- 8 files changed, 141 insertions(+), 91 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index d353b482..ddb24824 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -174,6 +174,7 @@ class NpuStripe(Command): ofm_box, weight_tensor=None, weight_box=None, + scale_tensor=None, ifm2_tensor=None, ifm2_box=None, pad_top=0, @@ -190,6 +191,7 @@ class NpuStripe(Command): self.ofm_tensor = ofm_tensor self.ofm_box = ofm_box self.weight_tensor = weight_tensor + self.scale_tensor = scale_tensor self.weight_box = weight_box self.pad_top = pad_top self.pad_bottom = pad_bottom diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index ecd375e9..5a838f88 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -186,6 +186,7 @@ def generate_high_level_commands_for_sched_op(sched_op, schedule): # Calculate the weight box - i.e. the subshape of weights needed for this NpuStripe command weight_tensor = op_info.npu_weights_tensor + scale_tensor = op_info.npu_scales_tensor if op_info.npu_weights_tensor: weight_box = Box([0, 0, 0, start_channel], [1, 1, 1, end_channel]) @@ -211,6 +212,7 @@ def generate_high_level_commands_for_sched_op(sched_op, schedule): ofm_box, weight_tensor, weight_box, + scale_tensor, ifm2_tensor=ifm2_tensor, ifm2_box=ifm2_box, pad_top=pad_top, diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 4ef7bee8..80d0e476 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -267,11 +267,14 @@ def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_sh return fm -def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]: +def create_weights( + weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures +) -> List[NpuAddressRange]: """Returns address ranges for weights and scales""" weights = [] biases = [] - region = get_region(weight_tensor.mem_type, arch) + shared_region = get_region(weight_tensor.mem_type, arch) + scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch) w_tensor_src = weight_tensor if weight_tensor.src_tensor: @@ -300,11 +303,19 @@ def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFea # Location of weights in tensor addr_range = NpuAddressRange( - region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16) + shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16) ) weights.append(addr_range) - # Location of biases in tensor - addr_range = NpuAddressRange(region, int(address), round_up(int(weight_range.scale_bytes), 16)) + + # Location of standalone scales or combined weights tensor scales + if scale_tensor: + assert scale_tensor.src_tensor is None # Must be standalone + scale_range = scale_tensor.encoded_ranges[key] + address = scale_tensor.address + scale_range.offset + addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16)) + else: + addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16)) + biases.append(addr_range) return weights, biases @@ -351,7 +362,7 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor) if cmd.weight_tensor is not None: - npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, arch) + npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch) npu_op.activation = create_npu_activation(op) npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops) npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize) diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index d75a167d..b687a9e7 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -344,16 +344,14 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ lr_graph, tens, target_mem_area, target_mem_type_set ): continue - rng = lr_graph.get_or_create_range(tens) rng.mark_usage(sg_time) for sched_op, op_info in sg.schedule.cost_map.items(): - if op_info.npu_weights_tensor and not ( - tensor_should_be_ignored(lr_graph, op_info.npu_weights_tensor, target_mem_area, target_mem_type_set) - ): - rng = lr_graph.get_or_create_range(op_info.npu_weights_tensor) - rng.mark_usage(sg_time) + for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]: + if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)): + rng = lr_graph.get_or_create_range(tensor) + rng.mark_usage(sg_time) lr_graph.current_time += 1 return lr_graph diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py index 39a7f21f..f462168a 100644 --- a/ethosu/vela/npu_serialisation.py +++ b/ethosu/vela/npu_serialisation.py @@ -98,6 +98,8 @@ def serialise_npu_subgraph_into_tensors(sg, arch, scratch_tens, scratch_fast_ten op_info = sg.schedule.cost_map[sched_op] if op_info.npu_weights_tensor: copy_compressed_values_to_memory_tensor(sg.flash_tensor, op_info.npu_weights_tensor) + if op_info.npu_scales_tensor: + copy_compressed_values_to_memory_tensor(sg.flash_tensor, op_info.npu_scales_tensor) if ifm_tensor and ifm_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast): copy_ifm_values_to_memory_tensor(sg.flash_tensor, ifm_tensor) diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 00a4dfc7..71007a32 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -94,6 +94,7 @@ class SchedulerOpInfo: self.time_index = None # Set by update_op_memory_snapshot self.ofm_depth_slices: List[int] = [0, stripe.depth] self.npu_weights_tensor = None + self.npu_scales_tensor = None self.buffered_weight_tensor = None self.cycles = None self.slack_buffering_cycles = 0 @@ -248,7 +249,10 @@ class SchedulerOperation: scheduler_op_info = SchedulerOpInfo(block_config, 0, ifm_shape, ifm2_shape, ofm_shape) if self.parent_op.weights: # Default full-depth weight encoding with no buffering - scheduler_op_info.npu_weights_tensor = weight_compressor.encode_weight_and_scale_tensor( + ( + scheduler_op_info.npu_weights_tensor, + scheduler_op_info.npu_scales_tensor, + ) = weight_compressor.encode_weight_and_scale_tensor( self.arch, self.parent_op, self.parent_op.weights, @@ -537,7 +541,7 @@ class Scheduler: ofm_full_depth_slices = [0, ref_cost.stripe.depth] # Encode weights for the full depth - full_weights = weight_compressor.encode_weight_and_scale_tensor( + full_weights, full_scales = weight_compressor.encode_weight_and_scale_tensor( self.arch, sched_op.parent_op, weight_tensor, @@ -552,9 +556,11 @@ class Scheduler: # No buffering required - take all the weights from permanent storage if sched_op.op_type == Op.FullyConnected or not needs_dma: cost.npu_weights_tensor = full_weights + cost.npu_scales_tensor = full_scales return encoded_weights = full_weights + encoded_scales = full_scales # How many NPU cycles are available under the previously executing # operator and SRAM unused for performing buffered DMA transfers @@ -609,7 +615,7 @@ class Scheduler: # Encode weights based depth slices cost.ofm_depth_slices = depth_slices - encoded_weights = weight_compressor.encode_weight_and_scale_tensor( + encoded_weights, encoded_scales = weight_compressor.encode_weight_and_scale_tensor( self.arch, sched_op.parent_op, weight_tensor, @@ -665,8 +671,10 @@ class Scheduler: # Don't slice or buffer - use the whole depth from persistent storage cost.ofm_depth_slices = ofm_full_depth_slices encoded_weights = full_weights + encoded_scales = full_scales cost.npu_weights_tensor = encoded_weights + cost.npu_scales_tensor = encoded_scales def propose_minimal_schedule(self) -> Schedule: """Proposes scheduling parameters where every operator is subdivided into the smallest stripe that satisfies the diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py index d3e2a037..4b5e5e42 100644 --- a/ethosu/vela/tensor_allocation.py +++ b/ethosu/vela/tensor_allocation.py @@ -49,6 +49,7 @@ def linear_allocate_live_ranges(live_ranges, alloc_granularity=Tensor.Allocation if tens.weight_compression_config is not None: for allocated_tens in allocated_tensors: if allocated_tens.weight_compression_config == tens.weight_compression_config: + assert allocated_tens.scale_compression_config == tens.scale_compression_config address = allocated_tens.address break if tens.purpose == TensorPurpose.LUT: diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index 652d0168..4ce03d55 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -40,10 +40,11 @@ from ethosu import mlw_codec # Contains meta info for a weight compression. If two tensors have identical weight compression config, # then they also will have identical compressed weights. WeightCompressionConfig = namedtuple( - "WeightCompressionConfig", - ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id", "scale_value_id"], + "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id"], ) +ScaleCompressionConfig = namedtuple("ScaleCompressionConfig", ["scale_value_id", "ifm_scale", "ofm_scale"]) + WeightKey = namedtuple("WeightKey", ["core", "depth"]) @@ -68,6 +69,7 @@ class NpuWeightTensor(Tensor): self.encoded_ranges = OrderedDict() self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST self.dtype = DataType.uint8 + self.scale_compression_config = None class CompressedWeightCache: @@ -95,15 +97,11 @@ class CompressedWeightCache: return cache_obj[1] if cache_obj else None -def create_weight_compression_config( - weight_tens, scale_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation -): +def create_weight_compression_config(weight_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation): # Note: for an ofm block only its depth is used in weight compression. # And block depth > ofm depth gives same result as block depth == ofm depth block_depth = min(ofm_block_depth, weight_tens.quant_values.shape[-1]) - return WeightCompressionConfig( - npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id, scale_tens.value_id - ) + return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id) def encode_weights( @@ -277,72 +275,86 @@ def _prepare_scale_and_bias(arch, tens, rescale_for_faf): def encode_weight_and_scale_tensor( arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False -) -> NpuWeightTensor: +) -> (NpuWeightTensor, NpuWeightTensor): npu_block_type = op.type.npu_block_type + ifm_scale = scale_tens and scale_tens.consumer_list[0].get_input_quantization().scale_f32 + ofm_scale = scale_tens and scale_tens.consumer_list[0].get_output_quantization().scale_f32 + wcc = create_weight_compression_config( - weight_tens, scale_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation + weight_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation ) + scc = ScaleCompressionConfig(scale_tens and scale_tens.value_id, ifm_scale, ofm_scale) + tens_cached = CompressedWeightCache.get_tensor_with_same_compression(wcc) if tens_cached is not None: - return tens_cached + if tens_cached.scale_compression_config == scc: + return tens_cached, None + npu_tensor = NpuWeightTensor(scale_tens.name) + do_weights = False + do_scales = True + else: + npu_tensor = NpuWeightTensor(weight_tens.name) + do_weights = True + do_scales = True - npu_tensor = NpuWeightTensor(weight_tens.name) npu_tensor.weight_compression_config = wcc + npu_tensor.scale_compression_config = scc + + # Ensure depth offsets are terminated at end of OFM shape + assert len(depth_offsets) > 1, "Require closed depth ranges" - # No cache hit, perform the compression - assert weight_tens.quantization is not None - assert weight_tens.quantization.scale_f32 is not None - assert weight_tens.quantization.zero_point is not None + ifm_bitdepth = op.inputs[0].dtype.size_in_bits() - zero_point = weight_tens.quantization.zero_point - quant_buf = weight_tens.quant_values.astype(np.int64) + # No cache hit, need to perform the encoding + if do_weights: + assert weight_tens.quantization is not None + assert weight_tens.quantization.scale_f32 is not None + assert weight_tens.quantization.zero_point is not None - # Early zero-point correction - weights = quant_buf - zero_point + # Early zero-point correction + quant_buf = weight_tens.quant_values.astype(np.int64) + weights = quant_buf - weight_tens.quantization.zero_point - if len(weights.shape) == 2: - weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0) + if len(weights.shape) == 2: + weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0) - # Expect this (undilated) equivalence - assert kernel.height == weights.shape[0] - assert kernel.width == weights.shape[1] - # Ensure depth offsets are terminated at end of OFM shape - assert len(depth_offsets) > 1, "Require closed depth ranges" + # Expect this (undilated) equivalence + assert kernel.height == weights.shape[0] + assert kernel.width == weights.shape[1] - ifm_bitdepth = op.inputs[0].dtype.size_in_bits() - ifm_depth = weights.shape[-2] - - # Default HW traversal - npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST - - if npu_block_type == NpuBlockType.ConvolutionMxN: - # Determine which block traversal strategy has better DPU utilization - kernel_size = weights.shape[0] * weights.shape[1] - depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16) - part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * ( - kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2) - ) - if part_kernel_utilization >= depth_utilization or ifm_depth <= 8: - # Part-kernel first is always better for ifm depths <= 8 - npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST - - if op.type == Op.Conv2DBackpropInputSwitchedBias: - # Transpose Convoluion, reverse weights in H and W axes - weights = np.flip(weights, axis=(0, 1)) + ifm_depth = weights.shape[-2] + + # Default HW traversal + npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST + + if npu_block_type == NpuBlockType.ConvolutionMxN: + # Determine which block traversal strategy has better DPU utilization + kernel_size = weights.shape[0] * weights.shape[1] + depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16) + part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * ( + kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2) + ) + if part_kernel_utilization >= depth_utilization or ifm_depth <= 8: + # Part-kernel first is always better for ifm depths <= 8 + npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST + + if op.type == Op.Conv2DBackpropInputSwitchedBias: + # Transpose Convoluion, reverse weights in H and W axes + weights = np.flip(weights, axis=(0, 1)) encoded_stream = bytearray() max_single_buffer_len = 0 is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise # Bias & scale - if scale_tens: + if do_scales: quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf) scale_tens.element_size_bytes = 10 # Slice the weight stream up depth-ways into bricks and compress - full_ofm_depth = quant_buf.shape[-1] + full_ofm_depth = weight_tens.quant_values.shape[-1] ofm_block_depth = block_config.ofm_block.depth weight_range_index = 0 @@ -352,11 +364,12 @@ def encode_weight_and_scale_tensor( depth_length = depth_offsets[idx + 1] - depth_offset # Get the weights necessary for this brick - brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length] + if do_weights: + brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length] buffer_start_offset = len(encoded_stream) - # For each core, deinterleave weights from the larger volume + # For each core, deinterleave weights/scales from the larger volume # and generate separate compressed streams. for core in range(0, min(arch.ncores, full_ofm_depth)): @@ -370,7 +383,7 @@ def encode_weight_and_scale_tensor( weight_range_index += 1 # Scales & biases - if scale_tens: + if do_scales: scale_stream = [] core_scales = quantised_scales[ depth_offset + core : depth_offset + core + depth_length : arch.ncores @@ -389,36 +402,49 @@ def encode_weight_and_scale_tensor( encoded_stream.extend(bytearray(16 - remainder)) # Weights - core_weights = core_deinterleave(brick_weights, core, arch.ncores) - encoded_substream, _ = encode_weights( - accelerator=arch.accelerator_config, - weights_volume=core_weights, - dilation_xy=kernel.dilation, - ifm_bitdepth=ifm_bitdepth, - ofm_block_depth=core_block_depth, - is_depthwise=is_depthwise, - block_traversal=npu_tensor.hw_traversal, - ) - - weight_range.weight_offset = len(encoded_stream) - weight_range.offset - weight_range.weight_bytes = len(encoded_substream) - - # Append encoded weights section - encoded_stream.extend(encoded_substream) - assert len(encoded_stream) % 16 == 0 - - # Record encoded range in weights tensor + if do_weights: + core_weights = core_deinterleave(brick_weights, core, arch.ncores) + encoded_substream, _ = encode_weights( + accelerator=arch.accelerator_config, + weights_volume=core_weights, + dilation_xy=kernel.dilation, + ifm_bitdepth=ifm_bitdepth, + ofm_block_depth=core_block_depth, + is_depthwise=is_depthwise, + block_traversal=npu_tensor.hw_traversal, + ) + weight_range.weight_offset = len(encoded_stream) - weight_range.offset + weight_range.weight_bytes = len(encoded_substream) + # Append encoded section + encoded_stream.extend(encoded_substream) + assert len(encoded_stream) % 16 == 0 + + # Record encoded range in tensor npu_tensor.encoded_ranges[key] = weight_range # Remember maximum encoded length for DoubleBuffering max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset) + # Attach buffer to tensor npu_tensor.buffer = encoded_stream npu_tensor.max_range_bytes = max_single_buffer_len npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)]) npu_tensor.format = TensorFormat.WeightsCompressed - npu_tensor.purpose = TensorPurpose.Weights - npu_tensor.mem_area = weight_tens.mem_area - npu_tensor.mem_type = weight_tens.mem_type - CompressedWeightCache.add(npu_tensor) - return npu_tensor + + # Scale only tensor + if not do_weights: + npu_tensor.weight_compression_config = None + npu_tensor.purpose = TensorPurpose.FSBias + npu_tensor.mem_area = scale_tens.mem_area + npu_tensor.mem_type = scale_tens.mem_type + weights_tensor = tens_cached + scale_tensor = npu_tensor + else: + npu_tensor.purpose = TensorPurpose.Weights + npu_tensor.mem_area = weight_tens.mem_area + npu_tensor.mem_type = weight_tens.mem_type + weights_tensor = npu_tensor + scale_tensor = None + CompressedWeightCache.add(weights_tensor) + + return weights_tensor, scale_tensor -- cgit v1.2.1