aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py29
1 files changed, 20 insertions, 9 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 9b98b8f..5107871 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -243,13 +243,15 @@ def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
return padding, skirt
-def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
+def calc_upscaled_padding_and_skirt(
+ padding_type, kernel_size, stride, input_shape, upscaling_factor_y, upscaling_factor_x
+):
kernel_height, kernel_width = kernel_size[0], kernel_size[1]
if padding_type == Padding.SAME:
- ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
- xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
- right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
- bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
+ ypad = needed_total_padding(int(input_shape.height) * upscaling_factor_y, int(stride[1]), int(kernel_height))
+ xpad = needed_total_padding(int(input_shape.width) * upscaling_factor_x, int(stride[2]), int(kernel_width))
+ right_pad = max(((xpad + 1) // upscaling_factor_x) - 1, 0)
+ bottom_pad = max(((ypad + 1) // upscaling_factor_y) - 1, 0)
left_pad = max(kernel_width - 1 - right_pad, 0)
top_pad = max(kernel_height - 1 - bottom_pad, 0)
elif padding_type == Padding.VALID:
@@ -269,7 +271,11 @@ def fixup_conv2d_backprop(op: Operation, arch, nng) -> Operation:
# flip the inputs
op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
op.type = Op.Conv2DBackpropInputSwitchedBias
- op.ifm_resampling_mode = resampling_mode.TRANSPOSE
+ stride_w = op.kernel.stride.x
+ stride_h = op.kernel.stride.y
+ if stride_w > 1 or stride_h > 1:
+ # Transpose conv2d with upscaling
+ op.ifm_resampling_mode = resampling_mode.TRANSPOSE
# Update strides
op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
@@ -924,10 +930,15 @@ def add_padding_fields(op, arch, nng):
else:
raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
- if op.type == Op.Conv2DBackpropInputSwitchedBias:
- upscaling_factor = output_shape.height // input_shape.height
+ if op.type == Op.Conv2DBackpropInputSwitchedBias and op.ifm_resampling_mode == resampling_mode.TRANSPOSE:
+ # Transpose with upscale
padding, skirt = calc_upscaled_padding_and_skirt(
- op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
+ op.attrs["padding"],
+ kernel_size,
+ op.attrs["strides"],
+ input_shape,
+ output_shape.height // input_shape.height,
+ output_shape.width // input_shape.width,
)
else:
padding, skirt = calc_padding_and_skirt(