diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index bc0597f6..5e97cfe8 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -414,27 +414,30 @@ class Tensor: return rounded_size def storage_shape_for_sub_purpose(self, sub_purpose, param_a, param_b): - shp = list(self.storage_shape) if sub_purpose == TensorSubPurpose.DoubleBuffer: + shp = list(self.shape) assert len(shp) >= 2 shp[-1] = min(shp[-1], param_a * 2) - elif sub_purpose == TensorSubPurpose.RollingBufferX: - assert len(shp) == 4 - shp[0] = 1 - shp[2] = min(shp[2], param_a) - elif sub_purpose == TensorSubPurpose.RollingBufferY: - assert len(shp) == 4 - shp[0] = 1 - shp[1] = min(shp[1], param_a) - elif sub_purpose == TensorSubPurpose.RollingBufferXY: - assert len(shp) == 4 - shp[0] = 1 - shp[2] = min(shp[2], param_a) - shp[1] = min(shp[1], param_b) - elif sub_purpose == TensorSubPurpose.Standard: - pass else: - assert 0, "did not expect new sub purpose %s" % (sub_purpose,) + shp = list(self.storage_shape) + if sub_purpose == TensorSubPurpose.RollingBufferX: + assert len(shp) == 4 + shp[0] = 1 + shp[2] = min(shp[2], param_a) + elif sub_purpose == TensorSubPurpose.RollingBufferY: + assert len(shp) == 4 + shp[0] = 1 + shp[1] = min(shp[1], param_a) + elif sub_purpose == TensorSubPurpose.RollingBufferXY: + assert len(shp) == 4 + shp[0] = 1 + shp[2] = min(shp[2], param_a) + shp[1] = min(shp[1], param_b) + elif sub_purpose == TensorSubPurpose.Standard: + pass + else: + assert 0, "did not expect new sub purpose %s" % (sub_purpose,) + return shp def set_new_sub_purpose(self, sub_purpose, param_a=None, param_b=None): |