diff options
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r-- | ethosu/vela/tosa_reader.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 6d80e10d..5ffefade 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -189,7 +189,8 @@ class TosaSubgraph: elif op.type.is_conv2d_op(): inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False) elif op.type.is_depthwise_conv2d_op(): - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False) + HWCM_to_HWOI = (0, 1, 3, 2) + inputs[1] = clone_and_reshape_tensor(inputs[1], HWCM_to_HWOI, False) if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]: # No Bias tensor inputs.append(None) @@ -241,7 +242,13 @@ class TosaSubgraph: if shift != 0: op.explicit_scaling = ExplicitScaling(False, [shift], [1]) if op.type.is_depthwise_conv2d_op(): - op.attrs["depth_multiplier"] = op.weights.shape[3] + assert op.weights.shape[-1] % op.ifm.shape[-1] == 0 + depth_multiplier = op.weights.shape[-1] / op.ifm.shape[-1] + if depth_multiplier > 1: + assert op.ifm.shape[-1] == 1 and op.ofm.shape[-1] == depth_multiplier, ( + "For depth multipliers > 1, IFM channels must be 1 and " + "OFM channels must be equal to the depth multiplier") + op.attrs["depth_multiplier"] = depth_multiplier if op.type == Op.SplitSliceRead: op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0) op.read_shapes[0] = op.attrs["size"] |