aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-05-07 08:12:58 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit3c07c97e0202c1cf01eba06c24b37a8f15ff7a7c (patch)
tree5856b7727a99b3c0baa00f5486f0c3b53e8e38e6
parent86d49935c3736c7aaa419abda07fa20c37c991a8 (diff)
downloadethos-u-vela-3c07c97e0202c1cf01eba06c24b37a8f15ff7a7c.tar.gz
MLBEDSW-1941: Bug fix shared weights
If same weight tensor was used with different block configs, errors would occur. Fixed by always cloning weight tensors, using a global weight compression cache and modifying the linear allocator to detect multiple usage of same weight compression. Change-Id: I91ca59176e1c59c66e0ac7a4227f2b5f0b47053f Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--ethosu/vela/compiler_driver.py5
-rw-r--r--ethosu/vela/high_level_command_stream_generator.py8
-rw-r--r--ethosu/vela/mark_tensors.py17
-rw-r--r--ethosu/vela/nn_graph.py1
-rw-r--r--ethosu/vela/tensor.py34
-rw-r--r--ethosu/vela/tensor_allocation.py16
-rw-r--r--ethosu/vela/tflite_reader.py34
-rw-r--r--ethosu/vela/weight_compressor.py127
8 files changed, 138 insertions, 104 deletions
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index 64aff06b..b6a98a64 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -144,13 +144,14 @@ def compiler_driver(nng, arch, options, scheduler_options):
# processed first during serialization into tensors
first_npu_sg = nng.subgraphs[1]
assert first_npu_sg.placement == PassPlacement.Npu
+ # Use the linear allocator for constant tensors
tensor_allocation.allocate_tensors(
nng,
first_npu_sg,
arch,
permanent_storage,
scheduler_options.use_ifm_ofm_overlap,
- options.tensor_allocator,
+ TensorAllocator.LinearAlloc,
options.verbose_allocation,
options.show_minimum_possible_allocation,
lr_graph_flash,
@@ -195,7 +196,7 @@ def compiler_driver(nng, arch, options, scheduler_options):
arch,
permanent_storage,
scheduler_options.use_ifm_ofm_overlap,
- options.tensor_allocator,
+ TensorAllocator.LinearAlloc,
options.verbose_allocation,
options.show_minimum_possible_allocation,
)
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 0cc70a7f..3b968dc8 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -27,12 +27,8 @@ from .operation import NpuBlockType
from .tensor import TensorPurpose
-def need_dma(tens):
- return len(tens.ops) == 1 and tens.ops[0].type == "DMA"
-
-
def dma_if_necessary(ps, box, tensor):
- if need_dma(tensor):
+ if tensor.needs_dma():
dma_op = tensor.ops[0]
in_tensor = dma_op.inputs[0]
yield DMA(in_tensor, tensor, box)
@@ -93,7 +89,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
if strat == SchedulingStrategy.WeightStream:
ofm_step = block_config[-1]
ofm_stop = ofm_end[-1]
- if weight_tensor is None or not need_dma(weight_tensor):
+ if weight_tensor is None or not weight_tensor.needs_dma():
ofm_step = ofm_stop
for start in range(ofm_start[-1], ofm_stop, ofm_step):
end = min(start + ofm_step, ofm_stop)
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index e7b3e50f..cd70446b 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -17,7 +17,6 @@
# Mark purpose and select formats for Tensors. Also compresses the weights.
from . import rewrite_graph
from . import weight_compressor
-from .architecture_features import Block
from .operation import NpuBlockType
from .tensor import TensorFormat
from .tensor import TensorPurpose
@@ -348,18 +347,12 @@ def mark_tensor_format(nng, arch, verbose_tensor_format=False):
for tens, fmt in formats_for_tensor.items():
tens.set_format(fmt, arch)
if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
- npu_block_type = find_npu_usage_of_tensor(tens)
- if len(tens.ops) == 1 and tens.ops[0].type == "DMA":
- weight_compressor.compress_weights(tens, arch, npu_block_type, Block(32, 32, 32), 32)
+ src_tens = tens.get_dma_src_tensor()
+ if src_tens is not None:
+ npu_block_type = find_npu_usage_of_tensor(tens)
+ weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 32, 32)
# Alias compressed weights back into source tensor
- src_tens = tens.ops[0].inputs[0]
- src_tens.compressed_values = tens.compressed_values
- src_tens.storage_shape = tens.storage_shape
- src_tens.brick_size = tens.brick_size
- src_tens.weight_compression_scales = tens.weight_compression_scales
- src_tens.weight_compressed_offsets = tens.weight_compressed_offsets
- src_tens.compression_scale_for_worst_weight_stream = tens.compression_scale_for_worst_weight_stream
- src_tens.storage_compression_scale = tens.storage_compression_scale
+ src_tens.copy_compressed_weight_info(tens)
if verbose_tensor_format:
nng.print_passes_with_tensors()
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index ed2ab322..ea35c087 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -485,6 +485,7 @@ class Graph:
self.bits_per_element = {}
self.total_size = {}
self.total_elements = {}
+ self.weight_cache = None # See CompressedWeightCache
def get_root_subgraph(self):
return self.subgraphs[0]
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 160cf630..2f91f61c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -225,7 +225,6 @@ class Tensor:
"quantization",
"weight_compressed_offsets",
"element_size_bytes",
- "reshaped",
"block_traversal",
"offset",
"cpu_tensor",
@@ -273,8 +272,6 @@ class Tensor:
# quantization parameters
self.quantization = None
-
- self.reshaped = False
self.block_traversal = TensorBlockTraversal.Default
self.resampling_mode = resampling_mode.NONE
@@ -294,20 +291,13 @@ class Tensor:
res.values = self.values
res.quant_values = self.quant_values
- res.compressed_values = self.compressed_values
res.mem_area = self.mem_area
res.format = self.format
res.purpose = self.purpose
res.sub_purpose = self.sub_purpose
res.alignment = self.alignment
- res.weight_transpose_depthwise = self.weight_transpose_depthwise
-
- res.storage_compression_scale = self.storage_compression_scale
res.bandwidth_compression_scale = self.bandwidth_compression_scale
- res.compression_scale_for_worst_weight_stream = self.compression_scale_for_worst_weight_stream
- res.weight_compression_scales = self.weight_compression_scales
res.storage_rounding_quantum = self.storage_rounding_quantum
- res.brick_size = self.brick_size
res.address = 0
if self.quantization is not None:
@@ -317,6 +307,7 @@ class Tensor:
res.resampling_mode = self.resampling_mode
+ res.copy_compressed_weight_info(self)
return res
def clone_into_fast_storage(self, arch):
@@ -324,6 +315,19 @@ class Tensor:
res.mem_area = arch.fast_storage_mem_area
return res
+ def copy_compressed_weight_info(self, src_tens):
+ # Copies compressed values + all related weight compression info from the given tensor
+ self.compressed_values = src_tens.compressed_values
+ self.storage_shape = src_tens.storage_shape
+ self.brick_size = src_tens.brick_size
+ self.weight_compression_scales = src_tens.weight_compression_scales
+ self.weight_compressed_offsets = src_tens.weight_compressed_offsets
+ self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
+ self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
+ self.storage_compression_scale = src_tens.storage_compression_scale
+ self.block_traversal = src_tens.block_traversal
+ self.weight_compression_config = src_tens.weight_compression_config
+
def set_format(self, fmt, arch):
self.format = fmt
shape_len = 0
@@ -527,6 +531,14 @@ class Tensor:
return strides
+ def needs_dma(self):
+ return len(self.ops) == 1 and self.ops[0].type == "DMA"
+
+ def get_dma_src_tensor(self):
+ # For weight tensors that need DMA: returns the source tensor in Flash, else None
+ # Note: for DMA ops, Pass.weight_tensor is referring to the SRAM weight tensor
+ return self.ops[0].inputs[0] if self.needs_dma() else None
+
def compressed_stream_index_from_coord(self, coord):
assert self.format == TensorFormat.WeightsCompressed
assert len(self.compressed_values) > 0
@@ -575,7 +587,7 @@ class Tensor:
if len(self.weight_compressed_offsets) == 0:
return 0
- if len(self.ops) == 1 and self.ops[0].type == "DMA" and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
+ if self.needs_dma() and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
depth = orig_coord[-1]
brick_depth = self.brick_size[-1]
# Clamp position at final element index
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index cd2b570f..e3952df3 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -27,18 +27,26 @@ from .nn_graph import TensorAllocator
from .tensor import MemArea
-def linear_allocate_live_ranges(live_ranges, alloc_granularity=256):
+def linear_allocate_live_ranges(live_ranges, alloc_granularity=16):
+ # Allocates using increasing addresses. Duplicate constant tensors will be allocated to the same address
total_sz = 0
allocated_tensors = []
- # just assign increasing addresses
+ # just assign increasing addresses, except for duplicates
for tens, lr in live_ranges.ranges.items():
if tens in allocated_tensors:
continue
- lr.set_address(total_sz)
+ address = total_sz
+ if tens.weight_compression_config is not None:
+ for allocated_tens in allocated_tensors:
+ if allocated_tens.weight_compression_config == tens.weight_compression_config:
+ address = allocated_tens.address
+ break
+ lr.set_address(address)
allocated_tensors += lr.tensors
- total_sz += numeric_util.round_up(int(math.ceil(lr.size)), alloc_granularity)
+ if address == total_sz:
+ total_sz += numeric_util.round_up(int(math.ceil(lr.size)), alloc_granularity)
return total_sz
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 109ae0ec..5ab90f04 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -39,24 +39,24 @@ def decode_str(s):
return s.decode("utf-8")
-def reshape_tensor_add_const_op(tens, reorder):
- if not tens.reshaped:
- original_shape = tens.shape
- tens.name = tens.name + "_reshape"
- tens.shape = [original_shape[idx] for idx in reorder]
- tens.bandwidth_shape = tens.shape
- tens.storage_shape = tens.shape
+def clone_and_reshape_tensor(src_tens, reorder):
- if tens.values is not None:
- tens.values = tens.values.transpose(reorder)
+ tens = src_tens.clone("_reshape")
+ tens.shape = [src_tens.shape[idx] for idx in reorder]
+ tens.bandwidth_shape = tens.shape
+ tens.storage_shape = tens.shape
- if tens.quant_values is not None:
- tens.quant_values = tens.quant_values.transpose(reorder)
+ if tens.values is not None:
+ tens.values = tens.values.transpose(reorder)
- op = Operation("Const", tens.name)
- op.outputs = [tens]
- tens.ops = [op]
- tens.reshaped = True
+ if tens.quant_values is not None:
+ tens.quant_values = tens.quant_values.transpose(reorder)
+
+ op = Operation("Const", tens.name)
+ op.outputs = [tens]
+ tens.ops = [op]
+
+ return tens
class TFLiteSubgraph:
@@ -137,10 +137,10 @@ class TFLiteSubgraph:
activation_function_to_split_out = None
if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"):
- reshape_tensor_add_const_op(inputs[1], (1, 2, 3, 0))
+ inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
if op_type.startswith("FullyConnected"):
- reshape_tensor_add_const_op(inputs[1], (1, 0))
+ inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
if opt_serializer is not None:
op.attrs = opt_serializer.deserialize(op_data.BuiltinOptions(), op_data.CustomOptionsAsNumpy())
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index a81b1fb4..450e091e 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -21,7 +21,6 @@ from collections import namedtuple
import numpy as np
from ethosu import mlw_codec
-from .architecture_features import Block
from .data_type import DataType
from .errors import UnsupportedFeatureError
from .nn_graph import SchedulingStrategy
@@ -35,6 +34,46 @@ from .tensor import TensorPurpose
from .tensor import TensorSubPurpose
+# Contains meta info for a weight compression. If two tensors have identical weight compression config,
+# then they also will have identical compressed weights.
+WeightCompressionConfig = namedtuple(
+ "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "equivalence_id"]
+)
+
+
+def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step):
+ # Note: for an ofm block only its depth is used in weight compression.
+ # And block depth > ofm depth gives same result as block depth == ofm depth
+ block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
+ return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, tens.equivalence_id)
+
+
+def set_storage_shape(tens):
+ # Sets the storage shape depending on the tensor's sub purpose
+ if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
+ offset = 2 * np.amax([len(x) for x in tens.compressed_values])
+ assert offset % 16 == 0
+ else:
+ offset = tens.weight_compressed_offsets[-1]
+ tens.storage_shape = [1, 1, 1, offset]
+
+
+class CompressedWeightCache:
+ # Contains weight compressions for all weight tensors in a graph
+ def __init__(self):
+ self.cache = {} # maps from WeightCompressionConfig to a tensor clone containing compressed weights
+
+ def get_tensor_with_same_compression(self, wcc):
+ return self.cache.get(wcc)
+
+ def add(self, tens):
+ # Adds the compressed weights from the tensor to the cache
+ wcc = tens.weight_compression_config
+ # Clone the tensor to make sure that nothing related to the weight compression is modified
+ tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
+ self.cache[wcc] = tens_clone
+
+
def encode(weight_stream):
assert np.amin(weight_stream) >= -255
assert np.amax(weight_stream) <= 255
@@ -51,7 +90,7 @@ def encode(weight_stream):
return compressed
-def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth):
+def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth):
is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
subkernel_max = arch.subkernel_max
@@ -74,8 +113,8 @@ def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth
stream = []
# Top level striping - OFM blocks in the entire brick's depth
- for ofm_block_z in range(0, ofm_depth, ofm_block.depth):
- clipped_ofm_block_depth = min(ofm_block.depth, ofm_depth - ofm_block_z)
+ for ofm_block_z in range(0, ofm_depth, ofm_block_depth):
+ clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z)
# IFM blocks required for the brick
for ifm_block_z in range(0, (1 if is_depthwise else ifm_depth), ifm_block_depth):
if is_depthwise:
@@ -139,20 +178,23 @@ def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth
# Compress the weights
-def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_val=None, max_val=None):
+def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step):
assert tens.purpose == TensorPurpose.Weights
assert tens.format == TensorFormat.WeightsCompressed
- WeightCompressionConfig = namedtuple("WeightCompressionConfig", ["npu_block_type", "ofm_block", "ofm_depth_step"])
-
- # check if weights have already been compressed
- wcc = tens.weight_compression_config
- if wcc is not None:
- assert wcc.npu_block_type == npu_block_type, "Weights not used by the same operator type"
-
- if wcc.ofm_block == ofm_block and wcc.ofm_depth_step == ofm_depth_step:
- return
-
+ # Check the weight cache
+ if nng.weight_cache is None:
+ nng.weight_cache = CompressedWeightCache()
+ wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step)
+ tens.weight_compression_config = wcc
+ tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
+ if tens_cached is not None:
+ # Cache hit, copy weights from the cache
+ tens.copy_compressed_weight_info(tens_cached)
+ set_storage_shape(tens)
+ return
+
+ # No cache hit, perform the compression
assert tens.quantization is not None
assert tens.quantization.scale_f32 is not None
assert tens.quantization.zero_point is not None
@@ -173,7 +215,6 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_
compressed_offsets = []
encoded_streams = []
offset = 0
- max_single_buffer_len = 0
ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
ifm_depth = weights.shape[-2]
@@ -200,14 +241,10 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_
brick_weights = weights[:, :, :, idx : idx + count]
# Encode all weights into one chunk
- raw_stream = generate_brick(arch, brick_weights, ofm_block, tens.block_traversal, ifm_bitdepth)
+ raw_stream = generate_brick(arch, brick_weights, ofm_block_depth, tens.block_traversal, ifm_bitdepth)
encoded = encode(raw_stream)
encoded_streams.append(encoded)
- # Remember maximum encoded length for DoubleBuffering
- if max_single_buffer_len < len(encoded):
- max_single_buffer_len = len(encoded)
-
# Remember where we put it for linear addressing
compressed_offsets.append(offset)
offset += len(encoded)
@@ -219,18 +256,14 @@ def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_
# Also track complete length in the offsets array
compressed_offsets.append(offset)
- if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(encoded_streams) > 2:
- offset = 2 * max_single_buffer_len
- assert offset % 16 == 0
-
- tens.storage_shape = [1, 1, 1, offset]
tens.weight_compression_scales = compression_scales
- tens.weight_compression_config = WeightCompressionConfig(npu_block_type, ofm_block, ofm_depth_step)
tens.weight_compressed_offsets = compressed_offsets
tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
tens.compressed_values = encoded_streams
tens.brick_size = (weights_shape[0], weights_shape[1], weights_shape[2], min(tens.shape[-1], ofm_depth_step))
+ set_storage_shape(tens)
+ nng.weight_cache.add(tens)
def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False):
@@ -352,39 +385,29 @@ def update_pass_weight_and_scale_tensors(nng, arch):
for sg in nng.subgraphs:
for ps in sg.passes:
- if ps.weight_tensor is not None:
- npu_usage_of_tensor = find_npu_usage_of_tensor(ps.weight_tensor)
+ tens = ps.weight_tensor
+ if tens is not None:
+ npu_usage_of_tensor = find_npu_usage_of_tensor(tens)
if npu_usage_of_tensor == NpuBlockType.ConvolutionDepthWise:
- ps.weight_tensor.quant_values = np.transpose(ps.weight_tensor.quant_values, (0, 1, 3, 2))
- ps.weight_tensor.shape = ps.weight_tensor.storage_shape = ps.weight_tensor.bandwidth_shape = list(
- ps.weight_tensor.quant_values.shape
- )
- ps.weight_tensor.weight_transpose_depthwise = True
+ tens.quant_values = np.transpose(tens.quant_values, (0, 1, 3, 2))
+ tens.shape = tens.storage_shape = tens.bandwidth_shape = list(tens.quant_values.shape)
+ tens.weight_transpose_depthwise = True
- needs_dma = len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA"
+ needs_dma = tens.needs_dma()
if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
ofm_depth_step = ps.block_config[-1]
else:
- ofm_depth_step = ps.weight_tensor.shape[-1]
-
+ ofm_depth_step = tens.shape[-1]
compress_weights(
- ps.weight_tensor,
- arch,
- npu_usage_of_tensor,
- Block(ps.block_config[-3], ps.block_config[-4], ps.block_config[-1]),
- ofm_depth_step,
+ arch, nng, tens, npu_usage_of_tensor, ps.block_config[-1], ofm_depth_step,
)
# Update source tensor
- if len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA":
- src_tens = ps.weight_tensor.ops[0].inputs[0]
- src_tens.shape = ps.weight_tensor.shape
- src_tens.weight_transpose_depthwise = ps.weight_tensor.weight_transpose_depthwise
- src_tens.quant_values = ps.weight_tensor.quant_values
- src_tens.compressed_values = ps.weight_tensor.compressed_values
- src_tens.storage_shape = [1, 1, 1, ps.weight_tensor.weight_compressed_offsets[-1]]
- src_tens.brick_size = ps.weight_tensor.brick_size
- src_tens.weight_compression_scales = ps.weight_tensor.weight_compression_scales
- src_tens.weight_compressed_offsets = ps.weight_tensor.weight_compressed_offsets
+ if needs_dma:
+ src_tens = tens.get_dma_src_tensor()
+ src_tens.shape = tens.shape
+ src_tens.quant_values = tens.quant_values
+ src_tens.copy_compressed_weight_info(tens)
+ set_storage_shape(src_tens)
if ps.scale_tensor is not None:
rescale_for_faf = False