aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/cascade_builder.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2021-05-27 18:49:40 +0100
committerTim Hall <tim.hall@arm.com>2021-05-27 18:57:39 +0100
commitd8339a75c9b655c0507e34238078fdad068b4023 (patch)
tree36a14726b30760169a83c0356803b480992fade8 /ethosu/vela/cascade_builder.py
parent64556f32ff7bfca6036a6598034464b13b64a4ef (diff)
downloadethos-u-vela-d8339a75c9b655c0507e34238078fdad068b4023.tar.gz
MLBEDSW-4034: New Scheduler Size or Performance Optimisation
- Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r--ethosu/vela/cascade_builder.py260
1 files changed, 260 insertions, 0 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
new file mode 100644
index 00000000..e4fa67e9
--- /dev/null
+++ b/ethosu/vela/cascade_builder.py
@@ -0,0 +1,260 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Groups Operators in a schedule together to form Cascades.
+from .numeric_util import round_up
+from .operation import NpuBlockType
+from .shape4d import Shape4D
+
+non_cascadable_blocks = (
+ NpuBlockType.Default,
+ NpuBlockType.VectorProduct,
+ NpuBlockType.ElementWise,
+ NpuBlockType.ReduceSum,
+)
+
+
+class CascadeInfo:
+ """Contains metadata about a cascade"""
+
+ def __init__(self, start, end, buffers, mem_usage: int):
+ self.start = start
+ self.end = end
+ self.buffers = buffers
+ self.mem_usage = mem_usage
+
+
+class BufferMap:
+ """Caches the buffers seen"""
+
+ def __init__(self):
+ self.buffer_map = {}
+
+ def get_buffer(self, producer, consumer, cost):
+ assert producer or consumer
+ key = (producer, consumer)
+ if key not in self.buffer_map:
+ # No cached buffer between these two SchedulerOperations
+ if consumer is None:
+ # There are either no consumers or multiple consumers - FeatureMap needs to be stored in full
+ buffer_shape = producer.ofm.shape
+ buffer_size = producer.ofm_size_in_bytes()
+ elif producer is None:
+ # First Op in subgraph or cascade - FeatureMap needs to be stored in full
+ buffer_shape = consumer.ifm.shape
+ buffer_size = consumer.ifm_size_in_bytes()
+ elif producer.requires_full_ofm or consumer.requires_full_ifm:
+ # FeatureMap needs to be stored in full
+ buffer_shape = max(producer.ofm.shape, consumer.ifm.shape)
+ buffer_size = max(producer.ofm_size_in_bytes(), consumer.ifm_size_in_bytes())
+ else:
+ # Use a rolling buffer
+ buffer_shape = rolling_buffer_shape(cost[producer].stripe, cost[consumer].stripe_input)
+ buffer_size = buffer_shape.elements() * producer.ofm.dtype.size_in_bytes()
+
+ self.buffer_map[key] = (buffer_shape, buffer_size)
+
+ return self.buffer_map[key]
+
+
+def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D:
+ """Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade"""
+ buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height)
+ # Rolling buffers have to conform to NHCWB16 format
+ return consumer_stripe_input.with_height(buffer_height).with_depth(round_up(producer_stripe.depth, 16))
+
+
+class CascadeBuilder:
+ """Class for grouping SchedulerOperations into cascades"""
+
+ def __init__(self, sched_ops, spilling, non_local_mem_usage=None):
+ self.sched_ops = sched_ops
+ self.no_cascade = 0
+ self.non_local_mem_usage = non_local_mem_usage if non_local_mem_usage else {}
+ self.spilling = spilling
+
+ def _is_cascadable(self, sched_op, cost) -> bool:
+ """Checks if 'sched_op' can be cascaded"""
+ return (
+ sched_op.op_type.npu_block_type not in non_cascadable_blocks
+ and cost.stripe.height < sched_op.ofm.shape.height
+ )
+
+ def _estimate_sram_usage(self, sched_op, cost) -> int:
+ """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
+ ifm2_size = sched_op.ifm2_size_in_bytes()
+ if sched_op.requires_full_ifm:
+ ifm_size = sched_op.ifm_size_in_bytes()
+ else:
+ ifm_size = (
+ cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
+ * sched_op.ifm.dtype.size_in_bytes()
+ )
+ if sched_op.requires_full_ofm:
+ ofm_size = sched_op.ofm_size_in_bytes()
+ else:
+ ofm_size = (
+ cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
+ )
+
+ return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
+
+ def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
+ ref_cost = ref_schedule.cost_map
+ fallback_cost = fallback_schedule.cost_map
+ cost = {}
+ cascade_map = {}
+ buffers = BufferMap()
+
+ # Peak memory usage so far - updated continously, unless dedicated SRAM where this is a hard limit
+ peak_sram_usage = guiding_mem_limit
+
+ idx = 0
+ while idx < len(self.sched_ops):
+ op = self.sched_ops[idx]
+ if op in cost:
+ # Already processed this Op
+ idx += 1
+ continue
+
+ if not self._is_cascadable(op, ref_cost[op]):
+ # Op is not a candidate for cascading - assign fallback cost
+ cost[op] = fallback_cost[op]
+ if not self.spilling:
+ peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
+ idx += 1
+ continue
+
+ # Propose a cascade starting with this Op
+ cascade_start = op.index
+ # Keep track of which Ops are in the proposed cascade as well as the best cascade so far
+ ops_in_cascade = [op]
+ ops_in_best_cascade = [op]
+ # Get the size of the weight buffer
+ weight_buffer = 0
+ if ref_cost[op].buffered_weight_tensor:
+ weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size()
+
+ # The first IFM needs to be stored in full
+ cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
+
+ # Add non-local memory usage
+ cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
+
+ # Sum of all intermediate cascade buffers (including weight buffers)
+ cascade_buffers = weight_buffer
+ # Best cascade size - Initially it's the fallback cost of the first Op in the cascade
+ best_cascade_size = self._estimate_sram_usage(op, fallback_cost[op])
+
+ # Op is the producer of the OFM consumed by the next Op to consider
+ producer = op
+ while True:
+ dependants = producer.get_dependants()
+ if len(dependants) != 1:
+ # producer is either the last Op in the schedule or the start of a branch
+ break
+
+ current_op = dependants[0]
+ if (
+ current_op in cost
+ or current_op not in ref_cost
+ or not self._is_cascadable(current_op, ref_cost[current_op])
+ or producer.ofm.shape != current_op.ifm.shape
+ ):
+ # Current op has already been processed or cannot be cascaded
+ break
+
+ # Get the size of the FeatureMap buffers between current and neighbouring Ops
+ op_full_ifm = current_op.ifm_size_in_bytes()
+ op_full_ofm = current_op.ofm_size_in_bytes()
+ _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
+
+ # Get the size of the weight buffer
+ op_weight_buffer = 0
+ if ref_cost[current_op].buffered_weight_tensor:
+ op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size()
+
+ # Calculate the uncascaded memory requirement for current Op
+ uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)
+
+ # Add current Op to cascade
+ ops_in_cascade.append(current_op)
+
+ # Increase the accumulated intermediate buffers in the cascade
+ cascade_buffers += op_ifm_buffer + op_weight_buffer
+
+ if self.spilling:
+ # For Dedicated SRAM only the intermediate buffers are in SRAM
+ if uncascaded_sram_usage < peak_sram_usage or cascade_buffers > peak_sram_usage:
+ # Cascade until an Op fits in its entirety or the accumulated buffers no longer fit
+ break
+ else:
+ # Any addition to the cascade that fits is the new best cascade for Dedicated SRAM
+ ops_in_best_cascade = [op for op in ops_in_cascade]
+ best_cascade_size = cascade_buffers
+
+ else:
+ # Calculate the total size of the current cascade
+ cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+
+ # Determine if cascading search should stop
+ if (
+ uncascaded_sram_usage < peak_sram_usage
+ and best_cascade_size < peak_sram_usage
+ or (cascade_ifm_size + cascade_buffers) > best_cascade_size
+ ):
+ # Both the existing cascade and current Op fits
+ break
+
+ # Determine if current cascade is the best so far
+ if cascade_size < best_cascade_size:
+ best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+ ops_in_best_cascade = [op for op in ops_in_cascade]
+
+ producer = current_op
+
+ if len(ops_in_best_cascade) > 1:
+ # A cascade was created - assign cascade and ref_cost to all of the Ops
+ cascade_end = cascade_start + (len(ops_in_best_cascade) - 1)
+ buffers_in_cascade = {}
+ prev_op = None
+ for cascaded_op in ops_in_best_cascade:
+ cost[cascaded_op] = ref_cost[cascaded_op]
+ cost[cascaded_op].cascade = cascade_end
+ if prev_op:
+ rolling_buffer_shape, _ = buffers.get_buffer(prev_op, cascaded_op, ref_cost)
+ buffers_in_cascade[cascaded_op] = rolling_buffer_shape
+
+ prev_op = cascaded_op
+
+ # Create a CascadeInfo for the cascade
+ cascade_map[cascade_end] = CascadeInfo(
+ cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
+ )
+ if not self.spilling:
+ # Update peak memory usage
+ peak_sram_usage = max(best_cascade_size, peak_sram_usage)
+ else:
+ # Assign fallback cost to the initial Op
+ cost[op] = fallback_cost[op]
+ if not self.spilling:
+ peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
+
+ # Update costing and cascde information for the ref_schedule
+ ref_schedule.cost_map = cost
+ ref_schedule.cascades = cascade_map
+ return ref_schedule