aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-21 16:56:26 +0000
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-21 16:56:26 +0000
commitdf0a5905177f3a1b836076bc3f9f39b2e86f1794 (patch)
treeb5151d0f12428e47d64b1fb2ce4f2f8c19304a0d
parentbf31d647dc5df47410ee577b12427ddf076d816b (diff)
downloadethos-u-vela-df0a5905177f3a1b836076bc3f9f39b2e86f1794.tar.gz
Revert "MLBEDSW-3645 4D class for op ifm/ofm shapes"
This reverts commit bf31d647dc5df47410ee577b12427ddf076d816b. Reason for revert: <INSERT REASONING HERE> Change-Id: I7b6c585b7658f94dbaa916c2b6bfe9fb463b8d37
-rw-r--r--ethosu/vela/debug_database.py15
-rw-r--r--ethosu/vela/graph_optimiser.py42
-rw-r--r--ethosu/vela/high_level_command_stream.py122
-rw-r--r--ethosu/vela/high_level_command_stream_generator.py32
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py3
-rw-r--r--ethosu/vela/nn_graph.py6
-rw-r--r--ethosu/vela/npu_performance.py41
-rw-r--r--ethosu/vela/operation.py28
-rw-r--r--ethosu/vela/pass_packing.py11
-rw-r--r--ethosu/vela/shape4d.py77
-rw-r--r--ethosu/vela/shared_buffer_allocation.py20
-rw-r--r--ethosu/vela/softmax.py4
-rw-r--r--ethosu/vela/tensor.py17
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py5
-rw-r--r--ethosu/vela/test/test_supported_operators.py2
-rw-r--r--ethosu/vela/test/testutil.py6
16 files changed, 201 insertions, 230 deletions
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
index 77e13eb0..203503f2 100644
--- a/ethosu/vela/debug_database.py
+++ b/ethosu/vela/debug_database.py
@@ -23,7 +23,7 @@ import lxml.etree as xml
from . import numeric_util
from .operation import Operation
-from .shape4d import Shape4D
+
UntypedDict = Dict[Any, Any]
UntypedList = List[Any]
@@ -79,18 +79,9 @@ class DebugDatabase:
src_uid = cls._sourceUID[parent]
uid = len(cls._optimisedUID)
cls._optimisedUID[op] = (uid, src_uid)
- ofm_shape = op.ofm_shapes[0] if op.ofm_shapes else Shape4D(op.outputs[0].shape)
+ ofm_shape = op.ofm_shapes[0] if op.ofm_shapes else numeric_util.full_shape(3, op.outputs[0].shape, 1)
cls._optimisedTable.append(
- [
- uid,
- src_uid,
- op.type,
- op.kernel.width,
- op.kernel.height,
- ofm_shape.width,
- ofm_shape.height,
- ofm_shape.depth,
- ]
+ [uid, src_uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
)
@classmethod
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 1128a311..fdb0fae0 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -37,7 +37,6 @@ from .operation import Op
from .operation import Operation
from .operation import Padding
from .operation_util import create_avgpool_nop
-from .shape4d import Shape4D
from .softmax import SoftMax
from .tensor import check_quantized_tens_scaling_equal
from .tensor import create_const_tensor
@@ -83,7 +82,6 @@ def rewrite_concat(tens, arch, nng):
new_op.run_on_npu = True
tens.ops.append(new_op)
DebugDatabase.add_optimised(concat_op, new_op)
- new_op.set_ifm_ofm_shapes()
assert tens.shape[axis] == offset
# If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -123,8 +121,7 @@ def rewrite_split(tens, arch, nng):
if out == tens:
break
axis_4D = axis + (4 - len(out.shape))
-
- offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
+ offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
# If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
if (offset_start[-1] % 16) != 0:
@@ -135,7 +132,6 @@ def rewrite_split(tens, arch, nng):
new_op.attrs["split_start"] = offset_start
new_op.run_on_npu = True
new_op.set_output_tensor(tens)
- new_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(split_op, new_op)
return tens
@@ -193,7 +189,6 @@ def fixup_conv2d_backprop(op, arch, nng):
if op.type == Op.Conv2DBackpropInput:
# flip the inputs
op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
- op.set_ifm_ofm_shapes()
op.type = Op.Conv2DBackpropInputSwitchedBias
# Update strides
@@ -221,7 +216,8 @@ def convert_resizebilinear_1x1_to_add(op):
# Set the add inputs
op.inputs[1] = op.inputs[0]
op.inputs[0] = tens
- op.set_ifm_ofm_shapes()
+ op.ifm_shapes = []
+ op.ofm_shapes = []
return op
@@ -327,14 +323,14 @@ def convert_batched_fc_shape(op, arch, nng):
ofm = op.outputs[0]
# Check if the FC is 2D and first dimension indicates batching
# TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
- if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1:
+ if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1:
n = ifm.shape[0]
batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
h, w = batching_split.get(n, (1, n))
prev_op = ifm.ops[0]
desired_shape = [1, h, w, ifm.shape[-1]]
- op.ifm_shapes[0] = Shape4D(desired_shape)
+ op.ifm_shapes[0] = desired_shape
if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
# There is a preceding Reshape
@@ -360,7 +356,7 @@ def convert_batched_fc_shape(op, arch, nng):
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
desired_shape = [1, h, w, ofm.shape[-1]]
- op.ofm_shapes[0] = Shape4D(desired_shape)
+ op.ofm_shapes[0] = desired_shape
if (
len(ofm.consumer_list) == 1
@@ -399,7 +395,6 @@ def fixup_pack_input(op, arch, nng):
reshape_op.attrs["new_shape"] = desired_shape
reshape_op.inputs = [inp, new_shape_tens]
reshape_op.set_output_tensor(reshape_out)
- reshape_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, reshape_op)
op.inputs[idx] = reshape_out
@@ -418,7 +413,6 @@ def unfuse_activation_function(op, arch, nng):
act_op.set_output_tensor(out_tens)
act_op.add_input_tensor(intermediate_tens)
op.set_output_tensor(intermediate_tens)
- act_op.set_ifm_ofm_shapes()
return op
@@ -463,7 +457,7 @@ def fixup_stridedslice_output(tens, arch, nng):
new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
for idx, out_tens in enumerate(op.outputs):
- op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape)
+ op.ofm_shapes[idx] = new_shape_tens
reshape_in = out_tens.clone("_reshaped")
reshape_in.set_all_shapes(reshape_input_shape)
reshape_in.ops = [op]
@@ -472,7 +466,6 @@ def fixup_stridedslice_output(tens, arch, nng):
reshape_op.attrs["new_shape"] = reshape_input_shape
reshape_op.inputs = [reshape_in, new_shape_tens]
reshape_op.set_output_tensor(out_tens)
- reshape_op.set_ifm_ofm_shapes()
op.outputs[idx] = reshape_in
@@ -500,7 +493,6 @@ def fixup_unpack_output(tens, arch, nng):
reshape_op.attrs["new_shape"] = reshape_input_shape
reshape_op.inputs = [reshape_in, new_shape_tens]
reshape_op.set_output_tensor(out_tens)
- reshape_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, reshape_op)
op.outputs[idx] = reshape_in
@@ -596,8 +588,7 @@ def convert_conv_to_fc(op, arch, nng):
# caching/double buffering for the weights.
# (Weights dont need to be reloaded for convs when IFM H and W are 1)
if op.type == Op.Conv2DBias:
- h = op.ifm_shapes[0].height
- w = op.ifm_shapes[0].width
+ _, h, w, _ = op.ifm_shapes[0]
kh, kw, _, _ = op.inputs[1].shape
if h == 1 and w == 1 and kh == 1 and kw == 1:
# Overwrite this op as a Fully Connected Op
@@ -625,11 +616,9 @@ def convert_conv_to_fc(op, arch, nng):
reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
reshape_op.set_output_tensor(orig_ofm_tensor)
- reshape_op.set_ifm_ofm_shapes()
# Replace this ops OFM to point to the 2D tensor
op.outputs[0] = fc_ofm_tensor
- op.set_ifm_ofm_shapes()
# Record optimisation in debug database
DebugDatabase.add_optimised(op, reshape_op)
DebugDatabase.add_optimised(op, op)
@@ -660,7 +649,6 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
relu_fused_op.add_input_tensor(ifm)
relu_fused_op.set_output_tensor(ofm)
- relu_fused_op.set_ifm_ofm_shapes()
op = relu_fused_op
return op
@@ -680,8 +668,8 @@ def fixup_act_reorder(op, arch, nng):
act_op_out = act_op.inputs[0].clone("_acted")
act_op_out.quantization = op.outputs[0].quantization.clone()
act_op.set_output_tensor(act_op_out)
- act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape)
- act_op.ofm_shapes[0] = Shape4D(act_op_out.shape)
+ act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1)
+ act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1)
# Update the consumer list
act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -851,7 +839,6 @@ def convert_lrelu_to_mul_max(op, arch):
mul_alpha.add_input_tensor(alpha_tens)
fm_alpha = ofm.clone(op.name + "_alpha")
mul_alpha.set_output_tensor(fm_alpha)
- mul_alpha.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, mul_alpha)
if check_quantized_tens_scaling_equal(ifm, ofm):
@@ -873,7 +860,6 @@ def convert_lrelu_to_mul_max(op, arch):
mul_identity.add_input_tensor(identity_tens)
fm_id = ofm.clone(op.name + "_id")
mul_identity.set_output_tensor(fm_id)
- mul_identity.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, mul_identity)
# Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
@@ -904,7 +890,7 @@ def convert_to_lut(op, lut_values, lut_name):
quantization.zero_point = 0
tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
op.add_input_tensor(tens)
- op.ifm_shapes.append(Shape4D(tens.shape))
+ op.ifm_shapes.append(full_shape(4, tens.shape, 1))
# The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
# so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
@@ -1172,7 +1158,11 @@ def optimise_graph_b(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# combined rewrite graph pass
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [],
+ nng,
+ sg,
+ arch,
+ [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split],
+ [set_ifm_ofm_op_shapes],
)
if verbose_graph:
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 9cbda452..bb4f1424 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -15,14 +15,11 @@
# limitations under the License.
# Description:
# Contains classes that hold commands for the high-level command stream (one command per DMA or NPU stripe).
-from typing import List
-
import numpy as np
from .architecture_features import Block
from .numeric_util import round_up_divide
from .operation import NpuBlockType
-from .shape4d import Shape4D
class Box:
@@ -35,15 +32,15 @@ class Box:
def transform_with_strides_and_skirt(
self,
- strides: List[int],
- skirt: List[int],
- ifm_shape: Shape4D,
- npu_block_type: NpuBlockType,
- concat_axis: int = 0,
- concat_offset: int = 0,
- split_offset: int = None,
- k_height: int = 1,
- upscaling_factor: int = 1,
+ strides,
+ skirt,
+ ifm_shape,
+ npu_block_type,
+ concat_axis=0,
+ concat_offset=0,
+ split_offset=None,
+ k_height=1,
+ upscaling_factor=1,
):
new_start_coord = list(self.start_coord)
new_end_coord = list(self.end_coord)
@@ -61,15 +58,15 @@ class Box:
):
# these types of operations do a "dot product" or sum over the entire IFM
new_start_coord[-1] = 0
- new_end_coord[-1] = ifm_shape.depth
+ new_end_coord[-1] = ifm_shape[-1]
- if npu_block_type == NpuBlockType.ElementWise and len(new_end_coord) >= 1:
- new_end_coord[-1] = min(new_end_coord[-1], ifm_shape.depth)
- if len(new_end_coord) >= 2:
- new_end_coord[-2] = min(new_end_coord[-2], ifm_shape.width * upscaling_factor)
- if len(new_end_coord) >= 3:
+ if npu_block_type == NpuBlockType.ElementWise and min(len(new_end_coord), len(ifm_shape)) >= 1:
+ new_end_coord[-1] = min(new_end_coord[-1], ifm_shape[-1])
+ if min(len(new_end_coord), len(ifm_shape)) >= 2:
+ new_end_coord[-2] = min(new_end_coord[-2], ifm_shape[-2] * upscaling_factor)
+ if min(len(new_end_coord), len(ifm_shape)) >= 3:
original_end_coord = list(new_end_coord)
- new_end_coord[-3] = min(new_end_coord[-3], ifm_shape.height * upscaling_factor)
+ new_end_coord[-3] = min(new_end_coord[-3], ifm_shape[-3] * upscaling_factor)
pad_top = 0
pad_bottom = 0
@@ -77,7 +74,7 @@ class Box:
if len(new_start_coord) >= 2:
stride = strides[2]
new_start_coord[-2] = max(new_start_coord[-2] * stride - skirt[1], 0)
- new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], ifm_shape.width)
+ new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], ifm_shape[-2])
if len(new_start_coord) >= 3:
stride = strides[1]
@@ -89,20 +86,23 @@ class Box:
pad_top = max(0, 0 - new_start_coord[-3]) + skirt_top_remainder
new_start_coord[-3] = max(new_start_coord[-3], 0)
- if (new_end_coord[-3] * stride + skirt[2]) > (ifm_shape.height * upscaling_factor):
+ while len(ifm_shape) < 3:
+ ifm_shape = [1] + ifm_shape
+
+ if (new_end_coord[-3] * stride + skirt[2]) > (ifm_shape[-3] * upscaling_factor):
# pad_bottom is calculated based the diff between the end position of the weight kernel,
# after last stride and the ifm height.
- if upscaling_factor != 1 and original_end_coord[-3] > ifm_shape.height * upscaling_factor:
+ if upscaling_factor != 1 and original_end_coord[-3] > ifm_shape[-3] * upscaling_factor:
# Special case for Transpose Convolution with VALID padding.
- pad_bottom = original_end_coord[-3] - (ifm_shape.height * upscaling_factor)
+ pad_bottom = original_end_coord[-3] - (ifm_shape[-3] * upscaling_factor)
else:
k_start = new_start_coord[-3] - pad_top
- pad_bottom = max(0, k_start + total_stride + k_height - (ifm_shape.height * upscaling_factor))
+ pad_bottom = max(0, k_start + total_stride + k_height - (ifm_shape[-3] * upscaling_factor))
# Adjust for upscaling
new_start_coord[-3] = max(new_start_coord[-3] // upscaling_factor, 0)
new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor)
- new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape.height), 1)
+ new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape[-3]), 1)
return Box(new_start_coord, new_end_coord), pad_top, pad_bottom
@@ -197,7 +197,7 @@ class NpuStripe(Command):
self.pad_top = pad_top
self.pad_bottom = pad_bottom
for i in range(len(self.ofm_box.end_coord)):
- assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0].get_dim(i)
+ assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0][i]
def is_npu_pass_command(self):
return True
@@ -251,6 +251,76 @@ class NpuStripe(Command):
assert res >= 0
return res
+ def get_single_block_command(self, block_idx):
+ block_cfg = (self.block_config[0], self.block_config[1], self.block_config[3])
+ dims = self.get_block_dimensions()
+ strides = dims[1] * dims[2], dims[2], 1
+ coord = []
+ idx_left = block_idx
+ for s in strides:
+ c = idx_left // s
+ idx_left -= c * s
+ coord.append(c)
+
+ assert idx_left == 0
+
+ # put in dummy height/widths in case we're dealing with FC layers
+ ofm_start = list(self.ofm_box.start_coord)
+ ofm_end = list(self.ofm_box.end_coord)
+
+ # cut out a nice block shape
+ for idx in (-1, -2, -3):
+ if len(ofm_start) >= -idx:
+ ofm_start[idx] += block_cfg[idx] * coord[idx]
+ ofm_end[idx] = min(ofm_end[idx], ofm_start[idx] + block_cfg[idx])
+
+ ps = self.ps
+ strides = None
+ skirt = None
+ if ps.primary_op is not None:
+ strides = ps.primary_op.attrs.get("strides", None)
+ skirt = ps.primary_op.attrs.get("skirt", None)
+ npu_block_type = ps.npu_block_type
+
+ ofm_box = Box(ofm_start, ofm_end)
+ ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
+ strides, skirt, self.ifm_tensor.shape, npu_block_type, self.concat_axis, self.concat_offset
+ )
+
+ weight_box = None
+ if self.weight_tensor is not None:
+ weight_oc_start = ofm_start[-1]
+ weight_oc_end = ofm_end[-1]
+ if self.concat_axis - len(self.weight_tensor.shape) == -1:
+ weight_oc_start -= self.concat_offset
+ weight_oc_end -= self.concat_offset
+
+ weight_box = Box.make_weight_box(
+ self.weight_tensor.shape,
+ npu_block_type,
+ weight_oc_start,
+ weight_oc_end,
+ self.weight_tensor.weight_transpose_depthwise,
+ )
+
+ return NpuStripe(
+ self.ps,
+ self.block_config,
+ self.is_first,
+ self.is_last,
+ self.is_first_h_stripe,
+ self.is_last_h_stripe,
+ self.ifm_tensor,
+ ifm_box,
+ self.ofm_tensor,
+ ofm_box,
+ self.weight_tensor,
+ weight_box,
+ self.scale_tensor,
+ self.concat_axis,
+ self.concat_offset,
+ )
+
class DMA(Command):
def __init__(self, ps, in_tensor, out_tensor, box):
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 60e62aa6..18a419c0 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -27,7 +27,6 @@ from .numeric_util import round_up_divide
from .operation import create_activation_function
from .operation import NpuBlockType
from .operation import Op
-from .shape4d import Shape4D
from .tensor import TensorPurpose
@@ -91,8 +90,8 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
weight_tensor = ps.weight_tensor
scale_tensor = ps.scale_tensor
- ofm_start = [0, 0, 0, 0]
- ofm_end = ofm_shape.as_list()
+ ofm_start = [0] * len(ofm_shape)
+ ofm_end = list(ofm_shape)
strides = None
skirt = None
@@ -101,9 +100,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
strides = ps.primary_op.attrs.get("strides", None)
skirt = ps.primary_op.attrs.get("skirt", None)
if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
- upscaling = ofm_shape.height // ifm_shape.height
+ upscaling = ofm_shape[-3] // ifm_shape[-3]
elif ps.primary_op.type == Op.ResizeBilinear:
- upscaling = round_up_divide(ofm_shape.height, ifm_shape.height)
+ upscaling = round_up_divide(ofm_shape[-3], ifm_shape[-3])
concat_axis = 0
concat_offset = 0
@@ -136,7 +135,14 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
if ifm_shape is not None:
ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
- strides, skirt, ifm_shape, npu_block_type, concat_axis, concat_offset, split_offsets[0], upscaling,
+ strides,
+ skirt,
+ ifm_tensor.shape,
+ npu_block_type,
+ concat_axis,
+ concat_offset,
+ split_offsets[0],
+ upscaling,
)
else:
ifm_box = Box([], [])
@@ -157,7 +163,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
strides,
skirt,
- Shape4D(intermediate.shape),
+ intermediate.shape,
npu_block_type,
concat_axis,
concat_offset,
@@ -206,7 +212,6 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
)
elif strat == SchedulingStrategy.IfmStream:
- assert ifm_shape is not None
y_step = block_config[0]
y_start = ofm_start[-3]
y_dim = ofm_end[-3]
@@ -217,7 +222,8 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
prev_pass_gen = generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx - 1)
else:
ifm_y_present = 1
- ifm_y_present = ifm_shape.height
+ if len(ifm_shape) >= 3:
+ ifm_y_present = ifm_shape[-3]
prev_pass_gen = []
prev_pass = None
@@ -270,7 +276,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id
intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
strides,
skirt,
- Shape4D(intermediate.shape),
+ intermediate.shape,
npu_block_type,
concat_axis,
concat_offset,
@@ -374,13 +380,13 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs):
if cmd.is_npu_pass_command():
if cmd.is_first:
ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
- cmd.ifm_box.start_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=False
+ cmd.ifm_box.start_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=False
)
if ifm_read is None:
return 0
if cmd.is_last:
write_offset = cmd.ofm_tensor.address_offset_for_coordinate(
- cmd.ofm_box.end_coord, cmd.ps.ofm_shapes[0].as_list(), is_top_box=True
+ cmd.ofm_box.end_coord, shape=cmd.ps.ofm_shapes[0], is_top_box=True
)
if write_offset is None:
return 0
@@ -393,7 +399,7 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs):
if cmd.is_first:
ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
- cmd.ifm_box.end_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=True
+ cmd.ifm_box.end_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=True
)
min_overlap = max(min_overlap, 0)
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 07117025..9380374e 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -58,7 +58,6 @@ from .register_command_stream_generator import generate_command_stream
from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
from .register_command_stream_util import to_npu_kernel
from .register_command_stream_util import UNARY_ELEMWISE_OPS
-from .shape4d import Shape4D
from .tensor import MemType
from .tensor import Tensor
from .tensor import TensorBlockTraversal
@@ -232,7 +231,7 @@ def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
-def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: Shape4D) -> NpuFeatureMap:
+def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: List[int]) -> NpuFeatureMap:
"""Creates feature map with common fields populated"""
fm = NpuFeatureMap()
fm.region = get_region(tens, arch)
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index d2c848ad..67925176 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -21,10 +21,8 @@
# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
# Graph - A full neural network graph with one or more Subgraphs.
import enum
-from typing import List
from .operation import Op
-from .shape4d import Shape4D
class PassPlacement(enum.Enum):
@@ -60,8 +58,8 @@ class Pass:
self.name = name
self.cascade = None
self.placement = placement
- self.ifm_shapes: List[Shape4D] = []
- self.ofm_shapes: List[Shape4D] = []
+ self.ifm_shapes = []
+ self.ofm_shapes = []
# TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
# allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 4ca46831..c2ec4424 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -48,7 +48,7 @@ def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_conf
if ps2.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
op = ps2.primary_op
- ifm_block_depth = arch.calc_ifm_block_depth(op.ifm_shapes[0].depth, op.ifm.dtype.size_in_bits())
+ ifm_block_depth = arch.calc_ifm_block_depth(op.ifm_shapes[0][-1], op.ifm.dtype.size_in_bits())
else:
ifm_block_depth = block_config_ps2[-1]
@@ -231,9 +231,9 @@ def estimate_conv_pooling_cycles(
arch.config.ofm_ublock.height == 2
and npu_block_type
in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.VectorProduct)
- and ofm_tens_shape.height == 1
+ and ofm_tens_shape[1] == 1
# Optimisation only applies for even width tensors
- and ofm_tens_shape.width % 2 == 0
+ and ofm_tens_shape[2] % 2 == 0
and kernel_dims[0] == 1
):
ofm_ublock.width = 4
@@ -319,14 +319,14 @@ def estimate_conv_pooling_cycles(
cycles_dpu_blk += delay_cycles
if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
- cycles_dpu_blk *= numeric_util.round_up_divide(ifm_tens_shape.depth, ifm_block.depth)
+ cycles_dpu_blk *= numeric_util.round_up_divide(ifm_tens_shape[3], ifm_block.depth)
cycles_dpu_blk /= arch.ncores
num_ofm_blk = (
- numeric_util.round_up_divide(ofm_tens_shape.height, ofm_block.height)
- * numeric_util.round_up_divide(ofm_tens_shape.width, ofm_block.width)
- * numeric_util.round_up_divide(ofm_tens_shape.depth, ofm_block.depth)
+ numeric_util.round_up_divide(ofm_tens_shape[1], ofm_block.height)
+ * numeric_util.round_up_divide(ofm_tens_shape[2], ofm_block.width)
+ * numeric_util.round_up_divide(ofm_tens_shape[3], ofm_block.depth)
)
cycles_output_blk = estimate_output_cycles(
@@ -336,7 +336,7 @@ def estimate_conv_pooling_cycles(
if scale_tensor:
cycles_bias_blk = (
10
- * min(ofm_block.depth, ofm_tens_shape.depth)
+ * min(ofm_block.depth, ofm_tens_shape[3])
* arch.memory_latency[scale_tensor.mem_area][BandwidthDirection.Read]
/ 256
)
@@ -420,8 +420,8 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
npu_block_type = primary_op.type.npu_block_type
ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
- ifm_tensor_shape = ps.primary_op.ifm_shapes[0].clone()
- ofm_tensor_shape = ps.primary_op.ofm_shapes[0].clone()
+ ifm_tensor_shape = list(ps.primary_op.ifm_shapes[0])
+ ofm_tensor_shape = list(ps.primary_op.ofm_shapes[0])
if npu_block_type == NpuBlockType.ReduceSum:
block_traversal = TensorBlockTraversal.DepthFirst
@@ -434,7 +434,7 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
else:
block_traversal = TensorBlockTraversal.Default
ifm_block_depth = get_ifm_block_depth(
- npu_block_type, ifm_tensor_shape.depth, ifm_tensor.dtype.size_in_bits(), block_traversal, ofm_block.depth
+ npu_block_type, ifm_tensor_shape[3], ifm_tensor.dtype.size_in_bits(), block_traversal, ofm_block.depth
)
ifm_block = arch.get_ifm_block_size(
ifm_block_depth, ofm_block, primary_op.kernel, ifm_resampling_mode=ifm_tensor.resampling_mode
@@ -448,12 +448,11 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
NpuBlockType.ReduceSum,
):
# extent the ifm to full dimension
-
- batch_size = ifm_tensor_shape.batch
+ batch_size = ifm_tensor_shape[0]
# add in padding
- ifm_tensor_shape.height += explicit_padding[0] + explicit_padding[2] # height += top and bottom
- ifm_tensor_shape.width += explicit_padding[1] + explicit_padding[3] # width += left and right
+ ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
+ ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
if npu_block_type != NpuBlockType.Pooling:
if npu_block_type == NpuBlockType.ReduceSum:
@@ -469,9 +468,9 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
nn_ops = (
- int(ofm_tensor_shape.batch)
- * int(ofm_tensor_shape.height)
- * int(ofm_tensor_shape.width)
+ int(ofm_tensor_shape[0])
+ * int(ofm_tensor_shape[1])
+ * int(ofm_tensor_shape[2])
* int(weight_tensor_shape[0])
* int(weight_tensor_shape[1])
* int(weight_tensor_shape[2])
@@ -482,7 +481,7 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
primary_op.attrs["ksize"][1],
primary_op.attrs["ksize"][2],
1,
- ifm_tensor_shape.depth,
+ ifm_tensor_shape[3],
]
weight_tensor_bandwidth_shape = weight_tensor_shape
weight_tensor_element_size = 0
@@ -505,8 +504,8 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth() * ifm_read_multiple
weight_read_multiple = numeric_util.round_up_divide(
- ofm_tensor_shape.height, ofm_block.height
- ) * numeric_util.round_up_divide(ofm_tensor_shape.width, ofm_block.width)
+ ofm_tensor_shape[1], ofm_block.height
+ ) * numeric_util.round_up_divide(ofm_tensor_shape[2], ofm_block.width)
replacement_read_bws[weight_tensor] = (
batch_size
* shape_num_elements(weight_tensor_bandwidth_shape)
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index c80e18b5..be26a26b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -26,7 +26,6 @@ from typing import TYPE_CHECKING
from .errors import VelaError
from .numeric_util import full_shape
-from .shape4d import Shape4D
if TYPE_CHECKING:
@@ -373,7 +372,7 @@ def create_activation_function(op_type: Op) -> ActivationFunction:
return act
-def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True):
+def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
# For strided slice operator: get start or end offsets
offsets = len(input_shape) * [0] if is_begin else input_shape[:]
for idx in range(len(input_shape)):
@@ -428,8 +427,8 @@ class Operation:
self.op_index = None # input network operator index
self.activation_lut = None
self._kernel = None
- self.ifm_shapes: List[Shape4D] = []
- self.ofm_shapes: List[Shape4D] = []
+ self.ifm_shapes = []
+ self.ofm_shapes = []
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -708,9 +707,6 @@ class Operation:
raise VelaError("\n".join(lines))
def set_ifm_ofm_shapes(self):
- self.ifm_shapes = []
- self.ofm_shapes = []
-
ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
# set all shapes to op, as 4D
@@ -720,24 +716,24 @@ class Operation:
batch_size = elms // n_in_elems
assert batch_size * n_in_elems == elms
- self.ifm_shapes.append(Shape4D([batch_size, 1, 1, n_in_elems]))
- self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
+ self.ifm_shapes.append([batch_size, 1, 1, n_in_elems])
+ self.ofm_shapes.append(ofm_tensor.get_full_shape())
elif self.type == Op.Softmax:
- self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
- self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
+ self.ifm_shapes.append(ifm_tensor.get_full_shape())
+ self.ofm_shapes.append(ofm_tensor.get_full_shape())
elif self.type.is_split_op or self.type.is_concat_op():
for inp in self.inputs:
if inp is not None:
- self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
+ self.ifm_shapes.append(full_shape(4, inp.shape, 1))
else:
self.ifm_shapes.append(None)
for out in self.outputs:
if out is not None:
- self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1)))
+ self.ofm_shapes.append(full_shape(4, out.shape, 1))
else:
self.ofm_shapes.append(None)
else:
- self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
+ self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1))
if ifm2_tensor is not None:
- self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
- self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))
+ self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1))
+ self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1))
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 8f6660c2..095a78d4 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -231,9 +231,9 @@ def pack_into_passes(nng, arch, verbose_packing=False):
ofm_tensor = op.ofm
if ofm_tensor is None:
ofm_tensor = op.outputs[0]
- build_pass((op,), ofm_tensor, op.ofm_shapes[0].clone())
+ build_pass((op,), ofm_tensor)
- def build_pass(start_ops_to_process, ofm_tensor=None, ofm_shapes=None):
+ def build_pass(start_ops_to_process, ofm_tensor=None):
reverse_ops_list = []
curr_flags = PassFlags.Empty
npu_block_type = NpuBlockType.Default
@@ -416,7 +416,8 @@ def pack_into_passes(nng, arch, verbose_packing=False):
ps.ifm_shapes.append(ps.primary_op.ifm_shapes[0])
ps.ofm_tensor = ofm_tensor
- ps.ofm_shapes.append(ofm_shapes)
+ if ps.primary_op is not None:
+ ps.ofm_shapes.append(ps.primary_op.ofm_shapes[0])
assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None
ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
@@ -452,11 +453,11 @@ def pack_into_passes(nng, arch, verbose_packing=False):
avgpool_out = inp.clone("_avgpooled")
avgpool_out.consumer_list.append(op)
avgpool_op.set_output_tensor(avgpool_out)
- avgpool_op.set_ifm_ofm_shapes()
+ avgpool_op.ifm_shapes = op.ifm_shapes
+ avgpool_op.ofm_shapes = op.ofm_shapes
op.inputs[0] = avgpool_out
op_list.insert(0, avgpool_op)
- op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, avgpool_op)
return avgpool_op
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
deleted file mode 100644
index a1b4feaa..00000000
--- a/ethosu/vela/shape4d.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
-#
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the License); you may
-# not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an AS IS BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Description:
-# Defines the class Shape4D.
-from .numeric_util import full_shape
-
-
-class Shape4D:
- """
- 4D Shape (in NHWC format)
- """
-
- def __init__(self, shape, base=1):
- assert shape is not None
- assert len(shape) <= 4
- self._shape4D = tuple(full_shape(4, shape, base))
-
- def __str__(self):
- return f"<Shape4D {self.as_list()}>"
-
- def __eq__(self, other):
- return self._shape4D == other._shape4D
-
- def clone(self):
- return Shape4D(self.as_list())
-
- @property
- def batch(self):
- return self._shape4D[0]
-
- @property
- def height(self):
- return self._shape4D[1]
-
- @property
- def width(self):
- return self._shape4D[2]
-
- @property
- def depth(self):
- return self._shape4D[3]
-
- @batch.setter
- def batch(self, new_batch):
- self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3])
-
- @height.setter
- def height(self, new_height):
- self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3])
-
- @width.setter
- def width(self, new_width):
- self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3])
-
- @depth.setter
- def depth(self, new_depth):
- self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth)
-
- def get_dim(self, dim):
- assert -4 <= dim < 4
- return self._shape4D[dim]
-
- def as_list(self):
- return list(self._shape4D)
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index d8faf369..1f027d60 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -32,7 +32,6 @@ from .operation import Kernel
from .operation import NpuBlockType
from .range_set import MemoryRangeSet
from .register_command_stream_util import to_kernel
-from .shape4d import Shape4D
from .tensor import MemArea
@@ -196,14 +195,14 @@ def shared_buffer_allocation_for_pass(arch, ps) -> SharedBufferAllocation:
ifm_bits = ifm_tensor.dtype.size_in_bits()
ifm_shape = ps.primary_op.ifm_shapes[0]
- if ifm_tensor.shape != []:
- ifm_depth = ifm_shape.depth
+ if ifm_shape != []:
+ ifm_depth = ifm_shape[-1]
if is_elementwise:
ifm_count = 2
if ifm_tensor.shape == []: # Scalar in ifm1
assert ifm2_tensor
- ifm_depth = ps.primary_op.ifm_shapes[1].depth
+ ifm_depth = ps.primary_op.ifm_shapes[1][-1]
ifm_count = 1
elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2
ifm_count = 1
@@ -252,7 +251,7 @@ def shared_buffer_allocation_for_npu_op(
ifm_bits=ifm_bits,
ifm_depth=ifm_depth,
ifm_count=ifm_count,
- ofm_shape=Shape4D(ofm_shape),
+ ofm_shape=ofm_shape,
)
@@ -266,9 +265,14 @@ def find_suitable_block_configs(arch, alloc: SharedBufferAllocation) -> List[Tup
# Constrain the search space if the OFM is smaller than the max block size
# - Add other block search constraints here if required
- max_block_width = alloc.ofm_shape.width
- max_block_height = alloc.ofm_shape.height
- max_block_depth = alloc.ofm_shape.depth
+ if len(alloc.ofm_shape) <= 2:
+ max_block_height = max_block_width = alloc.ofm_shape[0]
+ else:
+ max_block_width = alloc.ofm_shape[-2]
+ max_block_height = alloc.ofm_shape[-3]
+
+ # Common block depth
+ max_block_depth = alloc.ofm_shape[-1]
# Constrain to valid ranges before search
max_block_width = min(arch.ofm_block_max.width, max_block_width)
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 3b4bace9..98496539 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -213,7 +213,7 @@ class SoftMax:
ofm = self.op.outputs[0]
# Reshape ifm/ofm (if needed)
- full_shape = self.op.ifm_shapes[0].as_list()
+ full_shape = self.op.ifm_shapes[0]
if full_shape[0] > 1:
full_shape[1] *= full_shape[0]
full_shape[0] = 1
@@ -414,7 +414,6 @@ class SoftMax:
shr30_op.add_input_tensor(scaled_exp)
shr30_op.add_input_tensor(right_shift)
shr30_op.set_output_tensor(ofm)
- shr30_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(self.op, shr30_op)
return shr30_op
@@ -536,7 +535,6 @@ class SoftMax:
shr13_op.add_input_tensor(mul_ofm)
shr13_op.add_input_tensor(reciprocal_right_shift)
shr13_op.set_output_tensor(ofm)
- shr13_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(self.op, shr13_op)
return shr13_op
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 093e8771..df8f8868 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -40,7 +40,6 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .numeric_util import full_shape
from .operation import Op
from .operation import Operation
-from .shape4d import Shape4D
Shape = List
@@ -305,7 +304,6 @@ def create_const_tensor(
# Operator
const_op = Operation(Op.Const, name)
const_op.set_output_tensor(const_tensor)
- const_op.set_ifm_ofm_shapes()
return const_tensor
@@ -325,7 +323,8 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
reshape_op.add_input_tensor(reshape_ifm)
reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
reshape_op.set_output_tensor(reshape_ofm)
- reshape_op.set_ifm_ofm_shapes()
+ reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1))
+ reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1))
return reshape_ofm if ifm_reshape else reshape_ifm
@@ -609,7 +608,7 @@ class Tensor:
def consumers(self) -> List[Operation]:
return self.consumer_list
- def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple:
+ def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple:
# returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
if self.storage_shape == []:
@@ -617,7 +616,7 @@ class Tensor:
1,
1,
1,
- [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None],
+ [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None],
)
storage_shape_4D = full_shape(4, self.storage_shape, 1)
@@ -631,20 +630,20 @@ class Tensor:
box_width = crossing_x - start_coord[2]
addresses: List = [None] * 4
- addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list())
+ addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape)
if end_coord[2] > crossing_x:
addresses[1] = self.address_for_coordinate(
- [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list()
+ [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape
)
raise UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
addresses[2] = self.address_for_coordinate(
- [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list()
+ [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape
)
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
addresses[3] = self.address_for_coordinate(
- [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list()
+ [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape
)
return box_height0, box_height0, box_width, addresses
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 7fdc4bd8..45377417 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -21,7 +21,6 @@ import numpy as np
from ethosu.vela.graph_optimiser import convert_batched_fc_shape
from ethosu.vela.operation import Op
from ethosu.vela.tensor import create_const_tensor
-from ethosu.vela.tensor import Shape4D
from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
@@ -36,8 +35,8 @@ def test_convert_batched_fc():
ifm.consumer_list.append(op)
- op.ifm_shapes.append(Shape4D([4, 1, 1, 8]))
- op.ofm_shapes.append(Shape4D([4, 1, 1, 8]))
+ op.ifm_shapes.append([4, 1, 1, 8])
+ op.ofm_shapes.append([4, 1, 1, 8])
prev_op = op.clone()
prev_op.ifm_shapes = op.ifm_shapes
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 973b820d..583821a2 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -62,7 +62,7 @@ def test_constraint_tens_input_scalar():
def test_constraint_tens_shape_size():
# Tensors cannot be > 4D
- op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8])
assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index c3459501..63f841b4 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -75,7 +75,7 @@ def create_elemwise_op(
def create_op_with_quant_tensors(
- op_type, ifm_shape, ofm_shape, weights_shape=None, bias_shape=None, datatype=DataType.uint8, set_ifm_ofm_shapes=True
+ op_type, ifm_shape, ofm_shape, weights_shape=None, bias_shape=None, datatype=DataType.uint8
):
ifm = Tensor(ifm_shape, datatype, "in")
ifm.quantization = default_quant_params()
@@ -107,9 +107,7 @@ def create_op_with_quant_tensors(
bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), np.int32, quantization=qp)
op.add_input_tensor(bias)
- if set_ifm_ofm_shapes:
- op.set_ifm_ofm_shapes()
-
+ op.set_ifm_ofm_shapes()
return op