aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py233
1 files changed, 157 insertions, 76 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 78906374..21c02f3c 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -57,6 +57,7 @@ from .operation import Op
from .operation import Operation
from .operation import Padding
from .operation import RoundingMode
+from .operation_util import create_add
from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
from .operation_util import create_cast_op
@@ -942,9 +943,10 @@ def add_padding_fields(op, arch, nng):
def reorder_depthwise_weights(op, arch, nng):
if op.type.is_depthwise_conv2d_op():
weight_tensor = op.inputs[1]
- weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
- weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
- weight_tensor.weight_transpose_depthwise = True
+ if not weight_tensor.weight_transpose_depthwise:
+ weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
+ weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
+ weight_tensor.weight_transpose_depthwise = True
return op
@@ -1949,44 +1951,45 @@ def fixup_or_check_asymmetric_weights(force_symmetric_int_weights):
def convert_mean_to_depthwise_conv(op, arch, nng):
+ """
+ When h x w <= 4096 When h x w > 4096 there is a need to split into several ops.
+ Do this by splitting up h and change the read_offset/shape.
+ Below is an example where ifm is 1x190x64x1
+ MEAN MEAN
+ | |-----------------------|----------------------|
+ DepthwiseConv2DBias 1_DepthwiseConv2DBias 2_DepthwiseConv2DBias 3_DepthwiseConv2DBias
+ | | | |
+ MUL |---------ADD-----------| |
+ | |
+ |----------------ADD---------------|
+ |
+ MUL
+ 1_DepthwiseConv2DBias: read_offset [0, 0, 0, 0]> read_shape [1, 64, 64, 1]>
+ 2_DepthwiseConv2DBias: read_offset [0, 64, 0, 0]> read_shape [1, 64, 64, 1]>
+ 3_DepthwiseConv2DBias: read_offset [0, 128, 0, 0]> read_shape [1, 62, 64, 1]>
+ """
if op.type == Op.Mean and op.run_on_npu:
+ max_kernel_size = 4096
+ max_height = 64
inp, axis = op.inputs
shape = inp.shape
ofm_shape = op.ofm.shape
dims = len(shape)
dims_ofm = len(ofm_shape)
+ ofmq = op.ofm.quantization
+ ifmq = op.ifm.quantization
# Height and width axes have different index depending on dimensions
if axis.shape == [] or axis.shape[0] == 1: # single axis
axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
- if dims in (2, 3):
- # If dims is 2 or 3, axis 0 refers to h-dimension
- h, w = (shape[axis], 1) if axis == 0 else (1, shape[axis])
+ # If dims is 4, axis 1 refers to h-dimension
+ if dims == 4:
+ reduce_h, reduce_w = (True, False) if axis == 1 else (False, True)
else:
- # If dims is 4, axis 1 refers to h-dimension
- h, w = (shape[axis], 1) if axis == 1 else (1, shape[axis])
+ reduce_h, reduce_w = (True, False) if axis == 0 else (False, True)
else: # multiple axes
axis = sorted(axis.values)
- h, w = [shape[i] for i in axis]
-
- # Set necessary depthwise attributes
- op.attrs.update(
- {
- "padding": Padding.VALID,
- "stride_h": 1,
- "stride_w": 1,
- "strides": (1, 1, 1, 1),
- "depth_multiplier": 1,
- "channel_multiplier": 1,
- "dilation_h_factor": 1,
- "dilation_w_factor": 1,
- "dilation": (1, 1, 1, 1),
- }
- )
- # Change op type
- op.type = Op.DepthwiseConv2DBias
- # Set IFM/OFM shapes after changing op type
- op.set_ifm_ofm_shapes()
+ reduce_h, reduce_w = (True, True)
# Change dimensions to 4
def extend_dims(dim, in_shape):
@@ -2009,63 +2012,140 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
ofm_shape = extend_dims(dims_ofm, ofm_shape)
op.set_ifm_ofm_shapes()
- # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
- if h > 64:
- # This can only happen and be done for multiple axes, and
- # h * w <= 4096 for DepthwiseConv2DBias
- # which is checked in supported ops
+ # Compute kernel sizes for our convolutions
+ h = shape[1] if reduce_h else 1
+ w = shape[2] if reduce_w else 1
+ num_elements_in_axis = h * w
+
+ # If one convolution is enough, but height is greater than max kernel height
+ # reshape from HxW to 1x(HxW)
+ # This can only be done if the mean is computed over both H and W
+ if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_h and reduce_w:
shape = [shape[0], 1, h * w, shape[3]]
op.ifm_shapes[0] = Shape4D(shape)
- weight_shape = [1, h * w, shape[3], shape[0]]
- else:
- # Set weight shape to [H,W,C,B]
- weight_shape = [h, w, shape[3], shape[0]]
+ op.ifm.shape = shape
+ w = h * w
+ h = 1
+
+ intermediate_op = None
+ height_per_conv = min(max_kernel_size // w, h)
+ height_per_conv = min(height_per_conv, max_height)
+ num_convs = math.ceil(h / height_per_conv)
+ convs = list()
+
+ for i in range(num_convs):
+ is_last_op = i == (num_convs - 1)
+
+ intermediate_op = op.clone(f"{op.name}_conv_{i}")
+
+ intermediate_op.type = Op.DepthwiseConv2DBias
+
+ # Set necessary depthwise attributes
+ intermediate_op.attrs.update(
+ {
+ "padding": Padding.VALID,
+ "stride_h": 1,
+ "stride_w": 1,
+ "strides": (1, 1, 1, 1),
+ "depth_multiplier": 1,
+ "channel_multiplier": 1,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "dilation": (1, 1, 1, 1),
+ }
+ )
- op.rounding_mode = RoundingMode.HalfUp
- identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
- op.forced_input_quantization = identity_quant
- op.forced_output_quantization = identity_quant
+ b, _, _, c = shape
- # Add unit weight tensor
- op.set_input_tensor(
- create_const_tensor(
- "weights",
+ intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True)
+ intermediate_tensor.dtype = DataType.int32
+ intermediate_op.set_output_tensor(intermediate_tensor)
+
+ # as we have several convs, scaling/rounding must be done after the sum has been calculated
+ intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
+
+ # compute height for the kernel
+ if is_last_op and h % height_per_conv != 0:
+ weight_h = h % height_per_conv
+ else:
+ weight_h = height_per_conv
+
+ # compute ifm read offset and shape for the convolution
+ read_shape_h = weight_h if reduce_h else shape[1]
+ read_shape_w = w if reduce_w else shape[2]
+
+ intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0])
+ intermediate_op.read_shapes[0] = Shape4D(shape).with_hw(read_shape_h, read_shape_w)
+
+ weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0)
+ weight_shape = [weight_h, w, c, b]
+ weight_tensor = create_const_tensor(
+ f"{intermediate_op.name}_weights",
weight_shape,
- inp.dtype,
+ DataType.uint8,
np.ones(weight_shape),
- quantization=identity_quant,
- ),
- 1,
- )
- op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
+ TensorPurpose.Weights,
+ quantization=weight_quant,
+ )
- # Input zero point is adjusted after the sum calculation, so we emulate that with a bias
- ofmq, ifmq = op.ofm.quantization, inp.quantization
- bias = -ifmq.zero_point * h * w
- bias_shape = [shape[-1]]
- op.inputs.append(create_const_tensor(op.name + "_bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
- DebugDatabase.add_optimised(op, op)
+ weights_1D = np.ones(np.prod(weight_shape))
+ weight_tensor.equivalence_id = create_equivalence_id(tuple(weights_1D))
+ weight_tensor.value_id = weight_tensor.equivalence_id
- # Create intermediate tensor between depthwise conv and mul
- intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
- intermediate.dtype = DataType.int32
+ intermediate_op.set_input_tensor(weight_tensor, 1)
- # Multiply sum with 1/num_elements_in_axis to get the mean
- mul_op = Operation(Op.Mul, op.name + "_mul")
- mul_op.add_input_tensor(intermediate)
- mul_op.set_output_tensor(op.ofm)
- mul_op.forced_input_quantization = identity_quant
+ dtype = DataType.int64 if intermediate_op.ifm.dtype == DataType.int16 else DataType.int32
+ bias_values = [0] * c
+ bias = create_const_tensor(f"{intermediate_op.name}_bias", [c], dtype, bias_values)
+ bias.equivalence_id = create_equivalence_id(tuple(bias_values))
+ bias.value_id = bias.equivalence_id
+ intermediate_op.inputs.append(bias)
+ intermediate_op.set_ifm_ofm_shapes()
- # Set dw conv output to the intermediate tensor
- op.set_output_tensor(intermediate)
+ # We want to avoid reshaping the tensor directly, to not affect other ops
+ # so we update the shape explicitly for this operation
+ intermediate_op.ifm_shapes[0] = Shape4D(shape)
- # Move activation from original op to mean op
- mul_op.activation = op.activation
- op.activation = None
+ convs.append(intermediate_op)
+ DebugDatabase.add_optimised(op, intermediate_op)
+
+ # If we have more than one convolution
+ # We use add operations to accumulate the intermediate tensors
+ if len(convs) > 1:
+ prev_add_op = None
+ idx = 0
+
+ while len(convs):
+ intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True)
+ intermediate_tensor.dtype = DataType.int32
+
+ one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
+
+ ifm = convs.pop().ofm
+ if not prev_add_op:
+ ifm2 = convs.pop().ofm
+ else:
+ ifm2 = prev_add_op.ofm
+
+ intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant)
+ intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
+ intermediate_op.set_output_tensor(intermediate_tensor)
+ intermediate_op.set_ifm_ofm_shapes()
+
+ prev_add_op = intermediate_op
+ idx += 1
+
+ DebugDatabase.add_optimised(op, intermediate_op)
+
+ # Convert the original mean op to our final Mul operation
+ # Which scales and divides by num_elements_in_axis
+ op.type = Op.Mul
+ op.name = f"{op.name}_mul"
+ op.attrs = {}
+ op.set_input_tensor(intermediate_op.ofm, 0)
# The multiplier is calculated in the same way as in the reference,
# clamping the shift value at the price of some precision loss.
- num_elements_in_axis = int(h * w)
output_multiplier, output_shift_vela = quantise_scale(np.double(ifmq.scale_f32) / np.double(ofmq.scale_f32))
# Convert to reference representation shift value
@@ -2084,18 +2164,19 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
# For int32 scaling is not supported so instead multiply with the scale
# intermediate * scale -> round and shift.
+ identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
scalar = create_const_tensor(
op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [output_multiplier], quantization=identity_quant
)
- mul_op.add_input_tensor(scalar)
- mul_op.set_ifm_ofm_shapes()
+ op.set_input_tensor(scalar, 1)
+ op.set_ifm_ofm_shapes()
# Reference using TFL rounding for the multiply
- mul_op.rounding_mode = RoundingMode.TFLite
+ op.rounding_mode = RoundingMode.TFLite
# Need to use explicit scaling to get the wanted shift
- mul_op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
- DebugDatabase.add_optimised(op, mul_op)
+ op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
+ DebugDatabase.add_optimised(op, op)
return op