aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/tensor.py11
1 files changed, 3 insertions, 8 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 46040a46..5d0206cc 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -331,10 +331,6 @@ class Tensor:
self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
self.storage_rounding_quantum = self.storage_rounding_quantum[-shape_len:]
- if self.format == TensorFormat.NHCWB16:
- self.storage_rounding_quantum = self.storage_rounding_quantum[:-1] + (
- int(self.storage_rounding_quantum[-1] / self.dtype.size_in_bytes()),
- )
self.brick_size = arch.brick_sizes[self.format]
self.brick_size = self.brick_size[-shape_len:]
if self.shape is None:
@@ -491,7 +487,7 @@ class Tensor:
stride_order = [4, 1, 3, 2, 0]
elif self.format == TensorFormat.NHCWB16:
- channel_divisor = int(16 / self.element_size())
+ channel_divisor = 16
augmented_shape = augmented_shape[0:4] + [1]
augmented_coord = (
[augmented_coord[0], augmented_coord[3] // channel_divisor]
@@ -515,11 +511,10 @@ class Tensor:
stride *= augmented_shape[i]
else:
assert len(strides) == 5
- channel_divisor = int(16 / self.element_size())
strides[4] = stride
- strides[3] = channel_divisor # STRIDE_X
+ strides[3] = 16 * stride # STRIDE_X
strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
- strides[2] = augmented_shape[2] * augmented_shape[3] # STRIDE_Y
+ strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
return strides, augmented_coord