aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-05-05 17:49:35 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit2213e90570af328418d4f4a0d54269ed21dc40bc (patch)
treea40f53e04e1c9200eec5101ee7fc884a3e5feed0
parentcf72890e5dd89dada3189816a61d174b984086bd (diff)
downloadethos-u-vela-2213e90570af328418d4f4a0d54269ed21dc40bc.tar.gz
MLBEDSW-2241: Fix for NHCWB16 with int16
Changes in strides and rounding for int16 and NHCWB16 Change-Id: I195890215b55ee7a4eab2e6ce4da95fb41587acb Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
-rw-r--r--ethosu/vela/tensor.py11
1 files changed, 3 insertions, 8 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 46040a46..5d0206cc 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -331,10 +331,6 @@ class Tensor:
self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
self.storage_rounding_quantum = self.storage_rounding_quantum[-shape_len:]
- if self.format == TensorFormat.NHCWB16:
- self.storage_rounding_quantum = self.storage_rounding_quantum[:-1] + (
- int(self.storage_rounding_quantum[-1] / self.dtype.size_in_bytes()),
- )
self.brick_size = arch.brick_sizes[self.format]
self.brick_size = self.brick_size[-shape_len:]
if self.shape is None:
@@ -491,7 +487,7 @@ class Tensor:
stride_order = [4, 1, 3, 2, 0]
elif self.format == TensorFormat.NHCWB16:
- channel_divisor = int(16 / self.element_size())
+ channel_divisor = 16
augmented_shape = augmented_shape[0:4] + [1]
augmented_coord = (
[augmented_coord[0], augmented_coord[3] // channel_divisor]
@@ -515,11 +511,10 @@ class Tensor:
stride *= augmented_shape[i]
else:
assert len(strides) == 5
- channel_divisor = int(16 / self.element_size())
strides[4] = stride
- strides[3] = channel_divisor # STRIDE_X
+ strides[3] = 16 * stride # STRIDE_X
strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
- strides[2] = augmented_shape[2] * augmented_shape[3] # STRIDE_Y
+ strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
return strides, augmented_coord