aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/shared_buffer_allocation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/shared_buffer_allocation.py')
-rw-r--r--ethosu/vela/shared_buffer_allocation.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index aa5f4c86..f52d3a92 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -47,6 +47,7 @@ class SharedBufferAllocation:
self.kernel = Kernel(1, 1)
self.is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise
self.uses_lut = False
+ self.ifm_count = 1
if ps.primary_op:
strides = ps.primary_op.attrs.get("strides", strides)
@@ -82,11 +83,19 @@ class SharedBufferAllocation:
if ifm_tensor:
self.ifm_resampling_mode = ifm_tensor.resampling_mode
self.ifm_bits = ifm_tensor.dtype.size_in_bits()
- if ifm_tensor.shape == [] and self.is_elementwise:
- # Elementwise operator with scalar in ifm, use ifm2 depth
- self.ifm_depth = ifm2_tensor.shape[-1]
- else:
+
+ if ifm_tensor.shape != []:
self.ifm_depth = ifm_tensor.shape[-1]
+
+ if self.is_elementwise:
+ self.ifm_count = 2
+ if ifm_tensor.shape == []: # Scalar in ifm1
+ assert ifm2_tensor
+ self.ifm_depth = ifm2_tensor.shape[-1]
+ self.ifm_count = 1
+ elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2
+ self.ifm_count = 1
+
if self.ifm_bits == 16:
if ps.npu_block_type != NpuBlockType.Pooling and has_scale:
self.use_accumulator_element = SHRAMElements.Acc40
@@ -137,7 +146,7 @@ class SharedBufferAllocation:
acc_banks = ofm_config.banks[self.use_accumulator_element]
# Update bank counts for IFM and Accumulator
- self.banks_required[SharedBufferArea.IFM] = ifm_config.banks[self.use_ifm_element]
+ self.banks_required[SharedBufferArea.IFM] = ifm_config.banks[self.use_ifm_element] * self.ifm_count
self.banks_required[SharedBufferArea.Accumulators] = 0 if self.is_elementwise else acc_banks
# Validating calculates bank layout and returns validity