aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/weight_compressor.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2021-06-08 21:25:57 +0100
committerTim Hall <tim.hall@arm.com>2021-06-08 21:25:57 +0100
commitd784af7e8995a10fb403157af48371699c35bbfe (patch)
treebf40b35b030d560049cef9411293b51e3d70ff4a /ethosu/vela/weight_compressor.py
parent225e19d3640288e991475ee4c49cb3ffd83cc83b (diff)
downloadethos-u-vela-d784af7e8995a10fb403157af48371699c35bbfe.tar.gz
MLBEDSW-4602: Fix Deepspeech scale & bias reuse issue.
- Deepspeech reuses identical weights and biases throughout the network. Since biases are now interleaved with weights there is a scaling issue when the ifm scales differ between operations using the same weight and scale tensor. - This commit uses interleaved weights/scales on their first use but separates scales to source memory on subsequent use (if the ifm scale is different). Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: I7aae163438160a919cae04e235966e75355a6148
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r--ethosu/vela/weight_compressor.py178
1 files changed, 102 insertions, 76 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 652d0168..4ce03d55 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -40,10 +40,11 @@ from ethosu import mlw_codec
# Contains meta info for a weight compression. If two tensors have identical weight compression config,
# then they also will have identical compressed weights.
WeightCompressionConfig = namedtuple(
- "WeightCompressionConfig",
- ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id", "scale_value_id"],
+ "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id"],
)
+ScaleCompressionConfig = namedtuple("ScaleCompressionConfig", ["scale_value_id", "ifm_scale", "ofm_scale"])
+
WeightKey = namedtuple("WeightKey", ["core", "depth"])
@@ -68,6 +69,7 @@ class NpuWeightTensor(Tensor):
self.encoded_ranges = OrderedDict()
self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
self.dtype = DataType.uint8
+ self.scale_compression_config = None
class CompressedWeightCache:
@@ -95,15 +97,11 @@ class CompressedWeightCache:
return cache_obj[1] if cache_obj else None
-def create_weight_compression_config(
- weight_tens, scale_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation
-):
+def create_weight_compression_config(weight_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
# Note: for an ofm block only its depth is used in weight compression.
# And block depth > ofm depth gives same result as block depth == ofm depth
block_depth = min(ofm_block_depth, weight_tens.quant_values.shape[-1])
- return WeightCompressionConfig(
- npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id, scale_tens.value_id
- )
+ return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id)
def encode_weights(
@@ -277,72 +275,86 @@ def _prepare_scale_and_bias(arch, tens, rescale_for_faf):
def encode_weight_and_scale_tensor(
arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False
-) -> NpuWeightTensor:
+) -> (NpuWeightTensor, NpuWeightTensor):
npu_block_type = op.type.npu_block_type
+ ifm_scale = scale_tens and scale_tens.consumer_list[0].get_input_quantization().scale_f32
+ ofm_scale = scale_tens and scale_tens.consumer_list[0].get_output_quantization().scale_f32
+
wcc = create_weight_compression_config(
- weight_tens, scale_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
+ weight_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
)
+ scc = ScaleCompressionConfig(scale_tens and scale_tens.value_id, ifm_scale, ofm_scale)
+
tens_cached = CompressedWeightCache.get_tensor_with_same_compression(wcc)
if tens_cached is not None:
- return tens_cached
+ if tens_cached.scale_compression_config == scc:
+ return tens_cached, None
+ npu_tensor = NpuWeightTensor(scale_tens.name)
+ do_weights = False
+ do_scales = True
+ else:
+ npu_tensor = NpuWeightTensor(weight_tens.name)
+ do_weights = True
+ do_scales = True
- npu_tensor = NpuWeightTensor(weight_tens.name)
npu_tensor.weight_compression_config = wcc
+ npu_tensor.scale_compression_config = scc
+
+ # Ensure depth offsets are terminated at end of OFM shape
+ assert len(depth_offsets) > 1, "Require closed depth ranges"
- # No cache hit, perform the compression
- assert weight_tens.quantization is not None
- assert weight_tens.quantization.scale_f32 is not None
- assert weight_tens.quantization.zero_point is not None
+ ifm_bitdepth = op.inputs[0].dtype.size_in_bits()
- zero_point = weight_tens.quantization.zero_point
- quant_buf = weight_tens.quant_values.astype(np.int64)
+ # No cache hit, need to perform the encoding
+ if do_weights:
+ assert weight_tens.quantization is not None
+ assert weight_tens.quantization.scale_f32 is not None
+ assert weight_tens.quantization.zero_point is not None
- # Early zero-point correction
- weights = quant_buf - zero_point
+ # Early zero-point correction
+ quant_buf = weight_tens.quant_values.astype(np.int64)
+ weights = quant_buf - weight_tens.quantization.zero_point
- if len(weights.shape) == 2:
- weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
+ if len(weights.shape) == 2:
+ weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
- # Expect this (undilated) equivalence
- assert kernel.height == weights.shape[0]
- assert kernel.width == weights.shape[1]
- # Ensure depth offsets are terminated at end of OFM shape
- assert len(depth_offsets) > 1, "Require closed depth ranges"
+ # Expect this (undilated) equivalence
+ assert kernel.height == weights.shape[0]
+ assert kernel.width == weights.shape[1]
- ifm_bitdepth = op.inputs[0].dtype.size_in_bits()
- ifm_depth = weights.shape[-2]
-
- # Default HW traversal
- npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
-
- if npu_block_type == NpuBlockType.ConvolutionMxN:
- # Determine which block traversal strategy has better DPU utilization
- kernel_size = weights.shape[0] * weights.shape[1]
- depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
- part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
- kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
- )
- if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
- # Part-kernel first is always better for ifm depths <= 8
- npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
-
- if op.type == Op.Conv2DBackpropInputSwitchedBias:
- # Transpose Convoluion, reverse weights in H and W axes
- weights = np.flip(weights, axis=(0, 1))
+ ifm_depth = weights.shape[-2]
+
+ # Default HW traversal
+ npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
+
+ if npu_block_type == NpuBlockType.ConvolutionMxN:
+ # Determine which block traversal strategy has better DPU utilization
+ kernel_size = weights.shape[0] * weights.shape[1]
+ depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
+ part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
+ kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
+ )
+ if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
+ # Part-kernel first is always better for ifm depths <= 8
+ npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
+
+ if op.type == Op.Conv2DBackpropInputSwitchedBias:
+ # Transpose Convoluion, reverse weights in H and W axes
+ weights = np.flip(weights, axis=(0, 1))
encoded_stream = bytearray()
max_single_buffer_len = 0
is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
# Bias & scale
- if scale_tens:
+ if do_scales:
quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf)
scale_tens.element_size_bytes = 10
# Slice the weight stream up depth-ways into bricks and compress
- full_ofm_depth = quant_buf.shape[-1]
+ full_ofm_depth = weight_tens.quant_values.shape[-1]
ofm_block_depth = block_config.ofm_block.depth
weight_range_index = 0
@@ -352,11 +364,12 @@ def encode_weight_and_scale_tensor(
depth_length = depth_offsets[idx + 1] - depth_offset
# Get the weights necessary for this brick
- brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
+ if do_weights:
+ brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
buffer_start_offset = len(encoded_stream)
- # For each core, deinterleave weights from the larger volume
+ # For each core, deinterleave weights/scales from the larger volume
# and generate separate compressed streams.
for core in range(0, min(arch.ncores, full_ofm_depth)):
@@ -370,7 +383,7 @@ def encode_weight_and_scale_tensor(
weight_range_index += 1
# Scales & biases
- if scale_tens:
+ if do_scales:
scale_stream = []
core_scales = quantised_scales[
depth_offset + core : depth_offset + core + depth_length : arch.ncores
@@ -389,36 +402,49 @@ def encode_weight_and_scale_tensor(
encoded_stream.extend(bytearray(16 - remainder))
# Weights
- core_weights = core_deinterleave(brick_weights, core, arch.ncores)
- encoded_substream, _ = encode_weights(
- accelerator=arch.accelerator_config,
- weights_volume=core_weights,
- dilation_xy=kernel.dilation,
- ifm_bitdepth=ifm_bitdepth,
- ofm_block_depth=core_block_depth,
- is_depthwise=is_depthwise,
- block_traversal=npu_tensor.hw_traversal,
- )
-
- weight_range.weight_offset = len(encoded_stream) - weight_range.offset
- weight_range.weight_bytes = len(encoded_substream)
-
- # Append encoded weights section
- encoded_stream.extend(encoded_substream)
- assert len(encoded_stream) % 16 == 0
-
- # Record encoded range in weights tensor
+ if do_weights:
+ core_weights = core_deinterleave(brick_weights, core, arch.ncores)
+ encoded_substream, _ = encode_weights(
+ accelerator=arch.accelerator_config,
+ weights_volume=core_weights,
+ dilation_xy=kernel.dilation,
+ ifm_bitdepth=ifm_bitdepth,
+ ofm_block_depth=core_block_depth,
+ is_depthwise=is_depthwise,
+ block_traversal=npu_tensor.hw_traversal,
+ )
+ weight_range.weight_offset = len(encoded_stream) - weight_range.offset
+ weight_range.weight_bytes = len(encoded_substream)
+ # Append encoded section
+ encoded_stream.extend(encoded_substream)
+ assert len(encoded_stream) % 16 == 0
+
+ # Record encoded range in tensor
npu_tensor.encoded_ranges[key] = weight_range
# Remember maximum encoded length for DoubleBuffering
max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset)
+ # Attach buffer to tensor
npu_tensor.buffer = encoded_stream
npu_tensor.max_range_bytes = max_single_buffer_len
npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
npu_tensor.format = TensorFormat.WeightsCompressed
- npu_tensor.purpose = TensorPurpose.Weights
- npu_tensor.mem_area = weight_tens.mem_area
- npu_tensor.mem_type = weight_tens.mem_type
- CompressedWeightCache.add(npu_tensor)
- return npu_tensor
+
+ # Scale only tensor
+ if not do_weights:
+ npu_tensor.weight_compression_config = None
+ npu_tensor.purpose = TensorPurpose.FSBias
+ npu_tensor.mem_area = scale_tens.mem_area
+ npu_tensor.mem_type = scale_tens.mem_type
+ weights_tensor = tens_cached
+ scale_tensor = npu_tensor
+ else:
+ npu_tensor.purpose = TensorPurpose.Weights
+ npu_tensor.mem_area = weight_tens.mem_area
+ npu_tensor.mem_type = weight_tens.mem_type
+ weights_tensor = npu_tensor
+ scale_tensor = None
+ CompressedWeightCache.add(weights_tensor)
+
+ return weights_tensor, scale_tensor