diff options
author | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2022-03-01 12:39:55 +0100 |
---|---|---|
committer | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2022-03-21 11:09:39 +0100 |
commit | 845e23200d471e44f274940846e400d170b5ff37 (patch) | |
tree | 28a01492bf11f0ff69309ead9bd8a1bad9e14cbb /ethosu/vela/scheduler.py | |
parent | d2b5510697e7789f5a416f9d80d3cb640eecc092 (diff) | |
download | ethos-u-vela-845e23200d471e44f274940846e400d170b5ff37.tar.gz |
MLBEDSW-3367 Add mypy to pre-commit
Add mypy to pre-commit and clean up all reported errors.
Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com>
Change-Id: If7dc869f5fecdb0e2db40f14e7d9db21aa33df71
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r-- | ethosu/vela/scheduler.py | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 284848f5..73133bcd 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -17,6 +17,9 @@ # Description: # The scheduler creates and searches for an optimal plan for the network, selecting block configurations and # subdivisions for the Operators +# For Class name forward references for the type annotations. (see PEP 563). +from __future__ import annotations + import copy from collections import namedtuple from enum import auto @@ -25,6 +28,11 @@ from typing import Dict from typing import List from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING + +# Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import. +if TYPE_CHECKING: + from .npu_performance import CycleCost import numpy as np @@ -57,6 +65,7 @@ from .tensor import Tensor from .tensor import TensorFormat from .tensor import TensorPurpose from .tensor import TensorSubPurpose +from .weight_compressor import NpuWeightTensor def shape_for_format(shape: Shape4D, tensor_format: TensorFormat) -> Shape4D: @@ -95,10 +104,10 @@ class SchedulerOpInfo: self.cascade = 0 # Assigned by CascadeBuilder. 0 means not part of a cascade self.time_index = None # Set by update_op_memory_snapshot self.ofm_depth_slices: List[int] = [0, stripe.depth] - self.npu_weights_tensor = None - self.npu_scales_tensor = None - self.buffered_weight_tensor = None - self.cycles = None + self.npu_weights_tensor: Optional[NpuWeightTensor] = None + self.npu_scales_tensor: Optional[NpuWeightTensor] = None + self.buffered_weight_tensor: Optional[Tensor] = None + self.cycles: Optional[CycleCost] = None self.slack_buffering_cycles = 0 self.slack_buffering_memory = 0 self.full_weight_transfer_cycles = 0 @@ -230,7 +239,7 @@ class SchedulerOperation: def create_scheduler_info(self, nng: Graph, stripe: Shape4D) -> SchedulerOpInfo: """Returns schedule info about this SchedulerOperation based on how many ofm elements it should produce""" ifm_shape = self.ifm.shape - ifm2_shape = self.ifm2 and self.ifm2.shape + ifm2_shape = self.ifm2.shape if self.ifm2 is not None else None ofm_shape = stripe if ofm_shape != self.ofm.shape: @@ -273,14 +282,14 @@ class SchedulerOperation: return get_ifm_area_required(ofm_shape_to_produce, self.kernel, self.resampling_mode) - def _calculate_min_stripe_input(self) -> Shape4D: + def _calculate_min_stripe_input(self) -> Tuple[int, int]: # Calculate the input volume required height and width for the smallest possible stripe (h,w = 1,1) min_stripe = self.ofm.shape.with_hw(1, 1) return self._get_stripe_input_requirement(min_stripe) def _get_block_config( self, ifm_shape: Shape4D, ifm2_shape: Optional[Shape4D], uses_scalar: bool, ofm_shape: Shape4D - ) -> ArchitectureBlockConfig: + ) -> Optional[ArchitectureBlockConfig]: # Returns a block config and SHRAM layout lut_banks = 2 if self.parent_op.activation_lut else 0 return find_block_config( @@ -325,7 +334,7 @@ class Schedule: self.cost_map: Dict[SchedulerOperation, SchedulerOpInfo] = {} self.cascades: Dict[int, CascadeInfo] = {} self.fast_storage_peak_usage = 0 - self.memory_snapshot = None + self.memory_snapshot: Optional[List[int]] = None @property def name(self): @@ -340,7 +349,7 @@ class Scheduler: self.sg = sg self.arch = arch self.sched_ops: List[SchedulerOperation] = [] - self.max_schedule = None + self.max_schedule: Optional[Schedule] = None self.scheduler_options = options def avoid_nhcwb16_for_ofm(self, tens, ps, arch): @@ -524,7 +533,7 @@ class Scheduler: def propose_operator_buffering( self, sched_op: SchedulerOperation, - prev_op: SchedulerOperation, + prev_op: Optional[SchedulerOperation], buffered_schedule: Schedule, ref_schedule: Schedule, staging_limit_bytes, @@ -605,7 +614,7 @@ class Scheduler: cost.npu_scales_tensor = full_scales return - encoded_weights = full_weights + encoded_weights: Optional[NpuWeightTensor] = full_weights encoded_scales = full_scales # How many NPU cycles are available under the previously executing @@ -681,7 +690,7 @@ class Scheduler: cost.block_config, cost.ofm_depth_slices, ) - + assert encoded_weights is not None # Chosen buffering might not fit at all, iterate until it does # or until the minimum usable slice size is reached if ( @@ -747,7 +756,7 @@ class Scheduler: cost_map = min_schedule.cost_map # Keep track of the previous Op - which consumes the current Op's OFM - prev_op = None + prev_op: Optional[SchedulerOperation] = None for sched_op in reversed(self.sched_ops): min_stripe_height = prev_op.kernel.stride.y if prev_op else 1 min_stripe = sched_op.ofm.shape.with_height(min_stripe_height) |