aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/weight_compressor.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-05-07 08:12:58 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit3c07c97e0202c1cf01eba06c24b37a8f15ff7a7c (patch)
tree5856b7727a99b3c0baa00f5486f0c3b53e8e38e6 /ethosu/vela/weight_compressor.py
parent86d49935c3736c7aaa419abda07fa20c37c991a8 (diff)
downloadethos-u-vela-3c07c97e0202c1cf01eba06c24b37a8f15ff7a7c.tar.gz
MLBEDSW-1941: Bug fix shared weights
If same weight tensor was used with different block configs, errors would occur. Fixed by always cloning weight tensors, using a global weight compression cache and modifying the linear allocator to detect multiple usage of same weight compression. Change-Id: I91ca59176e1c59c66e0ac7a4227f2b5f0b47053f Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r--ethosu/vela/weight_compressor.py127
1 files changed, 75 insertions, 52 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index a81b1fb4..450e091e 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -21,7 +21,6 @@ from collections import namedtuple
import numpy as np
from ethosu import mlw_codec
-from .architecture_features import Block
from .data_type import DataType
from .errors import UnsupportedFeatureError
from .nn_graph import SchedulingStrategy
@@ -35,6 +34,46 @@ from .tensor import TensorPurpose
from .tensor import TensorSubPurpose
+# Contains meta info for a weight compression. If two tensors have identical weight compression config,
+# then they also will have identical compressed weights.
+WeightCompressionConfig = namedtuple(
+ "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "equivalence_id"]
+)
+
+
+def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step):
+ # Note: for an ofm block only its depth is used in weight compression.
+ # And block depth > ofm depth gives same result as block depth == ofm depth
+ block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
+ return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, tens.equivalence_id)
+
+
+def set_storage_shape(tens):
+ # Sets the storage shape depending on the tensor's sub purpose
+ if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
+ offset = 2 * np.amax([len(x) for x in tens.compressed_values])
+ assert offset % 16 == 0
+ else:
+ offset = tens.weight_compressed_offsets[-1]
+ tens.storage_shape = [1, 1, 1, offset]
+
+
+class CompressedWeightCache:
+ # Contains weight compressions for all weight tensors in a graph
+ def __init__(self):
+ self.cache = {} # maps from WeightCompressionConfig to a tensor clone containing compressed weights
+
+ def get_tensor_with_same_compression(self, wcc):
+ return self.cache.get(wcc)
+
+ def add(self, tens):
+ # Adds the compressed weights from the tensor to the cache
+ wcc = tens.weight_compression_config
+ # Clone the tensor to make sure that nothing related to the weight compression is modified
+ tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
+ self.cache[wcc] = tens_clone
+
+
def encode(weight_stream):
assert np.amin(weight_stream) >= -255
assert np.amax(weight_stream) <= 255
@@ -51,7 +90,7 @@ def encode(weight_stream):
return compressed
-def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth):
+def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth):
is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
subkernel_max = arch.subkernel_max
@@ -74,8 +113,8 @@ def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth
stream = []
# Top level striping - OFM blocks in the entire brick's depth
- for ofm_block_z in range(0, ofm_depth, ofm_block.depth):
- clipped_ofm_block_depth = min(ofm_block.depth, ofm_depth - ofm_block_z)
+ for ofm_block_z in range(0, ofm_depth, ofm_block_depth):
+ clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z)
# IFM blocks required for the brick
for ifm_block_z in range(0, (1 if is_depthwise else ifm_depth), ifm_block_depth):
if is_depthwise:
@@ -139,20 +178,23 @@ def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth
# Compress the weights
-def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_val=None, max_val=None):
+def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step):
assert tens.purpose == TensorPurpose.Weights
assert tens.format == TensorFormat.WeightsCompressed
- WeightCompressionConfig = namedtuple("WeightCompressionConfig", ["npu_block_type", "ofm_block", "ofm_depth_step"])
-
- # check if weights have already been compressed
- wcc = tens.weight_compression_config
- if wcc is not None:
- assert wcc.npu_block_type == npu_block_type, "Weights not used by the same operator type"
-
- if wcc.ofm_block == ofm_block and wcc.ofm_depth_step == ofm_depth_step:
- return
-
+ # Check the weight cache
+ if nng.weight_cache is None:
+ nng.weight_cache = CompressedWeightCache()
+ wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step)
+ tens.weight_compression_config = wcc
+ tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
+ if tens_cached is not None:
+ # Cache hit, copy weights from the cache
+ tens.copy_compressed_weight_info(tens_cached)
+ set_storage_shape(tens)
+ return
+
+ # No cache hit, perform the compression
assert tens.quantization is not None
assert tens.quantization.scale_f32 is not None
assert tens.quantization.zero_point is not None
@@ -173,7 +215,6 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_
compressed_offsets = []
encoded_streams = []
offset = 0
- max_single_buffer_len = 0
ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
ifm_depth = weights.shape[-2]
@@ -200,14 +241,10 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_
brick_weights = weights[:, :, :, idx : idx + count]
# Encode all weights into one chunk
- raw_stream = generate_brick(arch, brick_weights, ofm_block, tens.block_traversal, ifm_bitdepth)
+ raw_stream = generate_brick(arch, brick_weights, ofm_block_depth, tens.block_traversal, ifm_bitdepth)
encoded = encode(raw_stream)
encoded_streams.append(encoded)
- # Remember maximum encoded length for DoubleBuffering
- if max_single_buffer_len < len(encoded):
- max_single_buffer_len = len(encoded)
-
# Remember where we put it for linear addressing
compressed_offsets.append(offset)
offset += len(encoded)
@@ -219,18 +256,14 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_
# Also track complete length in the offsets array
compressed_offsets.append(offset)
- if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(encoded_streams) > 2:
- offset = 2 * max_single_buffer_len
- assert offset % 16 == 0
-
- tens.storage_shape = [1, 1, 1, offset]
tens.weight_compression_scales = compression_scales
- tens.weight_compression_config = WeightCompressionConfig(npu_block_type, ofm_block, ofm_depth_step)
tens.weight_compressed_offsets = compressed_offsets
tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
tens.compressed_values = encoded_streams
tens.brick_size = (weights_shape[0], weights_shape[1], weights_shape[2], min(tens.shape[-1], ofm_depth_step))
+ set_storage_shape(tens)
+ nng.weight_cache.add(tens)
def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False):
@@ -352,39 +385,29 @@ def update_pass_weight_and_scale_tensors(nng, arch):
for sg in nng.subgraphs:
for ps in sg.passes:
- if ps.weight_tensor is not None:
- npu_usage_of_tensor = find_npu_usage_of_tensor(ps.weight_tensor)
+ tens = ps.weight_tensor
+ if tens is not None:
+ npu_usage_of_tensor = find_npu_usage_of_tensor(tens)
if npu_usage_of_tensor == NpuBlockType.ConvolutionDepthWise:
- ps.weight_tensor.quant_values = np.transpose(ps.weight_tensor.quant_values, (0, 1, 3, 2))
- ps.weight_tensor.shape = ps.weight_tensor.storage_shape = ps.weight_tensor.bandwidth_shape = list(
- ps.weight_tensor.quant_values.shape
- )
- ps.weight_tensor.weight_transpose_depthwise = True
+ tens.quant_values = np.transpose(tens.quant_values, (0, 1, 3, 2))
+ tens.shape = tens.storage_shape = tens.bandwidth_shape = list(tens.quant_values.shape)
+ tens.weight_transpose_depthwise = True
- needs_dma = len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA"
+ needs_dma = tens.needs_dma()
if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
ofm_depth_step = ps.block_config[-1]
else:
- ofm_depth_step = ps.weight_tensor.shape[-1]
-
+ ofm_depth_step = tens.shape[-1]
compress_weights(
- ps.weight_tensor,
- arch,
- npu_usage_of_tensor,
- Block(ps.block_config[-3], ps.block_config[-4], ps.block_config[-1]),
- ofm_depth_step,
+ arch, nng, tens, npu_usage_of_tensor, ps.block_config[-1], ofm_depth_step,
)
# Update source tensor
- if len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA":
- src_tens = ps.weight_tensor.ops[0].inputs[0]
- src_tens.shape = ps.weight_tensor.shape
- src_tens.weight_transpose_depthwise = ps.weight_tensor.weight_transpose_depthwise
- src_tens.quant_values = ps.weight_tensor.quant_values
- src_tens.compressed_values = ps.weight_tensor.compressed_values
- src_tens.storage_shape = [1, 1, 1, ps.weight_tensor.weight_compressed_offsets[-1]]
- src_tens.brick_size = ps.weight_tensor.brick_size
- src_tens.weight_compression_scales = ps.weight_tensor.weight_compression_scales
- src_tens.weight_compressed_offsets = ps.weight_tensor.weight_compressed_offsets
+ if needs_dma:
+ src_tens = tens.get_dma_src_tensor()
+ src_tens.shape = tens.shape
+ src_tens.quant_values = tens.quant_values
+ src_tens.copy_compressed_weight_info(tens)
+ set_storage_shape(src_tens)
if ps.scale_tensor is not None:
rescale_for_faf = False