diff options
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/shared_buffer_allocation.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py index 7657dffa..58856a3e 100644 --- a/ethosu/vela/shared_buffer_allocation.py +++ b/ethosu/vela/shared_buffer_allocation.py @@ -44,7 +44,7 @@ class SharedBufferAllocation: strides = (1, 1, 1, 1) dilation = (1, 1, 1, 1) self.kernel = Kernel(1, 1) - is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise + self.is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise self.uses_lut = False if ps.primary_op: @@ -63,14 +63,14 @@ class SharedBufferAllocation: self.kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1]) self.uses_lut = ps.primary_op.activation_lut is not None - self.is_equal_depth_op = is_elementwise or ps.npu_block_type in ( + self.is_equal_depth_op = self.is_elementwise or ps.npu_block_type in ( NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling, ) self.strides = strides self.use_accumulator_element = SHRAMElements.Acc32 - if is_elementwise: + if self.is_elementwise: self.use_ifm_element = SHRAMElements.IFM8_Elementwise else: self.use_ifm_element = SHRAMElements.IFM8 @@ -81,7 +81,7 @@ class SharedBufferAllocation: if ifm_tensor: self.ifm_resampling_mode = ifm_tensor.resampling_mode self.ifm_bits = ifm_tensor.dtype.size_in_bits() - if ifm_tensor.shape == [] and is_elementwise: + if ifm_tensor.shape == [] and self.is_elementwise: # Elementwise operator with scalar in ifm, use ifm2 depth self.ifm_depth = ifm2_tensor.shape[-1] else: @@ -94,7 +94,7 @@ class SharedBufferAllocation: self.use_ifm_element == SHRAMElements.IFM16_Elementwise ) elif self.ifm_bits == 32: - assert is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum, "Unsupported 32-bit IFM operation" + assert self.is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum, "Unsupported 32-bit IFM operation" self.use_ifm_element = SHRAMElements.IFM32 else: assert self.ifm_bits == 8, "Unexpected IFM bitdepth" @@ -131,9 +131,11 @@ class SharedBufferAllocation: if ofm_config is None: return None + acc_banks = ofm_config.banks[self.use_accumulator_element] + # Update bank counts for IFM and Accumulator self.banks_required[SharedBufferArea.IFM] = ifm_config.banks[self.use_ifm_element] - self.banks_required[SharedBufferArea.Accumulators] = ofm_config.banks[self.use_accumulator_element] + self.banks_required[SharedBufferArea.Accumulators] = 0 if self.is_elementwise else acc_banks # Validating calculates bank layout and returns validity if not self.is_valid(): |