aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py37
1 files changed, 20 insertions, 17 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index bc0597f6..5e97cfe8 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -414,27 +414,30 @@ class Tensor:
return rounded_size
def storage_shape_for_sub_purpose(self, sub_purpose, param_a, param_b):
- shp = list(self.storage_shape)
if sub_purpose == TensorSubPurpose.DoubleBuffer:
+ shp = list(self.shape)
assert len(shp) >= 2
shp[-1] = min(shp[-1], param_a * 2)
- elif sub_purpose == TensorSubPurpose.RollingBufferX:
- assert len(shp) == 4
- shp[0] = 1
- shp[2] = min(shp[2], param_a)
- elif sub_purpose == TensorSubPurpose.RollingBufferY:
- assert len(shp) == 4
- shp[0] = 1
- shp[1] = min(shp[1], param_a)
- elif sub_purpose == TensorSubPurpose.RollingBufferXY:
- assert len(shp) == 4
- shp[0] = 1
- shp[2] = min(shp[2], param_a)
- shp[1] = min(shp[1], param_b)
- elif sub_purpose == TensorSubPurpose.Standard:
- pass
else:
- assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
+ shp = list(self.storage_shape)
+ if sub_purpose == TensorSubPurpose.RollingBufferX:
+ assert len(shp) == 4
+ shp[0] = 1
+ shp[2] = min(shp[2], param_a)
+ elif sub_purpose == TensorSubPurpose.RollingBufferY:
+ assert len(shp) == 4
+ shp[0] = 1
+ shp[1] = min(shp[1], param_a)
+ elif sub_purpose == TensorSubPurpose.RollingBufferXY:
+ assert len(shp) == 4
+ shp[0] = 1
+ shp[2] = min(shp[2], param_a)
+ shp[1] = min(shp[1], param_b)
+ elif sub_purpose == TensorSubPurpose.Standard:
+ pass
+ else:
+ assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
+
return shp
def set_new_sub_purpose(self, sub_purpose, param_a=None, param_b=None):