diff options
author | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2022-03-30 10:30:25 +0200 |
---|---|---|
committer | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2022-03-30 15:54:14 +0200 |
commit | d85750702229af97c0b0bbda6e397a23254b6144 (patch) | |
tree | 389962105a35d5cef595cfeb5d640bd59a0d0ff8 /ethosu/vela/scheduler.py | |
parent | cc5f4de1c35ba44fca7ff6295c6ae846f8242344 (diff) | |
download | ethos-u-vela-d85750702229af97c0b0bbda6e397a23254b6144.tar.gz |
Update version of Black to 22.3.0
Update version of Black to 22.3.0 due to updated dependencies.
Updates to fix reported issues due to new version.
Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com>
Change-Id: I60056aae452093ce8dcea1f499ecced22b25eef1
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r-- | ethosu/vela/scheduler.py | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index fe2d711e..a19d0531 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -113,7 +113,13 @@ class SchedulerOpInfo: self.full_weight_transfer_cycles = 0 def copy(self): - res = SchedulerOpInfo(self.block_config, self.weights_size, self.stripe_input, self.stripe_input2, self.stripe,) + res = SchedulerOpInfo( + self.block_config, + self.weights_size, + self.stripe_input, + self.stripe_input2, + self.stripe, + ) res.cascade = self.cascade return res @@ -135,7 +141,10 @@ class SchedulerOptions: """Contains options for the Scheduler""" def __init__( - self, optimization_strategy, sram_target, verbose_schedule, + self, + optimization_strategy, + sram_target, + verbose_schedule, ): self.optimization_strategy = optimization_strategy self.optimization_sram_limit = sram_target @@ -175,15 +184,28 @@ class SchedulerOperation: ) self.ifm_ublock = arch.ifm_ublock - self.ifm = SchedulerTensor(ps.ifm_shapes[0], ps.ifm_tensor.dtype, ps.ifm_tensor.mem_area, ps.ifm_tensor.format,) + self.ifm = SchedulerTensor( + ps.ifm_shapes[0], + ps.ifm_tensor.dtype, + ps.ifm_tensor.mem_area, + ps.ifm_tensor.format, + ) self.ifm2 = None if ps.ifm2_tensor: self.ifm2 = SchedulerTensor( - ps.ifm_shapes[1], ps.ifm2_tensor.dtype, ps.ifm2_tensor.mem_area, ps.ifm2_tensor.format, + ps.ifm_shapes[1], + ps.ifm2_tensor.dtype, + ps.ifm2_tensor.mem_area, + ps.ifm2_tensor.format, ) - self.ofm = SchedulerTensor(ps.ofm_shapes[0], ps.ofm_tensor.dtype, ps.ofm_tensor.mem_area, ps.ofm_tensor.format,) + self.ofm = SchedulerTensor( + ps.ofm_shapes[0], + ps.ofm_tensor.dtype, + ps.ofm_tensor.mem_area, + ps.ofm_tensor.format, + ) # Input volume width and height required to produce the smallest possible stripe self.min_stripe_input_w, self.min_stripe_input_h = self._calculate_min_stripe_input() @@ -481,7 +503,11 @@ class Scheduler: lr_graph = live_range.LiveRangeGraph() for mem_area, mem_type_set in memories_list: live_range.extract_live_ranges_from_cascaded_passes( - self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum, + self.nng.get_root_subgraph(), + mem_area, + mem_type_set, + lr_graph, + Tensor.AllocationQuantum, ) # Populate time-array with memory used by live ranges @@ -923,7 +949,11 @@ class Scheduler: return best_schedule def optimize_schedule( - self, schedule: Schedule, max_sched: Schedule, max_template: Schedule, options: SchedulerOptions, + self, + schedule: Schedule, + max_sched: Schedule, + max_template: Schedule, + options: SchedulerOptions, ) -> Schedule: """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule""" sram_limit = options.optimization_sram_limit @@ -994,7 +1024,11 @@ class Scheduler: lr_graph = live_range.LiveRangeGraph() for mem_area, mem_type_set in memories_list: live_range.extract_live_ranges_from_cascaded_passes( - self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum, + self.nng.get_root_subgraph(), + mem_area, + mem_type_set, + lr_graph, + Tensor.AllocationQuantum, ) max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area) @@ -1252,7 +1286,14 @@ def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_o cascaded_passes = [] for idx, ps in enumerate(sg.passes): cps = CascadedPass( - ps.name, SchedulingStrategy.WeightStream, ps.inputs, [], ps.outputs, [ps], ps.placement, False, + ps.name, + SchedulingStrategy.WeightStream, + ps.inputs, + [], + ps.outputs, + [ps], + ps.placement, + False, ) cps.time = idx |