aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py105
1 files changed, 54 insertions, 51 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 28dead10..a12eeb37 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1982,58 +1982,59 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
max_kernel_size = 4096
max_height = 64
inp, axis = op.inputs
- shape = inp.shape
- ofm_shape = op.ofm.shape
- dims = len(shape)
- dims_ofm = len(ofm_shape)
+ dims = len(inp.shape)
+ dims_ofm = len(op.ofm.shape)
ofmq = op.ofm.quantization
ifmq = op.ifm.quantization
- # Height and width axes have different index depending on dimensions
- if axis.shape == [] or axis.shape[0] == 1: # single axis
- axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
- # If dims is 4, axis 1 refers to h-dimension
- if dims == 4:
- reduce_h, reduce_w = (True, False) if axis == 1 else (False, True)
- else:
- reduce_h, reduce_w = (True, False) if axis == 0 else (False, True)
- else: # multiple axes
- axis = sorted(axis.values)
- reduce_h, reduce_w = (True, True)
-
- # Change dimensions to 4
- def extend_dims(dim, in_shape):
- if dim < 4:
- in_shape = [1] + in_shape
- if dim == 2:
- in_shape += [1]
- return in_shape
-
- if dims < 4 or dims_ofm < 4:
- # Fix the ofm dimension when keep_dims is false
- # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
- if isinstance(axis, int) and dims_ofm + 1 == dims:
- ofm_shape.insert(axis, 1)
- elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
- for i in axis:
- ofm_shape.insert(i, 1)
- shape = extend_dims(dims, shape)
- dims_ofm = len(ofm_shape)
- ofm_shape = extend_dims(dims_ofm, ofm_shape)
- op.set_ifm_ofm_shapes()
-
- # Compute kernel sizes for our convolutions
- h = shape[1] if reduce_h else 1
- w = shape[2] if reduce_w else 1
+ # reduce_axis[i] is true if axis i should be reduced
+ if axis.shape == []:
+ reduce_axis = [True if i == axis.values else False for i in range(dims)]
+ else:
+ reduce_axis = [True if i in axis.values else False for i in range(dims)]
+
+ ifm_shape = inp.shape.copy()
+ intermediate_shape = op.ofm.shape.copy()
+
+ # Fix intermediate_shape when keep_dims is false
+ # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the intermediate_shape should be 1xHx1xC
+ if dims_ofm < dims:
+ for i in range(dims):
+ if reduce_axis[i]:
+ intermediate_shape.insert(i, 1)
+
+ # Reshape to 4D
+ if dims == 2:
+ # Reshape WxC -> 1xHxWx1 to support both axes
+ reduce_axis = [False] + reduce_axis + [False]
+ ifm_shape = [1] + ifm_shape + [1]
+ intermediate_shape = [1] + intermediate_shape + [1]
+ elif dims == 3:
+ # Reshape to 4D HxWxC -> 1xHxWxC
+ reduce_axis = [False] + reduce_axis
+ ifm_shape = [1] + ifm_shape
+ intermediate_shape = [1] + intermediate_shape
+
+ # If all dimensions to reduce have shape 1, the operation is essentially a memcpy.
+ # We can then remove the whole op by propagating ofm to previous ops
+ if not any([reduce_axis[i] and ifm_shape[i] > 1 for i in range(4)]):
+ op.type = Op.Memcpy
+ op = bypass_memory_only_ops(op, arch, nng)
+ return op
+
+ # Compute kernel sizes for our convolutions.
+ # batch and depth axes are only supported if their shapes are 1.
+ # hence reduction in batch or depth axis is implicit.
+ h = ifm_shape[1] if reduce_axis[1] else 1
+ w = ifm_shape[2] if reduce_axis[2] else 1
+
num_elements_in_axis = h * w
# If one convolution is enough, but height is greater than max kernel height
# reshape from HxW to 1x(HxW)
# This can only be done if the mean is computed over both H and W
- if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_h and reduce_w:
- shape = [shape[0], 1, h * w, shape[3]]
- op.ifm_shapes[0] = Shape4D(shape)
- op.ifm.shape = shape
+ if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_axis[1] and reduce_axis[2]:
+ ifm_shape = [ifm_shape[0], 1, h * w, ifm_shape[3]]
w = h * w
h = 1
@@ -2065,10 +2066,11 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
}
)
- b, _, _, c = shape
+ b, _, _, c = ifm_shape
intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True)
intermediate_tensor.dtype = DataType.int32
+ intermediate_tensor.shape = intermediate_shape
intermediate_op.set_output_tensor(intermediate_tensor)
# as we have several convs, scaling/rounding must be done after the sum has been calculated
@@ -2081,11 +2083,11 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
weight_h = height_per_conv
# compute ifm read offset and shape for the convolution
- read_shape_h = weight_h if reduce_h else shape[1]
- read_shape_w = w if reduce_w else shape[2]
+ read_shape_h = weight_h if reduce_axis[1] else ifm_shape[1]
+ read_shape_w = w if reduce_axis[2] else ifm_shape[2]
intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0])
- intermediate_op.read_shapes[0] = Shape4D(shape).with_hw(read_shape_h, read_shape_w)
+ intermediate_op.read_shapes[0] = Shape4D(ifm_shape).with_hw(read_shape_h, read_shape_w)
weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0)
weight_shape = [weight_h, w, c, b]
@@ -2112,9 +2114,9 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
intermediate_op.inputs.append(bias)
intermediate_op.set_ifm_ofm_shapes()
- # We want to avoid reshaping the tensor directly, to not affect other ops
+ # We want to avoid reshaping the ifm tensor directly, to not affect other ops
# so we update the shape explicitly for this operation
- intermediate_op.ifm_shapes[0] = Shape4D(shape)
+ intermediate_op.ifm_shapes[0] = Shape4D(ifm_shape)
convs.append(intermediate_op)
DebugDatabase.add_optimised(op, intermediate_op)
@@ -2128,6 +2130,7 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
while len(convs):
intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True)
intermediate_tensor.dtype = DataType.int32
+ intermediate_tensor.shape = intermediate_shape
one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
@@ -2136,7 +2139,6 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
ifm2 = convs.pop().ofm
else:
ifm2 = prev_add_op.ofm
-
intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant)
intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
intermediate_op.set_output_tensor(intermediate_tensor)
@@ -2180,6 +2182,7 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
)
op.set_input_tensor(scalar, 1)
op.set_ifm_ofm_shapes()
+ op.ofm_shapes[0] = Shape4D(intermediate_shape)
# Reference using TFL rounding for the multiply
op.rounding_mode = RoundingMode.TFLite