aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/cascade_builder.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index ba210032..3a3026fe 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -76,8 +76,11 @@ class BufferMap:
def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D:
"""Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade"""
buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height)
+ # Striding on the consumer op can result in IFM widths that are narrower than the OFM width of the producer.
+ # Therefore, the maximum of the two needs to be used.
+ buffer_width = max(producer_stripe.width, consumer_stripe_input.width)
# Rolling buffers have to conform to NHCWB16 format
- return consumer_stripe_input.with_height(buffer_height).with_depth(round_up(producer_stripe.depth, 16))
+ return Shape4D([1, buffer_height, buffer_width, round_up(producer_stripe.depth, 16)])
class CascadeBuilder: