aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/graph_optimiser.py14
-rw-r--r--ethosu/vela/tensor.py5
2 files changed, 10 insertions, 9 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index cb0cc643..23ddf833 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -205,7 +205,7 @@ def fixup_fully_connected_input(op, arch):
reshape_op.inputs = [inp, new_shape_tens]
reshape_op.attrs["new_shape"] = desired_shape
reshape_out = inp.clone("_reshaped")
- reshape_out.shape = reshape_out.storage_shape = reshape_out.bandwidth_shape = desired_shape
+ reshape_out.set_all_shapes(desired_shape)
reshape_out.ops = [reshape_op]
reshape_op.outputs = [reshape_out]
@@ -235,7 +235,7 @@ def fixup_pack_input(op, arch):
reshape_op.inputs = [inp, new_shape_tens]
reshape_op.attrs["new_shape"] = desired_shape
reshape_out = inp.clone("_reshaped")
- reshape_out.shape = reshape_out.storage_shape = reshape_out.bandwidth_shape = desired_shape
+ reshape_out.set_all_shapes(desired_shape)
reshape_out.ops = [reshape_op]
reshape_op.outputs = [reshape_out]
@@ -308,7 +308,7 @@ def fixup_unpack_output(tens, arch):
reshape_op = Operation("Reshape", reshape_name)
reshape_op.outputs = [out_tens]
reshape_in = out_tens.clone("_reshaped")
- reshape_in.shape = reshape_in.storage_shape = reshape_in.bandwidth_shape = reshape_input_shape
+ reshape_in.set_all_shapes(reshape_input_shape)
reshape_in.ops = [op]
out_tens.ops = [reshape_op]
reshape_op.inputs = [reshape_in, new_shape_tens]
@@ -425,9 +425,7 @@ def convert_depthwise_to_conv(op, arch):
del op.attrs["depth_multiplier"]
weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
- weight_tensor.shape = weight_tensor.storage_shape = weight_tensor.bandwidth_shape = list(
- weight_tensor.quant_values.shape
- )
+ weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
else:
raise UnsupportedFeatureError(
"Unsupported DepthwiseConv2d with depth_multiplier = {}, ifm channels = {}, ofm channels = {}".format(
@@ -441,9 +439,7 @@ def reorder_depthwise_weights(op, arch):
if "DepthwiseConv2d" in op.type:
weight_tensor = op.inputs[1]
weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
- weight_tensor.shape = weight_tensor.storage_shape = weight_tensor.bandwidth_shape = list(
- weight_tensor.quant_values.shape
- )
+ weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
weight_tensor.weight_transpose_depthwise = True
return op
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 42ba853d..1a071e61 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -690,6 +690,11 @@ class Tensor:
return True
return False
+ def set_all_shapes(self, shape):
+ self.shape = shape
+ self.storage_shape = shape
+ self.bandwidth_shape = shape
+
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)