From 92240e7979018a197b42aab2da16dc002d86f224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 25 Mar 2024 22:30:12 +0100 Subject: Reshape weights from TOSA to Vela expected format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reshape the weight for depthwise conv2d and set the depth_multiplier attribute on the operation. Signed-off-by: Per Åstrand Change-Id: I3b73988fa8c4e0cbe2430874cefe6d002885ec89 --- ethosu/vela/tosa_reader.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 6d80e10..5ffefad 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -189,7 +189,8 @@ class TosaSubgraph: elif op.type.is_conv2d_op(): inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False) elif op.type.is_depthwise_conv2d_op(): - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False) + HWCM_to_HWOI = (0, 1, 3, 2) + inputs[1] = clone_and_reshape_tensor(inputs[1], HWCM_to_HWOI, False) if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]: # No Bias tensor inputs.append(None) @@ -241,7 +242,13 @@ class TosaSubgraph: if shift != 0: op.explicit_scaling = ExplicitScaling(False, [shift], [1]) if op.type.is_depthwise_conv2d_op(): - op.attrs["depth_multiplier"] = op.weights.shape[3] + assert op.weights.shape[-1] % op.ifm.shape[-1] == 0 + depth_multiplier = op.weights.shape[-1] / op.ifm.shape[-1] + if depth_multiplier > 1: + assert op.ifm.shape[-1] == 1 and op.ofm.shape[-1] == depth_multiplier, ( + "For depth multipliers > 1, IFM channels must be 1 and " + "OFM channels must be equal to the depth multiplier") + op.attrs["depth_multiplier"] = depth_multiplier if op.type == Op.SplitSliceRead: op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0) op.read_shapes[0] = op.attrs["size"] -- cgit v1.2.1