aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-01 16:02:29 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-18 16:33:32 +0100
commit2349d429d926e258e9a61d34c7fd97660ab9fb98 (patch)
treeb5151d0f12428e47d64b1fb2ce4f2f8c19304a0d
parent528a56df829b65f7a2c61953650b123c461095f7 (diff)
downloadethos-u-vela-2349d429d926e258e9a61d34c7fd97660ab9fb98.tar.gz
MLBEDSW-3654 Add/use op ifm/ofm shapes
Add ifm/ofm shapes to op Changed to rely on these shapes Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
-rw-r--r--ethosu/vela/debug_database.py2
-rw-r--r--ethosu/vela/graph_optimiser.py69
-rw-r--r--ethosu/vela/high_level_command_stream.py2
-rw-r--r--ethosu/vela/high_level_command_stream_generator.py60
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py12
-rw-r--r--ethosu/vela/insert_dma.py2
-rw-r--r--ethosu/vela/live_range.py4
-rw-r--r--ethosu/vela/nn_graph.py2
-rw-r--r--ethosu/vela/npu_performance.py10
-rw-r--r--ethosu/vela/operation.py50
-rw-r--r--ethosu/vela/operation_util.py3
-rw-r--r--ethosu/vela/pass_packing.py19
-rw-r--r--ethosu/vela/scheduler.py9
-rw-r--r--ethosu/vela/shared_buffer_allocation.py9
-rw-r--r--ethosu/vela/softmax.py5
-rw-r--r--ethosu/vela/tensor.py55
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py13
-rw-r--r--ethosu/vela/test/testutil.py5
18 files changed, 231 insertions, 100 deletions
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
index 4f0a50ae..203503f2 100644
--- a/ethosu/vela/debug_database.py
+++ b/ethosu/vela/debug_database.py
@@ -79,7 +79,7 @@ class DebugDatabase:
src_uid = cls._sourceUID[parent]
uid = len(cls._optimisedUID)
cls._optimisedUID[op] = (uid, src_uid)
- ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1)
+ ofm_shape = op.ofm_shapes[0] if op.ofm_shapes else numeric_util.full_shape(3, op.outputs[0].shape, 1)
cls._optimisedTable.append(
[uid, src_uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 4806001f..fdb0fae0 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -75,7 +75,7 @@ def rewrite_concat(tens, arch, nng):
new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
new_op.inputs = [inp]
new_op.outputs = [tens]
- new_op.attrs["concat_axis"] = axis
+ new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape))
new_op.attrs["concat_start"] = offset
offset += inp.shape[axis]
new_op.attrs["concat_end"] = offset
@@ -116,21 +116,20 @@ def rewrite_split(tens, arch, nng):
# be calculated from the index of the output tensor
if axis is not None:
# Get the start and end of the split
- offset_start = [0] * len(tens.shape)
- offset_end = [0] * len(tens.shape)
- for out in outputs:
+ offset_start = [0] * 4
+ for idx, out in enumerate(outputs):
if out == tens:
break
- offset_start[axis] += out.shape[axis]
+ axis_4D = axis + (4 - len(out.shape))
+ offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
# If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
if (offset_start[-1] % 16) != 0:
inp.avoid_NHCWB16 = True
-
- offset_end[axis] = offset_start[axis] + tens.shape[axis]
+ else:
+ offset_start = full_shape(4, offset_start, 0)
new_op.attrs["split_start"] = offset_start
- new_op.attrs["split_end"] = offset_end
new_op.run_on_npu = True
new_op.set_output_tensor(tens)
DebugDatabase.add_optimised(split_op, new_op)
@@ -217,6 +216,8 @@ def convert_resizebilinear_1x1_to_add(op):
# Set the add inputs
op.inputs[1] = op.inputs[0]
op.inputs[0] = tens
+ op.ifm_shapes = []
+ op.ofm_shapes = []
return op
@@ -321,13 +322,16 @@ def convert_batched_fc_shape(op, arch, nng):
ifm = op.inputs[0]
ofm = op.outputs[0]
# Check if the FC is 2D and first dimension indicates batching
- if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1:
+ # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
+ if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1:
n = ifm.shape[0]
batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
h, w = batching_split.get(n, (1, n))
prev_op = ifm.ops[0]
desired_shape = [1, h, w, ifm.shape[-1]]
+ op.ifm_shapes[0] = desired_shape
+
if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
# There is a preceding Reshape
# Compare input of prev_op and input of op, to see if prev_op can be removed
@@ -352,6 +356,8 @@ def convert_batched_fc_shape(op, arch, nng):
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
desired_shape = [1, h, w, ofm.shape[-1]]
+ op.ofm_shapes[0] = desired_shape
+
if (
len(ofm.consumer_list) == 1
and ofm.consumer_list[0] is not None
@@ -451,6 +457,7 @@ def fixup_stridedslice_output(tens, arch, nng):
new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
for idx, out_tens in enumerate(op.outputs):
+ op.ofm_shapes[idx] = new_shape_tens
reshape_in = out_tens.clone("_reshaped")
reshape_in.set_all_shapes(reshape_input_shape)
reshape_in.ops = [op]
@@ -489,7 +496,6 @@ def fixup_unpack_output(tens, arch, nng):
DebugDatabase.add_optimised(op, reshape_op)
op.outputs[idx] = reshape_in
-
return tens
@@ -582,7 +588,7 @@ def convert_conv_to_fc(op, arch, nng):
# caching/double buffering for the weights.
# (Weights dont need to be reloaded for convs when IFM H and W are 1)
if op.type == Op.Conv2DBias:
- _, h, w, _ = op.inputs[0].shape
+ _, h, w, _ = op.ifm_shapes[0]
kh, kw, _, _ = op.inputs[1].shape
if h == 1 and w == 1 and kh == 1 and kw == 1:
# Overwrite this op as a Fully Connected Op
@@ -595,6 +601,7 @@ def convert_conv_to_fc(op, arch, nng):
weight_tensor = op.inputs[1]
weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+
# The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it
# back to 4D afterwards as the next layer is expecting that shape
orig_ofm_tensor = op.outputs[0]
@@ -609,6 +616,7 @@ def convert_conv_to_fc(op, arch, nng):
reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
reshape_op.set_output_tensor(orig_ofm_tensor)
+
# Replace this ops OFM to point to the 2D tensor
op.outputs[0] = fc_ofm_tensor
# Record optimisation in debug database
@@ -651,6 +659,8 @@ def fixup_act_reorder(op, arch, nng):
prep_op = get_prepend_op(op)
if prep_op is not None:
act_op = op.clone("_reordered")
+ act_op.ifm_shapes = list(op.ifm_shapes)
+ act_op.ofm_shapes = list(op.ofm_shapes)
# There is only one input tensor, overwrite it
act_op.set_input_tensor(prep_op.inputs[0], 0)
@@ -658,6 +668,8 @@ def fixup_act_reorder(op, arch, nng):
act_op_out = act_op.inputs[0].clone("_acted")
act_op_out.quantization = op.outputs[0].quantization.clone()
act_op.set_output_tensor(act_op_out)
+ act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1)
+ act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1)
# Update the consumer list
act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -704,6 +716,15 @@ def set_tensor_equivalence(op, arch, nng):
return op
+def set_ifm_ofm_op_shapes(op, arch, nng):
+ if op.run_on_npu and op.type.needs_shapes():
+ if op.ifm_shapes or op.ofm_shapes:
+ # Shapes already set
+ return op
+ op.set_ifm_ofm_shapes()
+ return op
+
+
def convert_softmax(op, arch, nng):
if op.type == Op.Softmax and op.run_on_npu:
softmax = SoftMax(op)
@@ -839,7 +860,7 @@ def convert_lrelu_to_mul_max(op, arch):
mul_identity.add_input_tensor(identity_tens)
fm_id = ofm.clone(op.name + "_id")
mul_identity.set_output_tensor(fm_id)
- DebugDatabase.add_optimised(op, mul_alpha)
+ DebugDatabase.add_optimised(op, mul_identity)
# Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
op.type = Op.Maximum
@@ -869,6 +890,8 @@ def convert_to_lut(op, lut_values, lut_name):
quantization.zero_point = 0
tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
op.add_input_tensor(tens)
+ op.ifm_shapes.append(full_shape(4, tens.shape, 1))
+
# The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
# so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
# should be the same as the IFM
@@ -1072,10 +1095,20 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
if verbose_graph:
nng.print_graph()
+ pre_process_list = [
+ supported_operator_check,
+ set_ifm_ofm_op_shapes,
+ # TODO: memory-only Op removal
+ ]
+
+ for idx, sg in enumerate(nng.subgraphs):
+ # rewrite graph pass
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+ )
+
op_rewrite_list = [
set_tensor_equivalence,
- supported_operator_check,
- # then do any rewrites of supported operators
convert_depthwise_to_conv,
convert_conv_to_fc,
convert_softmax,
@@ -1106,7 +1139,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# remove passthrough tensors and attempt further optimizations
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
+ nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields],
)
# Post-optimisation operator debug tracing
@@ -1125,7 +1158,11 @@ def optimise_graph_b(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# combined rewrite graph pass
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], []
+ nng,
+ sg,
+ arch,
+ [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split],
+ [set_ifm_ofm_op_shapes],
)
if verbose_graph:
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index c45bc4e5..bb4f1424 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -197,7 +197,7 @@ class NpuStripe(Command):
self.pad_top = pad_top
self.pad_bottom = pad_bottom
for i in range(len(self.ofm_box.end_coord)):
- assert self.ofm_box.end_coord[i] <= self.ofm_tensor.shape[i]
+ assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0][i]
def is_npu_pass_command(self):
return True
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 905263d6..18a419c0 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -56,6 +56,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
# Ensure correct ifm and ifm2 order
if match_tensor(ps.inputs[0], ps.primary_op.inputs[1]) and match_tensor(ps.inputs[1], ps.primary_op.inputs[0]):
ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor
+ ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
for op in ps.ops:
if op.type == Op.SplitSliceRead:
@@ -77,13 +78,20 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
ifm_idx += 1
ifm_tensor = ps.ifm_tensor
+ ifm_shape = None
+ if ifm_tensor.shape != []:
+ ifm_shape = ps.ifm_shapes[0]
ifm2_tensor = ps.ifm2_tensor
+ ifm2_shape = None
+ if ifm2_tensor is not None and ifm2_tensor.shape != []:
+ ifm2_shape = ps.ifm_shapes[1]
ofm_tensor = ps.ofm_tensor
+ ofm_shape = ps.ofm_shapes[0]
weight_tensor = ps.weight_tensor
scale_tensor = ps.scale_tensor
- ofm_start = [0] * len(ofm_tensor.shape)
- ofm_end = list(ofm_tensor.shape)
+ ofm_start = [0] * len(ofm_shape)
+ ofm_end = list(ofm_shape)
strides = None
skirt = None
@@ -92,9 +100,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
strides = ps.primary_op.attrs.get("strides", None)
skirt = ps.primary_op.attrs.get("skirt", None)
if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
- upscaling = ofm_tensor.shape[-3] // ifm_tensor.shape[-3]
+ upscaling = ofm_shape[-3] // ifm_shape[-3]
elif ps.primary_op.type == Op.ResizeBilinear:
- upscaling = round_up_divide(ofm_tensor.shape[-3], ifm_tensor.shape[-3])
+ upscaling = round_up_divide(ofm_shape[-3], ifm_shape[-3])
concat_axis = 0
concat_offset = 0
@@ -125,7 +133,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
ifm_box = None
ifm2_box = None
- if ifm_tensor.shape != []:
+ if ifm_shape is not None:
ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
strides,
skirt,
@@ -138,16 +146,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
)
else:
ifm_box = Box([], [])
- if ifm2_tensor is not None and ifm2_tensor.shape != []:
+ if ifm2_shape is not None:
ifm2_box, _, _ = ofm_box.transform_with_strides_and_skirt(
- strides,
- skirt,
- ifm2_tensor.shape,
- npu_block_type,
- concat_axis,
- concat_offset,
- split_offsets[1],
- upscaling,
+ strides, skirt, ifm2_shape, npu_block_type, concat_axis, concat_offset, split_offsets[1], upscaling,
)
else:
ifm2_box = Box([], [])
@@ -212,19 +213,17 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
elif strat == SchedulingStrategy.IfmStream:
y_step = block_config[0]
- y_start = 0
- y_dim = 1
- if len(ofm_tensor.shape) >= 3:
- y_start = ofm_start[-3]
- y_dim = ofm_end[-3]
+ y_start = ofm_start[-3]
+ y_dim = ofm_end[-3]
+
if idx > 0:
ifm_y_present = 0
prev_pass = passes[idx - 1]
prev_pass_gen = generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx - 1)
else:
ifm_y_present = 1
- if len(ifm_tensor.shape) >= 3:
- ifm_y_present = ifm_tensor.shape[-3]
+ if len(ifm_shape) >= 3:
+ ifm_y_present = ifm_shape[-3]
prev_pass_gen = []
prev_pass = None
@@ -243,9 +242,8 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
for start in range(y_start, y_dim, y_step):
end = min(start + y_step, y_dim)
- if len(ofm_tensor.shape) >= 3:
- ofm_start[-3] = start
- ofm_end[-3] = end
+ ofm_start[-3] = start
+ ofm_end[-3] = end
ofm_box = Box(ofm_start, ofm_end)
k_height = 1
@@ -259,7 +257,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
strides,
skirt,
- ifm_tensor.shape,
+ ifm_shape,
npu_block_type,
concat_axis,
concat_offset,
@@ -381,11 +379,15 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs):
for cmd in generate_high_level_command_stream_for_pass_list(strat, passes, block_configs):
if cmd.is_npu_pass_command():
if cmd.is_first:
- ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.start_coord, is_top_box=False)
+ ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
+ cmd.ifm_box.start_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=False
+ )
if ifm_read is None:
return 0
if cmd.is_last:
- write_offset = cmd.ofm_tensor.address_offset_for_coordinate(cmd.ofm_box.end_coord, is_top_box=True)
+ write_offset = cmd.ofm_tensor.address_offset_for_coordinate(
+ cmd.ofm_box.end_coord, shape=cmd.ps.ofm_shapes[0], is_top_box=True
+ )
if write_offset is None:
return 0
highest_ofm_write = max(write_offset, highest_ofm_write)
@@ -396,7 +398,9 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs):
min_overlap = min(min_overlap, can_overwrite)
if cmd.is_first:
- ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.end_coord, is_top_box=True)
+ ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
+ cmd.ifm_box.end_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=True
+ )
min_overlap = max(min_overlap, 0)
return min_overlap
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 096a65cc..9380374e 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -231,7 +231,7 @@ def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
-def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> NpuFeatureMap:
+def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: List[int]) -> NpuFeatureMap:
"""Creates feature map with common fields populated"""
fm = NpuFeatureMap()
fm.region = get_region(tens, arch)
@@ -242,7 +242,7 @@ def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> Np
fm.layout = NpuLayout.NHCWB16
else:
assert 0, "Incorrect tensor format"
- height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord)
+ height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord, fm_shape)
for idx, addr in enumerate(addresses):
if addr is None:
addresses[idx] = 0
@@ -326,12 +326,12 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit
ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width
ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
- npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch)
+ npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
out_block = cmd.ofm_box.get_block()
- npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch)
+ npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
@@ -397,13 +397,15 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
elemwise_op = elementwise_op_map[op.type]
npu_op = NpuElementWiseOperation(elemwise_op)
+
if elemwise_op not in UNARY_ELEMWISE_OPS:
if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
# The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
+ ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
npu_op.reversed_operands = True
- npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch)
+ npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
if cmd.ifm2_tensor.shape == []:
# scalar
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
index fc1e7986..3797f43e 100644
--- a/ethosu/vela/insert_dma.py
+++ b/ethosu/vela/insert_dma.py
@@ -72,7 +72,7 @@ def insert_dma_cmd(op, arch, nng):
tens.purpose == TensorPurpose.FeatureMap
and op.type.is_binary_elementwise_op()
and tens.shape != []
- and tens.shape != op.outputs[0].shape
+ and op.ifm_shapes[0] != op.ofm_shapes[0]
and tens.storage_size() > max_ifm_shram_avail
):
only_vector_product_consumers = True
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 14e83a33..0cc89e27 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -181,12 +181,12 @@ def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_s
inps.append(elem_op.ifm2)
if len(inps) > 0:
- for inp in inps:
+ for i, inp in enumerate(inps):
# check input format, dtype, broadcasting or if there are more input consumers
if (
inp.format == elem_op.ofm.format
and inp.dtype == elem_op.ofm.dtype
- and inp.shape == elem_op.ofm.shape
+ and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
):
lr_graph.fuse_ranges(inp, elem_op.ofm)
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 0ae3de9a..67925176 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -58,6 +58,8 @@ class Pass:
self.name = name
self.cascade = None
self.placement = placement
+ self.ifm_shapes = []
+ self.ofm_shapes = []
# TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
# allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 9d83f6fb..c2ec4424 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -48,7 +48,7 @@ def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_conf
if ps2.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
op = ps2.primary_op
- ifm_block_depth = arch.calc_ifm_block_depth(op.ifm.shape[-1], op.ifm.dtype.size_in_bits())
+ ifm_block_depth = arch.calc_ifm_block_depth(op.ifm_shapes[0][-1], op.ifm.dtype.size_in_bits())
else:
ifm_block_depth = block_config_ps2[-1]
@@ -224,8 +224,8 @@ def estimate_conv_pooling_cycles(
scale_tensor=None,
):
ofm_ublock = Block(arch.config.ofm_ublock.width, arch.config.ofm_ublock.height, arch.config.ofm_ublock.depth)
- ifm_tens_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
- ofm_tens_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1)
+ ifm_tens_shape = primary_op.ifm_shapes[0]
+ ofm_tens_shape = primary_op.ofm_shapes[0]
if (
arch.config.ofm_ublock.height == 2
@@ -420,8 +420,8 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
npu_block_type = primary_op.type.npu_block_type
ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
- ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
- ofm_tensor_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1)
+ ifm_tensor_shape = list(ps.primary_op.ifm_shapes[0])
+ ofm_tensor_shape = list(ps.primary_op.ofm_shapes[0])
if npu_block_type == NpuBlockType.ReduceSum:
block_traversal = TensorBlockTraversal.DepthFirst
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 30c32acc..be26a26b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING
from .errors import VelaError
from .numeric_util import full_shape
+
if TYPE_CHECKING:
from .tensor import Tensor
@@ -129,7 +130,7 @@ class Op(Enum):
Concat = OperatorInfo(indices=CONCAT_INDICES)
ConcatEmbeddings = OperatorInfo()
ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
- ConcatTFLite = OperatorInfo()
+ ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES)
Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
@@ -197,7 +198,7 @@ class Op(Enum):
NonMaxSuppressionV5 = OperatorInfo()
NotEqual = OperatorInfo()
OneHot = OperatorInfo()
- Pack = OperatorInfo()
+ Pack = OperatorInfo(indices=IFM_INDICES)
PackReshaped = OperatorInfo(indices=IFM_INDICES)
Pad = OperatorInfo()
PadV2 = OperatorInfo()
@@ -260,7 +261,7 @@ class Op(Enum):
UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
Unique = OperatorInfo()
- Unpack = OperatorInfo()
+ Unpack = OperatorInfo(indices=IFM_INDICES)
UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
Where = OperatorInfo()
While = OperatorInfo()
@@ -305,14 +306,17 @@ class Op(Enum):
return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
def is_split_op(self):
- return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
+ return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
def is_concat_op(self):
- return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
+ return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
def needs_bias(self):
return bool(self.info.indices.biases)
+ def needs_shapes(self):
+ return bool(self.info.indices.ifms)
+
@classmethod
def op_set(cls, predicate):
# Returns the set of all operator codes that fulfill the given predicate
@@ -400,6 +404,8 @@ class Operation:
"forced_output_quantization",
"activation_lut",
"_kernel",
+ "ifm_shapes",
+ "ofm_shapes",
)
def __init__(self, op_type: Op, name: str):
@@ -421,6 +427,8 @@ class Operation:
self.op_index = None # input network operator index
self.activation_lut = None
self._kernel = None
+ self.ifm_shapes = []
+ self.ofm_shapes = []
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -697,3 +705,35 @@ class Operation:
lines += _print_tensors(self.outputs)
raise VelaError("\n".join(lines))
+
+ def set_ifm_ofm_shapes(self):
+ ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
+
+ # set all shapes to op, as 4D
+ if self.type == Op.FullyConnected:
+ n_in_elems = weight_tensor.shape[-2]
+ elms = ifm_tensor.elements()
+ batch_size = elms // n_in_elems
+ assert batch_size * n_in_elems == elms
+
+ self.ifm_shapes.append([batch_size, 1, 1, n_in_elems])
+ self.ofm_shapes.append(ofm_tensor.get_full_shape())
+ elif self.type == Op.Softmax:
+ self.ifm_shapes.append(ifm_tensor.get_full_shape())
+ self.ofm_shapes.append(ofm_tensor.get_full_shape())
+ elif self.type.is_split_op or self.type.is_concat_op():
+ for inp in self.inputs:
+ if inp is not None:
+ self.ifm_shapes.append(full_shape(4, inp.shape, 1))
+ else:
+ self.ifm_shapes.append(None)
+ for out in self.outputs:
+ if out is not None:
+ self.ofm_shapes.append(full_shape(4, out.shape, 1))
+ else:
+ self.ofm_shapes.append(None)
+ else:
+ self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1))
+ if ifm2_tensor is not None:
+ self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1))
+ self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1))
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index a267b2ad..a55b9548 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -61,6 +61,7 @@ def create_depthwise_maxpool(
ofm = Tensor([1, height, 1, 1], ifm.dtype, op.name + "_tens0")
ofm.quantization = quantization
op.set_output_tensor(ofm)
+ op.set_ifm_ofm_shapes()
return op
@@ -81,6 +82,7 @@ def create_reduce_sum(
sum_of_exp = Tensor(ofm_shape, DataType.int32, op.name + "_tens0")
sum_of_exp.quantization = quantization
op.set_output_tensor(sum_of_exp)
+ op.set_ifm_ofm_shapes()
return op
@@ -190,4 +192,5 @@ def create_binary_elementwise(
ofm = Tensor(ofm_shape, dtype, f"{op.name}_tens0")
ofm.quantization = quantization
op.set_output_tensor(ofm)
+ op.set_ifm_ofm_shapes()
return op
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 9bc04f29..095a78d4 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -397,11 +397,28 @@ def pack_into_passes(nng, arch, verbose_packing=False):
if len(ps.inputs) > 2:
ps.ifm_tensor = ps.inputs[-2]
+
+ # Get the corresponding ifm_shapes
+ for op in input_ops_list + [primary_op]:
+ if ps.ifm_tensor == op.ifm:
+ ps.ifm_shapes.append(op.ifm_shapes[0])
+ elif ps.ifm_tensor == op.ifm2:
+ ps.ifm_shapes.append(op.ifm_shapes[1])
+ for op in input_ops_list + [primary_op]:
+ if ps.ifm2_tensor == op.ifm:
+ ps.ifm_shapes.append(op.ifm_shapes[0])
+ elif ps.ifm2_tensor == op.ifm2:
+ ps.ifm_shapes.append(op.ifm_shapes[1])
else:
ps.ifm_tensor = ifm_tensor
ps.ifm2_tensor = None
+ if ps.primary_op is not None:
+ ps.ifm_shapes.append(ps.primary_op.ifm_shapes[0])
ps.ofm_tensor = ofm_tensor
+ if ps.primary_op is not None:
+ ps.ofm_shapes.append(ps.primary_op.ofm_shapes[0])
+
assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None
ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2]
@@ -436,6 +453,8 @@ def pack_into_passes(nng, arch, verbose_packing=False):
avgpool_out = inp.clone("_avgpooled")
avgpool_out.consumer_list.append(op)
avgpool_op.set_output_tensor(avgpool_out)
+ avgpool_op.ifm_shapes = op.ifm_shapes
+ avgpool_op.ofm_shapes = op.ofm_shapes
op.inputs[0] = avgpool_out
op_list.insert(0, avgpool_op)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 2c10640b..6cbff500 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -34,7 +34,6 @@ from .npu_performance import make_bandwidth_array
from .npu_performance import make_cycles_array
from .npu_performance import make_metrics_arrays
from .npu_performance import PassCycles
-from .numeric_util import full_shape
from .operation import NpuBlockType
from .operation import Op
from .operation import Operation
@@ -188,7 +187,7 @@ class StrategySet:
def __eq__(self, other):
if (self.bws != other.bws).any():
return False
- if (self.macs != other.macs).any():
+ if self.macs != other.macs:
return False
if (self.cycles != other.cycles).any():
return False
@@ -1000,10 +999,8 @@ class DynamicProgrammingScheduler:
rewrites.extend(get_rewrites(op))
# Detect no-op reshapes by comparing their full input and output tensor shapes.
- inshape = full_shape(4, op.inputs[0].shape, 1)
- compatible_shape = [
- (inshape == full_shape(4, oper.outputs[0].shape, 1)) for oper in get_rewrites(op)
- ]
+ inshape = op.ifm_shapes[0]
+ compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
use_NHCWB16 = compatible_shape and all(compatible_shape)
else:
use_NHCWB16 = False
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 600b3170..1f027d60 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -193,15 +193,16 @@ def shared_buffer_allocation_for_pass(arch, ps) -> SharedBufferAllocation:
if ifm_tensor:
ifm_resampling_mode = ifm_tensor.resampling_mode
ifm_bits = ifm_tensor.dtype.size_in_bits()
+ ifm_shape = ps.primary_op.ifm_shapes[0]
- if ifm_tensor.shape != []:
- ifm_depth = ifm_tensor.shape[-1]
+ if ifm_shape != []:
+ ifm_depth = ifm_shape[-1]
if is_elementwise:
ifm_count = 2
if ifm_tensor.shape == []: # Scalar in ifm1
assert ifm2_tensor
- ifm_depth = ifm2_tensor.shape[-1]
+ ifm_depth = ps.primary_op.ifm_shapes[1][-1]
ifm_count = 1
elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2
ifm_count = 1
@@ -215,7 +216,7 @@ def shared_buffer_allocation_for_pass(arch, ps) -> SharedBufferAllocation:
ifm_bits=ifm_bits,
ifm_depth=ifm_depth,
ifm_count=ifm_count,
- ofm_shape=ofm_tensor.shape,
+ ofm_shape=ps.primary_op.ofm_shapes[0],
)
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 8b061297..98496539 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -213,7 +213,7 @@ class SoftMax:
ofm = self.op.outputs[0]
# Reshape ifm/ofm (if needed)
- full_shape = ifm.get_full_shape()
+ full_shape = self.op.ifm_shapes[0]
if full_shape[0] > 1:
full_shape[1] *= full_shape[0]
full_shape[0] = 1
@@ -230,9 +230,6 @@ class SoftMax:
def get_graph_8bit(self, ifm, ofm):
exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
- ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
- DebugDatabase.add_optimised(self.op, ifm.ops[0])
- ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
no_scale_quant = ifm.quantization.clone()
no_scale_quant.scale_f32 = None
no_scale_quant.zero_point = 0
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 69618d2c..df8f8868 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -37,6 +37,7 @@ from .data_type import DataType
from .errors import UnsupportedFeatureError
from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .numeric_util import full_shape
from .operation import Op
from .operation import Operation
@@ -322,6 +323,8 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
reshape_op.add_input_tensor(reshape_ifm)
reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
reshape_op.set_output_tensor(reshape_ofm)
+ reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1))
+ reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1))
return reshape_ofm if ifm_reshape else reshape_ifm
@@ -605,20 +608,20 @@ class Tensor:
def consumers(self) -> List[Operation]:
return self.consumer_list
- def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape) -> Tuple:
+ def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple:
# returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
- if len(start_coord) < 4:
- box_height0 = 1
- box_width = 1
-
- if len(start_coord) >= 2:
- box_width = end_coord[-2] - start_coord[-2]
-
- return box_height0, box_height0, box_width, [self.address_for_coordinate(start_coord), None, None, None]
+ if self.storage_shape == []:
+ return (
+ 1,
+ 1,
+ 1,
+ [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None],
+ )
- crossing_y = numeric_util.round_up(start_coord[1] + 1, self.storage_shape[1])
- crossing_x = numeric_util.round_up(start_coord[2] + 1, self.storage_shape[2])
+ storage_shape_4D = full_shape(4, self.storage_shape, 1)
+ crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D[1])
+ crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D[2])
crossing_y = min(crossing_y, end_coord[1])
crossing_x = min(crossing_x, end_coord[2])
@@ -627,20 +630,28 @@ class Tensor:
box_width = crossing_x - start_coord[2]
addresses: List = [None] * 4
- addresses[0] = self.address_for_coordinate(start_coord)
+ addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape)
if end_coord[2] > crossing_x:
- addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]])
+ addresses[1] = self.address_for_coordinate(
+ [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape
+ )
raise UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
- addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]])
+ addresses[2] = self.address_for_coordinate(
+ [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape
+ )
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
- addresses[3] = self.address_for_coordinate([start_coord[0], crossing_y, crossing_x, start_coord[3]])
+ addresses[3] = self.address_for_coordinate(
+ [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape
+ )
return box_height0, box_height0, box_width, addresses
- def address_for_coordinate(self, coord: Shape, is_top_box: bool = False) -> int:
- offset = self.address_offset_for_coordinate(coord, is_top_box)
+ def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, shape: Shape = None) -> int:
+ if shape is None:
+ shape = self.shape
+ offset = self.address_offset_for_coordinate(coord, shape, is_top_box)
assert offset is not None
return self.address + offset
@@ -752,18 +763,18 @@ class Tensor:
assert 0 <= index < len(self.compressed_values)
return index == len(self.compressed_values) - 1
- def address_offset_for_coordinate(self, orig_coord: Shape, is_top_box: bool = False) -> Optional[int]:
+ def address_offset_for_coordinate(self, orig_coord: Shape, shape: Shape, is_top_box: bool = False) -> Optional[int]:
address_offset = 0
coord = orig_coord
coord = coord[-len(self.storage_shape) :]
if self.sub_purpose == TensorSubPurpose.Standard:
- for idx, c in enumerate(coord):
+ for idx, c in enumerate(orig_coord):
if is_top_box:
- assert c > 0 and c <= self.shape[idx]
+ assert c > 0 and c <= shape[idx]
else:
- assert c >= 0 and c < self.shape[idx]
+ assert c >= 0 and c < shape[idx]
if self.format == TensorFormat.WeightsCompressed:
if len(self.weight_compressed_offsets) == 0:
@@ -830,7 +841,7 @@ class Tensor:
def get_full_shape(self) -> Shape:
d = len(self.shape)
if d in (1, 3):
- return numeric_util.full_shape(4, self.shape, 1)
+ return full_shape(4, self.shape, 1)
elif d == 2:
return [self.shape[0], 1, 1, self.shape[1]]
else:
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 62a1b763..45377417 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -32,9 +32,16 @@ def test_convert_batched_fc():
weights = create_const_tensor("weight_in", shape, np.uint8, np.zeros(shape))
ofm = Tensor(ifm.shape, np.uint8, "test_out")
op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
+
ifm.consumer_list.append(op)
+ op.ifm_shapes.append([4, 1, 1, 8])
+ op.ofm_shapes.append([4, 1, 1, 8])
+
prev_op = op.clone()
+ prev_op.ifm_shapes = op.ifm_shapes
+ prev_op.ofm_shapes = op.ofm_shapes
+
conv_op = convert_batched_fc_shape(op, None, None)
assert conv_op.ifm != prev_op.ifm
@@ -51,7 +58,13 @@ def test_convert_batched_fc():
op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
ifm.consumer_list.append(op)
+ op.ifm_shapes.append([1, 1, 1, 8])
+ op.ofm_shapes.append([1, 1, 1, 8])
+
prev_op = op.clone()
+ prev_op.ifm_shapes = op.ifm_shapes
+ prev_op.ofm_shapes = op.ofm_shapes
+
conv_op = convert_batched_fc_shape(op, None, None)
assert conv_op.ifm == prev_op.ifm
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 9ba39bc5..63f841b4 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -69,6 +69,8 @@ def create_elemwise_op(
ofm = Tensor(ofm_shape, datatype, name + "_ofm")
ofm.quantization = ofm_quant
op.set_output_tensor(ofm)
+ op.set_ifm_ofm_shapes()
+
return op
@@ -104,6 +106,8 @@ def create_op_with_quant_tensors(
qp.zero_point = np.zeros(bias_shape)
bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), np.int32, quantization=qp)
op.add_input_tensor(bias)
+
+ op.set_ifm_ofm_shapes()
return op
@@ -113,6 +117,7 @@ def create_op(op_type, inputs, output, attrs=None):
op.outputs = [output]
if attrs is not None:
op.attrs = attrs
+ op.set_ifm_ofm_shapes()
return op