aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela')
-rw-r--r--ethosu/vela/graph_optimiser_util.py11
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py17
-rw-r--r--ethosu/vela/nn_graph.py5
-rw-r--r--ethosu/vela/operation.py12
-rw-r--r--ethosu/vela/pass_packing.py23
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py6
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py47
-rw-r--r--ethosu/vela/tflite_supported_operators.py7
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py61
-rw-r--r--ethosu/vela/tosa_reader.py12
10 files changed, 159 insertions, 42 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 46762e4d..f1b9e1aa 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -39,6 +39,10 @@ memory_only_ops = (
Op.Identity,
)
+# This list contains ops that requires its ofm shape to be intact in order
+# to be able to decompose it correctly in the graph optimiser step
+ofm_not_replaceable_ops = (Op.Mean,)
+
def _avoid_nhcwb16_for_concat(tens):
# If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -300,8 +304,11 @@ def bypass_memory_only_ops(op, arch, nng):
ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1
ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ producer_ofm_not_replaceable = any(
+ ifm_prod is not None and ifm_prod.type in ofm_not_replaceable_ops for ifm_prod in op.ifm.ops
+ )
- if ifm_has_multiple_cons or ifm_is_cpu_produced:
+ if ifm_has_multiple_cons or ifm_is_cpu_produced or producer_ofm_not_replaceable:
# Convert to a memcpy op
op.type = Op.Memcpy
DebugDatabase.add_optimised(op, op)
@@ -348,6 +355,8 @@ def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_o
"""Creates an average pool for the given concat op/input feature map"""
ofm = concat_op.ofm
avgpool_op = create_avgpool_nop(name)
+ # Enforce original type since this is used in pass packing to group concat ops
+ avgpool_op._original_type = concat_op.type
avgpool_op.inputs = [ifm]
avgpool_op.outputs = [ofm]
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 52d07187..71181d05 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -410,16 +410,20 @@ def create_feature_map(
assert strides is not None
+ multiplied_strides = strides.copy()
if stride_multiplier and stride_multiplier != [1, 1, 1]:
assert (
tens.format == TensorFormat.NHWC
), "Only default stride multiplier ([1, 1, 1]) supported for NHCWB16 format"
# Multiply strides for C/H/W (in that order) with corresponding stride factor
for i, stride_factor in enumerate(stride_multiplier, start=1):
- strides[i] *= stride_factor
+ multiplied_strides[i] *= stride_factor
+
+ # Stride multiplier only affects tiles and addresses for OFM
+ _strides = multiplied_strides if is_ofm else strides
height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
- box.start_coord, box.end_coord, strides, op_shape4D
+ box.start_coord, box.end_coord, _strides, op_shape4D
)
for idx, offset in enumerate(tile_base_offsets):
@@ -427,7 +431,9 @@ def create_feature_map(
fm.tiles = NpuTileBox(
height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
)
- fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
+ fm.strides = NpuShape3D(
+ height=int(multiplied_strides[2]), width=int(multiplied_strides[3]), depth=int(multiplied_strides[1])
+ )
fm.name = tens.name
return fm
@@ -518,8 +524,9 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit
ifm_height = cmd.ifm_box.get_block().height
ifm_width = cmd.ifm_box.get_block().width
ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
-
- npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0])
+ npu_op.ifm = create_feature_map(
+ cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0], op.ifm_stride_multiplier[0]
+ )
npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 3c87f9be..b9eee28b 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -253,7 +253,10 @@ class Subgraph:
for tens in ps.inputs:
for op in tens.ops:
pred_pass = op.scheduled_pass
- assert pred_pass.time < ps.time
+ # Pass with split concat ops may end up with a dependency to
+ # itself since output from concat is produced by several avg pool ops.
+ # Hence pred_pass can be equal to ps.
+ assert pred_pass == ps or pred_pass.time < ps.time
if ps not in pred_pass.successors:
pred_pass.successors.append(ps)
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index a831537b..9a917f22 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -511,6 +511,7 @@ class Operation:
"tile_base_offsets_ifm",
"tile_base_offsets_ofm",
"ofm_stride_multiplier",
+ "ifm_stride_multiplier",
)
def __init__(self, op_type: Op, name: str):
@@ -554,8 +555,9 @@ class Operation:
self.tile_base_offsets_ifm: List[List[int]] = [[0, 0, 0, 0], [0, 0, 0, 0]]
# ofm (nhwc)
self.tile_base_offsets_ofm: List[int] = [0, 0, 0, 0]
- # For interleaved/sparse outputs - stride is multiplied with the stride factor of the corresponding axis
- # Order is [C, H, W] - default is no multiplication
+ # Stride is multiplied with the ifm/ofm stride factor of the corresponding axis
+ # Order is [C, H, W]
+ self.ifm_stride_multiplier: List[List[int]] = [[1, 1, 1], [1, 1, 1]]
self.ofm_stride_multiplier: List[int] = [1, 1, 1]
def clone(self, suffix="_clone"):
@@ -585,6 +587,7 @@ class Operation:
res.ifm_resampling_mode = self.ifm_resampling_mode
res.tile_base_offsets_ifm = [_ifm.copy() for _ifm in self.tile_base_offsets_ifm]
res.tile_base_offsets_ofm = self.tile_base_offsets_ofm.copy()
+ res.ifm_stride_multiplier = [_ifm.copy() for _ifm in self.ifm_stride_multiplier]
res.ofm_stride_multiplier = self.ofm_stride_multiplier.copy()
return res
@@ -763,6 +766,7 @@ class Operation:
offset_start = None
offset_end = None
axis = None
+ strides_tens = None
if self.type == Op.Split:
num_splits = self.attrs.get("num_splits")
axis_tens = self.inputs[0]
@@ -831,7 +835,7 @@ class Operation:
else:
assert False
- return input_tens, outputs, axis, offset_start, offset_end
+ return input_tens, outputs, axis, offset_start, offset_end, strides_tens
def set_activation_lut(self, lut_tensor):
self.activation = ActivationFunction(Op.LUT)
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 0de0341d..f157e67b 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -524,6 +524,27 @@ def pack_into_passes(nng, arch, verbose_packing=False):
# Sort ops by op_index (same call order as in the original graph)
pass_list_top = sorted(pass_list_top, key=lambda ps: -1 if ps.ops[0].op_index is None else ps.ops[0].op_index)
+ # A concat is implemented by several AvgPool ops writing to the same ofm but with slice offset
+ # If there is a cpu op in between, group all AvgPool ops for a concat so that they run
+ # within the same cmd stream
+ last_idx = len(pass_list) - 1
+ for npu_ps in reversed(pass_list):
+ if npu_ps.placement == PassPlacement.Cpu or not npu_ps.ops[0].original_type.is_concat_op():
+ continue
+ # Concat pass found, search forward for the next avgpool op writing to the same ofm
+ idx = pass_list.index(npu_ps)
+ concat_is_split_between_npu_ops = False
+ for next_ps in pass_list[idx + 1 :]:
+ if next_ps.placement == PassPlacement.Cpu:
+ concat_is_split_between_npu_ops = True
+ next_is_concat = next_ps.ops[0].original_type.is_concat_op()
+ if next_is_concat and next_ps.ops[0].ofm == npu_ps.ops[0].ofm and concat_is_split_between_npu_ops:
+ # Avgpool writing to the same OFM and there is a cpu op between them, group them
+ pass_list.remove(npu_ps)
+ insert_index = pass_list.index(next_ps)
+ pass_list.insert(insert_index, npu_ps)
+ break
+
# Sort the rest of the list based on critera 2.
# Search from bottom of list and when a CPU pass is found
# search forward in the list and see if it is possible to join another CPU pass.
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index e65717a8..3b15b318 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -542,7 +542,9 @@ def create_strided_slice():
def test_constraint_stridedslice_stride_values():
# Unsupported strides
op = create_strided_slice()
- op.inputs[3].values = [1, 1, 2, 1]
+ op.inputs[3].values = [1, 2, 2, 1]
+ assert support.is_operator_supported(op)
+ op.inputs[3].values = [2, 1, 1, 1]
assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 687e5d4f..ccbb1f28 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -141,7 +141,7 @@ def rewrite_split_ops(tens, arch, nng):
if not split_op.run_on_npu:
return tens
- inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
+ inp, outputs, axis, offset_start, offset_end, strides_tens = split_op.get_split_inputs_axis()
tens.ops = []
new_op = Operation(Op.SplitSliceRead, split_op.name)
@@ -150,8 +150,10 @@ def rewrite_split_ops(tens, arch, nng):
if None in (offset_end, offset_start):
read_shape = None
else:
- # the read shape is relative to each start offset
- read_shape = Shape4D([oe - os for oe, os in zip(offset_end, offset_start)])
+ # The read shape is relative to each start offset
+ # Limit read shape to the size of the IFM - offset is not necessarily limited
+ ifm_dims = split_op.ifm_shapes[0].as_list()
+ read_shape = Shape4D([min(oe, ifm_dim) - os for oe, os, ifm_dim in zip(offset_end, offset_start, ifm_dims)])
# For Split the offset cannot be extracted from the tensor so it has to
# be calculated from the index of the output tensor
@@ -182,6 +184,9 @@ def rewrite_split_ops(tens, arch, nng):
new_op.set_output_tensor(tens)
new_op.ifm_shapes.append(Shape4D(inp.shape))
new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
+ # Set stride multiplier in H/W if a stride tensor is provided
+ s_h, s_w = (strides_tens.values[-3], strides_tens.values[-2]) if strides_tens else (1, 1)
+ new_op.ifm_stride_multiplier[0] = [1, s_h, s_w] # C/H/W
DebugDatabase.add_optimised(split_op, new_op)
return tens
@@ -193,18 +198,24 @@ def remove_SplitSliceRead(op, arch):
# Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
# or if an avgpool need to be inserted
# Not possible to move:
+ # - if ifm stride multiplier is larger than one in any dimension
# - if consumer is a Transpose op since ifm shape has been reshaped and can not be changed
# - if consumer is elementwise and ifm needs to be broadcasted
- if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
- consumer is not None
- and consumer.run_on_npu
- and consumer.type not in memory_only_ops
- and consumer.original_type != Op.Transpose
- and check_splitsliceread_to_consumer_shape(op, consumer)
- and not (
- consumer.type.is_binary_elementwise_op() and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0]
+ if (
+ op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+ and all(s_mul == 1 for s_mul in op.ifm_stride_multiplier[0])
+ and all(
+ consumer is not None
+ and consumer.run_on_npu
+ and consumer.type not in memory_only_ops
+ and consumer.original_type != Op.Transpose
+ and check_splitsliceread_to_consumer_shape(op, consumer)
+ and not (
+ consumer.type.is_binary_elementwise_op()
+ and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0]
+ )
+ for consumer in op.ofm.consumer_list
)
- for consumer in op.ofm.consumer_list
):
# SplitSliceRead can be performed by tensor consumer(s)
for cons_op in list(op.ofm.consumer_list):
@@ -219,6 +230,9 @@ def remove_SplitSliceRead(op, arch):
avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
avgpool_op.read_offsets[0] = op.read_offsets[0]
avgpool_op.read_shapes[0] = op.read_shapes[0]
+ if any(s_mul != 1 for s_mul in op.ifm_stride_multiplier[0]):
+ avgpool_op.ifm_stride_multiplier[0] = op.ifm_stride_multiplier[0].copy()
+ avgpool_op.ifm.force_linear_format = True
op.ifm.consumer_list.remove(op)
DebugDatabase.add_optimised(op, avgpool_op)
@@ -827,7 +841,7 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
if op.type == Op.FullyConnected:
# Check if the first dimension indicates batching
if op.ifm_shapes[0].batch > 1:
- batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+ batching_split = {4: (2, 2), 6: (2, 3), 8: (2, 4), 9: (3, 3), 12: (3, 4), 16: (4, 4)}
n = op.ifm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
@@ -840,6 +854,13 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
n = op.ofm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
+ if h == 1 and w > 4:
+ # If batch can not be found in the split set the weights are going to be
+ # read from memory several times. Convert op to conv2d since this
+ # enables weight buffering.
+ op.type = Op.Conv2DBias
+ op.attrs["padding"] = Padding.SAME
+ DebugDatabase.add_optimised(op, op)
return op
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 91a3ee83..b293a2ef 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -841,10 +841,11 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_stridedslice_stride_values(op):
- "All Strides values must be 1"
+ "Batch and channel stride values must be 1"
strides = op.inputs[3]
- valid = all(stride == 1 for stride in strides.values)
- return valid, f"Op has strides values {strides.values}"
+ s_c = strides.values[-1]
+ s_n = strides.values[0] if len(strides.values) > 3 else 1
+ return s_n == s_c == 1, f"Op has strides values {strides.values}"
@staticmethod
def constraint_stridedslice_offset_false(op):
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 09b2c526..26d3dcad 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -247,7 +247,11 @@ def fix_sg_input_output_tosa(op, arch, nng):
# consumed by CPU
# Check if operator ifm/ofm are sg ifm/ofm
- ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+ ifm_is_sg_ifm = op.ifm.ops[0].type in (
+ Op.Placeholder,
+ Op.SubgraphInput,
+ Op.Const,
+ )
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
# Check if ifm/ofm is produced repectivly consumed by CPU
@@ -302,7 +306,13 @@ def remove_splitsliceread(op, arch):
else:
name = op.name + "_add"
ofm = op.ofm
- ifm2 = create_const_tensor(name + "_zero_scalar", [1], ofm.dtype, [0], quantization=ofm.quantization)
+ ifm2 = create_const_tensor(
+ name + "_zero_scalar",
+ [1],
+ ofm.dtype,
+ [0],
+ quantization=ofm.quantization,
+ )
add_op = create_add_nop(name)
add_op.inputs = [op.ifm, ifm2]
add_op.outputs = [ofm]
@@ -330,7 +340,13 @@ def rewrite_concat(op):
write_offset = [0, 0, 0, 0]
write_offset[axis_4D] = offset
concat_end = offset + op.ifm_shapes[idx][axis_4D]
- create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
+ create_add_for_concat(
+ op,
+ op.name + str(idx) + "_add",
+ inp,
+ op.ifm_shapes[idx],
+ Shape4D.from_list(write_offset),
+ )
offset = concat_end
assert op.ofm_shapes[0][axis_4D] == offset
@@ -417,7 +433,10 @@ def rewrite_rescale(op, arch, nng):
DebugDatabase.add_optimised(op, prev_op)
return op
else:
- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
+ print(
+ "Warning, unsupported fusing of TOSA Rescale previous operator is of type:",
+ prev_op.type,
+ )
assert False
elif (
(ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
@@ -447,7 +466,7 @@ def rewrite_rescale(op, arch, nng):
for a in equal_attributes:
assert op.attrs[a] == rescale_1.attrs[a] == rescale_2.attrs[a], (
f"Only handling equal {a} for all operands "
- "({op.attrs[a]}, {rescale_1.attrs[a]}, {rescale_2.attrs[a]}) "
+ f"({op.attrs[a]}, {rescale_1.attrs[a]}, {rescale_2.attrs[a]}) "
"for all the rescale operations to be fused with Add!"
)
@@ -486,7 +505,10 @@ def rewrite_rescale(op, arch, nng):
print("Warning, unsupported fusing of TOSA Rescale with Add.")
assert False
else:
- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
+ print(
+ "Warning, unsupported fusing of TOSA Rescale previous operator is of type:",
+ prev_op.type,
+ )
assert False
return op
@@ -519,17 +541,31 @@ def convert_pad_in_width(op):
if left > 0:
shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
zero_tens = create_const_tensor(
- op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
+ op.name + "_left",
+ shape.as_list(),
+ ofm.dtype,
+ shape.elements() * [pad_value],
+ quantization=quant,
)
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp0)
if right > 0:
shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
zero_tens = create_const_tensor(
- op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
+ op.name + "_right",
+ shape.as_list(),
+ ofm.dtype,
+ shape.elements() * [pad_value],
+ quantization=quant,
)
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
- create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp0.with_width(ofm_shape.width - right))
+ create_add_for_concat(
+ op,
+ op.name + "_right",
+ zero_tens,
+ shape,
+ shp0.with_width(ofm_shape.width - right),
+ )
op.type = Op.ConcatTFLite
return add_op
@@ -992,7 +1028,12 @@ def tosa_optimise_graph(nng, arch):
)
# Rewite Operators step
- op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv, convert_table_to_lut]
+ op_rewrite_list = [
+ set_tensor_equivalence,
+ rewrite_rescale,
+ convert_depthwise_to_conv,
+ convert_table_to_lut,
+ ]
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 6d80e10d..670b264a 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -189,7 +189,8 @@ class TosaSubgraph:
elif op.type.is_conv2d_op():
inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False)
elif op.type.is_depthwise_conv2d_op():
- inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False)
+ HWCM_to_HWOI = (0, 1, 3, 2)
+ inputs[1] = clone_and_reshape_tensor(inputs[1], HWCM_to_HWOI, False)
if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
# No Bias tensor
inputs.append(None)
@@ -241,7 +242,14 @@ class TosaSubgraph:
if shift != 0:
op.explicit_scaling = ExplicitScaling(False, [shift], [1])
if op.type.is_depthwise_conv2d_op():
- op.attrs["depth_multiplier"] = op.weights.shape[3]
+ assert op.weights.shape[-1] % op.ifm.shape[-1] == 0
+ depth_multiplier = op.weights.shape[-1] / op.ifm.shape[-1]
+ if depth_multiplier > 1:
+ assert op.ifm.shape[-1] == 1 and op.ofm.shape[-1] == depth_multiplier, (
+ "For depth multipliers > 1, IFM channels must be 1 and "
+ "OFM channels must be equal to the depth multiplier"
+ )
+ op.attrs["depth_multiplier"] = depth_multiplier
if op.type == Op.SplitSliceRead:
op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0)
op.read_shapes[0] = op.attrs["size"]