From 6a8d424b4d41fb5ea69996dd227ea74f794f7a64 Mon Sep 17 00:00:00 2001 From: Michael McGeagh Date: Tue, 28 Jul 2020 12:17:59 +0100 Subject: vela: Move common functionality There is a repeating pattern of setting the 3 different shapes in a tensor to a single shape value. This adds a new function in the tensor class that does this for you. Changed existing instances of manually setting shape to use this new function. Signed-off-by: Michael McGeagh Change-Id: Ibc74e741ea47cec473e6be42cc102f721ec63b11 --- ethosu/vela/graph_optimiser.py | 14 +++++--------- ethosu/vela/tensor.py | 5 +++++ 2 files changed, 10 insertions(+), 9 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index cb0cc643..23ddf833 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -205,7 +205,7 @@ def fixup_fully_connected_input(op, arch): reshape_op.inputs = [inp, new_shape_tens] reshape_op.attrs["new_shape"] = desired_shape reshape_out = inp.clone("_reshaped") - reshape_out.shape = reshape_out.storage_shape = reshape_out.bandwidth_shape = desired_shape + reshape_out.set_all_shapes(desired_shape) reshape_out.ops = [reshape_op] reshape_op.outputs = [reshape_out] @@ -235,7 +235,7 @@ def fixup_pack_input(op, arch): reshape_op.inputs = [inp, new_shape_tens] reshape_op.attrs["new_shape"] = desired_shape reshape_out = inp.clone("_reshaped") - reshape_out.shape = reshape_out.storage_shape = reshape_out.bandwidth_shape = desired_shape + reshape_out.set_all_shapes(desired_shape) reshape_out.ops = [reshape_op] reshape_op.outputs = [reshape_out] @@ -308,7 +308,7 @@ def fixup_unpack_output(tens, arch): reshape_op = Operation("Reshape", reshape_name) reshape_op.outputs = [out_tens] reshape_in = out_tens.clone("_reshaped") - reshape_in.shape = reshape_in.storage_shape = reshape_in.bandwidth_shape = reshape_input_shape + reshape_in.set_all_shapes(reshape_input_shape) reshape_in.ops = [op] out_tens.ops = [reshape_op] reshape_op.inputs = [reshape_in, new_shape_tens] @@ -425,9 +425,7 @@ def convert_depthwise_to_conv(op, arch): del op.attrs["depth_multiplier"] weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2)) - weight_tensor.shape = weight_tensor.storage_shape = weight_tensor.bandwidth_shape = list( - weight_tensor.quant_values.shape - ) + weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) else: raise UnsupportedFeatureError( "Unsupported DepthwiseConv2d with depth_multiplier = {}, ifm channels = {}, ofm channels = {}".format( @@ -441,9 +439,7 @@ def reorder_depthwise_weights(op, arch): if "DepthwiseConv2d" in op.type: weight_tensor = op.inputs[1] weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2)) - weight_tensor.shape = weight_tensor.storage_shape = weight_tensor.bandwidth_shape = list( - weight_tensor.quant_values.shape - ) + weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) weight_tensor.weight_transpose_depthwise = True return op diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 42ba853d..1a071e61 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -690,6 +690,11 @@ class Tensor: return True return False + def set_all_shapes(self, shape): + self.shape = shape + self.storage_shape = shape + self.bandwidth_shape = shape + def __str__(self): return "" % (self.name, self.shape, self.dtype) -- cgit v1.2.1