aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPer Åstrand <per.astrand@arm.com>2024-03-25 22:30:12 +0100
committerPer Åstrand <per.astrand@arm.com>2024-04-12 12:19:20 +0200
commit92240e7979018a197b42aab2da16dc002d86f224 (patch)
treefe1ec0d1790fe40b3fad3dd06fcabf252139c80c
parent931613df7c68fb1c7cb45c6f69783c86003d7583 (diff)
downloadethos-u-vela-main.tar.gz
Reshape weights from TOSA to Vela expected formatHEADmain
Reshape the weight for depthwise conv2d and set the depth_multiplier attribute on the operation. Signed-off-by: Per Åstrand <per.astrand@arm.com> Change-Id: I3b73988fa8c4e0cbe2430874cefe6d002885ec89
-rw-r--r--ethosu/vela/tosa_reader.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 6d80e10..5ffefad 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -189,7 +189,8 @@ class TosaSubgraph:
elif op.type.is_conv2d_op():
inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False)
elif op.type.is_depthwise_conv2d_op():
- inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False)
+ HWCM_to_HWOI = (0, 1, 3, 2)
+ inputs[1] = clone_and_reshape_tensor(inputs[1], HWCM_to_HWOI, False)
if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
# No Bias tensor
inputs.append(None)
@@ -241,7 +242,13 @@ class TosaSubgraph:
if shift != 0:
op.explicit_scaling = ExplicitScaling(False, [shift], [1])
if op.type.is_depthwise_conv2d_op():
- op.attrs["depth_multiplier"] = op.weights.shape[3]
+ assert op.weights.shape[-1] % op.ifm.shape[-1] == 0
+ depth_multiplier = op.weights.shape[-1] / op.ifm.shape[-1]
+ if depth_multiplier > 1:
+ assert op.ifm.shape[-1] == 1 and op.ofm.shape[-1] == depth_multiplier, (
+ "For depth multipliers > 1, IFM channels must be 1 and "
+ "OFM channels must be equal to the depth multiplier")
+ op.attrs["depth_multiplier"] = depth_multiplier
if op.type == Op.SplitSliceRead:
op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0)
op.read_shapes[0] = op.attrs["size"]