diff options
author | Per Åstrand <per.astrand@arm.com> | 2024-03-25 22:30:12 +0100 |
---|---|---|
committer | Per Åstrand <per.astrand@arm.com> | 2024-04-12 12:19:20 +0200 |
commit | 92240e7979018a197b42aab2da16dc002d86f224 (patch) | |
tree | fe1ec0d1790fe40b3fad3dd06fcabf252139c80c | |
parent | 931613df7c68fb1c7cb45c6f69783c86003d7583 (diff) | |
download | ethos-u-vela-92240e7979018a197b42aab2da16dc002d86f224.tar.gz |
Reshape weights from TOSA to Vela expected format
Reshape the weight for depthwise conv2d and set the
depth_multiplier attribute on the operation.
Signed-off-by: Per Åstrand <per.astrand@arm.com>
Change-Id: I3b73988fa8c4e0cbe2430874cefe6d002885ec89
-rw-r--r-- | ethosu/vela/tosa_reader.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 6d80e10..5ffefad 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -189,7 +189,8 @@ class TosaSubgraph: elif op.type.is_conv2d_op(): inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False) elif op.type.is_depthwise_conv2d_op(): - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False) + HWCM_to_HWOI = (0, 1, 3, 2) + inputs[1] = clone_and_reshape_tensor(inputs[1], HWCM_to_HWOI, False) if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]: # No Bias tensor inputs.append(None) @@ -241,7 +242,13 @@ class TosaSubgraph: if shift != 0: op.explicit_scaling = ExplicitScaling(False, [shift], [1]) if op.type.is_depthwise_conv2d_op(): - op.attrs["depth_multiplier"] = op.weights.shape[3] + assert op.weights.shape[-1] % op.ifm.shape[-1] == 0 + depth_multiplier = op.weights.shape[-1] / op.ifm.shape[-1] + if depth_multiplier > 1: + assert op.ifm.shape[-1] == 1 and op.ofm.shape[-1] == depth_multiplier, ( + "For depth multipliers > 1, IFM channels must be 1 and " + "OFM channels must be equal to the depth multiplier") + op.attrs["depth_multiplier"] = depth_multiplier if op.type == Op.SplitSliceRead: op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0) op.read_shapes[0] = op.attrs["size"] |