diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 105 |
1 files changed, 54 insertions, 51 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 28dead10..a12eeb37 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1982,58 +1982,59 @@ def convert_mean_to_depthwise_conv(op, arch, nng): max_kernel_size = 4096 max_height = 64 inp, axis = op.inputs - shape = inp.shape - ofm_shape = op.ofm.shape - dims = len(shape) - dims_ofm = len(ofm_shape) + dims = len(inp.shape) + dims_ofm = len(op.ofm.shape) ofmq = op.ofm.quantization ifmq = op.ifm.quantization - # Height and width axes have different index depending on dimensions - if axis.shape == [] or axis.shape[0] == 1: # single axis - axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0]) - # If dims is 4, axis 1 refers to h-dimension - if dims == 4: - reduce_h, reduce_w = (True, False) if axis == 1 else (False, True) - else: - reduce_h, reduce_w = (True, False) if axis == 0 else (False, True) - else: # multiple axes - axis = sorted(axis.values) - reduce_h, reduce_w = (True, True) - - # Change dimensions to 4 - def extend_dims(dim, in_shape): - if dim < 4: - in_shape = [1] + in_shape - if dim == 2: - in_shape += [1] - return in_shape - - if dims < 4 or dims_ofm < 4: - # Fix the ofm dimension when keep_dims is false - # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC - if isinstance(axis, int) and dims_ofm + 1 == dims: - ofm_shape.insert(axis, 1) - elif isinstance(axis, list) and (dims_ofm + len(axis) == dims): - for i in axis: - ofm_shape.insert(i, 1) - shape = extend_dims(dims, shape) - dims_ofm = len(ofm_shape) - ofm_shape = extend_dims(dims_ofm, ofm_shape) - op.set_ifm_ofm_shapes() - - # Compute kernel sizes for our convolutions - h = shape[1] if reduce_h else 1 - w = shape[2] if reduce_w else 1 + # reduce_axis[i] is true if axis i should be reduced + if axis.shape == []: + reduce_axis = [True if i == axis.values else False for i in range(dims)] + else: + reduce_axis = [True if i in axis.values else False for i in range(dims)] + + ifm_shape = inp.shape.copy() + intermediate_shape = op.ofm.shape.copy() + + # Fix intermediate_shape when keep_dims is false + # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the intermediate_shape should be 1xHx1xC + if dims_ofm < dims: + for i in range(dims): + if reduce_axis[i]: + intermediate_shape.insert(i, 1) + + # Reshape to 4D + if dims == 2: + # Reshape WxC -> 1xHxWx1 to support both axes + reduce_axis = [False] + reduce_axis + [False] + ifm_shape = [1] + ifm_shape + [1] + intermediate_shape = [1] + intermediate_shape + [1] + elif dims == 3: + # Reshape to 4D HxWxC -> 1xHxWxC + reduce_axis = [False] + reduce_axis + ifm_shape = [1] + ifm_shape + intermediate_shape = [1] + intermediate_shape + + # If all dimensions to reduce have shape 1, the operation is essentially a memcpy. + # We can then remove the whole op by propagating ofm to previous ops + if not any([reduce_axis[i] and ifm_shape[i] > 1 for i in range(4)]): + op.type = Op.Memcpy + op = bypass_memory_only_ops(op, arch, nng) + return op + + # Compute kernel sizes for our convolutions. + # batch and depth axes are only supported if their shapes are 1. + # hence reduction in batch or depth axis is implicit. + h = ifm_shape[1] if reduce_axis[1] else 1 + w = ifm_shape[2] if reduce_axis[2] else 1 + num_elements_in_axis = h * w # If one convolution is enough, but height is greater than max kernel height # reshape from HxW to 1x(HxW) # This can only be done if the mean is computed over both H and W - if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_h and reduce_w: - shape = [shape[0], 1, h * w, shape[3]] - op.ifm_shapes[0] = Shape4D(shape) - op.ifm.shape = shape + if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_axis[1] and reduce_axis[2]: + ifm_shape = [ifm_shape[0], 1, h * w, ifm_shape[3]] w = h * w h = 1 @@ -2065,10 +2066,11 @@ def convert_mean_to_depthwise_conv(op, arch, nng): } ) - b, _, _, c = shape + b, _, _, c = ifm_shape intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True) intermediate_tensor.dtype = DataType.int32 + intermediate_tensor.shape = intermediate_shape intermediate_op.set_output_tensor(intermediate_tensor) # as we have several convs, scaling/rounding must be done after the sum has been calculated @@ -2081,11 +2083,11 @@ def convert_mean_to_depthwise_conv(op, arch, nng): weight_h = height_per_conv # compute ifm read offset and shape for the convolution - read_shape_h = weight_h if reduce_h else shape[1] - read_shape_w = w if reduce_w else shape[2] + read_shape_h = weight_h if reduce_axis[1] else ifm_shape[1] + read_shape_w = w if reduce_axis[2] else ifm_shape[2] intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0]) - intermediate_op.read_shapes[0] = Shape4D(shape).with_hw(read_shape_h, read_shape_w) + intermediate_op.read_shapes[0] = Shape4D(ifm_shape).with_hw(read_shape_h, read_shape_w) weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0) weight_shape = [weight_h, w, c, b] @@ -2112,9 +2114,9 @@ def convert_mean_to_depthwise_conv(op, arch, nng): intermediate_op.inputs.append(bias) intermediate_op.set_ifm_ofm_shapes() - # We want to avoid reshaping the tensor directly, to not affect other ops + # We want to avoid reshaping the ifm tensor directly, to not affect other ops # so we update the shape explicitly for this operation - intermediate_op.ifm_shapes[0] = Shape4D(shape) + intermediate_op.ifm_shapes[0] = Shape4D(ifm_shape) convs.append(intermediate_op) DebugDatabase.add_optimised(op, intermediate_op) @@ -2128,6 +2130,7 @@ def convert_mean_to_depthwise_conv(op, arch, nng): while len(convs): intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True) intermediate_tensor.dtype = DataType.int32 + intermediate_tensor.shape = intermediate_shape one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0) @@ -2136,7 +2139,6 @@ def convert_mean_to_depthwise_conv(op, arch, nng): ifm2 = convs.pop().ofm else: ifm2 = prev_add_op.ofm - intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant) intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) intermediate_op.set_output_tensor(intermediate_tensor) @@ -2180,6 +2182,7 @@ def convert_mean_to_depthwise_conv(op, arch, nng): ) op.set_input_tensor(scalar, 1) op.set_ifm_ofm_shapes() + op.ofm_shapes[0] = Shape4D(intermediate_shape) # Reference using TFL rounding for the multiply op.rounding_mode = RoundingMode.TFLite |