aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/register_command_stream_util.py')
-rw-r--r--ethosu/vela/register_command_stream_util.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index 3751d88e..83126ead 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -163,7 +163,7 @@ def get_h_ranges(
return [get_address_range(fm, strides, y, x0, c0, y, x1, c1) for y in range(y0, y1 + 1)]
-def get_address_ranges_for_area(fm: NpuFeatureMap, start: PointXYZ, end: PointXYZ) -> List[NpuAddressRange]:
+def get_address_ranges_for_area(fm: NpuFeatureMap, start: PointXYZ, end: PointXYZ) -> List[Optional[NpuAddressRange]]:
"""
Returns a list of adddress ranges that covers the area start - end (inclusive).
Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
@@ -183,7 +183,7 @@ def get_address_ranges_for_area(fm: NpuFeatureMap, start: PointXYZ, end: PointXY
h, w, c = fm.shape
y0, x0, c0 = start.y, start.x, start.z
y1, x1, c1 = min(end.y, h - 1), min(end.x, w - 1), min(end.z, c - 1)
- ranges = []
+ ranges: List[Optional[NpuAddressRange]] = []
if x0 < width_0 and y0 < height_0:
# Horizontal ranges for tile 0
ranges.extend(get_h_ranges(fm, strides, y0, x0, c0, min(y1, height_0 - 1), min(x1, width_0 - 1), c1))
@@ -373,7 +373,7 @@ def intersects(
else:
# The OFM produces a part of the IFM (e.g. a stripe), or the IFM consumes part of the OFM.
# In this case, address comparison between the two areas is needed
- ifm_ranges = get_address_ranges_for_area(ifm, ifm_start_coord, ifm_end_coord)
+ ifm_ranges: List[Optional[NpuAddressRange]] = get_address_ranges_for_area(ifm, ifm_start_coord, ifm_end_coord)
prev_ofm_ranges = get_address_ranges_for_area(prev_ofm, ofm_start_coord, ofm_end_coord)
res = range_lists_overlap(ifm_ranges, prev_ofm_ranges)
return res