aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/shared_buffer_allocation.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 7657dffa..58856a3e 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -44,7 +44,7 @@ class SharedBufferAllocation:
strides = (1, 1, 1, 1)
dilation = (1, 1, 1, 1)
self.kernel = Kernel(1, 1)
- is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise
+ self.is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise
self.uses_lut = False
if ps.primary_op:
@@ -63,14 +63,14 @@ class SharedBufferAllocation:
self.kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
self.uses_lut = ps.primary_op.activation_lut is not None
- self.is_equal_depth_op = is_elementwise or ps.npu_block_type in (
+ self.is_equal_depth_op = self.is_elementwise or ps.npu_block_type in (
NpuBlockType.ConvolutionDepthWise,
NpuBlockType.Pooling,
)
self.strides = strides
self.use_accumulator_element = SHRAMElements.Acc32
- if is_elementwise:
+ if self.is_elementwise:
self.use_ifm_element = SHRAMElements.IFM8_Elementwise
else:
self.use_ifm_element = SHRAMElements.IFM8
@@ -81,7 +81,7 @@ class SharedBufferAllocation:
if ifm_tensor:
self.ifm_resampling_mode = ifm_tensor.resampling_mode
self.ifm_bits = ifm_tensor.dtype.size_in_bits()
- if ifm_tensor.shape == [] and is_elementwise:
+ if ifm_tensor.shape == [] and self.is_elementwise:
# Elementwise operator with scalar in ifm, use ifm2 depth
self.ifm_depth = ifm2_tensor.shape[-1]
else:
@@ -94,7 +94,7 @@ class SharedBufferAllocation:
self.use_ifm_element == SHRAMElements.IFM16_Elementwise
)
elif self.ifm_bits == 32:
- assert is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum, "Unsupported 32-bit IFM operation"
+ assert self.is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum, "Unsupported 32-bit IFM operation"
self.use_ifm_element = SHRAMElements.IFM32
else:
assert self.ifm_bits == 8, "Unexpected IFM bitdepth"
@@ -131,9 +131,11 @@ class SharedBufferAllocation:
if ofm_config is None:
return None
+ acc_banks = ofm_config.banks[self.use_accumulator_element]
+
# Update bank counts for IFM and Accumulator
self.banks_required[SharedBufferArea.IFM] = ifm_config.banks[self.use_ifm_element]
- self.banks_required[SharedBufferArea.Accumulators] = ofm_config.banks[self.use_accumulator_element]
+ self.banks_required[SharedBufferArea.Accumulators] = 0 if self.is_elementwise else acc_banks
# Validating calculates bank layout and returns validity
if not self.is_valid():