aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-01 12:43:02 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-07 15:35:49 +0200
commitf1580f0167d7e9539a17ac8e33b0b595300f8090 (patch)
tree2fe1d9d4715ac38be9cfdc5fbe049c07ab5b9563
parent94292fe32c34357a8935a42c77b759a499eb0db9 (diff)
downloadethos-u-vela-f1580f0167d7e9539a17ac8e33b0b595300f8090.tar.gz
TOSA: Added RESHAPE, SLICE and CONCAT
Added support for Data layout ops RESHAPE, SLICE and CONCAT. -No support for bool_t -Support limited to Rank <= 4 and N = 1 Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I487ac494b6506a2a6ba947ee758aa193194dd796
-rw-r--r--ethosu/vela/graph_optimiser_util.py23
-rw-r--r--ethosu/vela/operation_util.py6
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py20
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py165
-rw-r--r--ethosu/vela/tosa_mapping.py10
-rw-r--r--ethosu/vela/tosa_reader.py4
-rw-r--r--ethosu/vela/tosa_supported_operators.py6
7 files changed, 204 insertions, 30 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index d01d4a1..8095f08 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -212,6 +212,29 @@ def bypass_reshape_and_squeeze_ops(op):
cons.set_input_tensor(ifm, ifm_idx)
+def move_splitsliceread_to_consumer(op, cons_op):
+ assert op.type == Op.SplitSliceRead
+
+ if cons_op.ifm == op.ofm:
+ cons_op.read_offsets[0] = op.read_offsets[0]
+ cons_op.read_shapes[0] = op.read_shapes[0]
+ cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
+ cons_op.ifm_shapes[0] = op.ifm_shapes[0]
+ elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
+ cons_op.read_offsets[1] = op.read_offsets[0]
+ cons_op.read_shapes[1] = op.read_shapes[0]
+ cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
+ cons_op.ifm_shapes[1] = op.ifm_shapes[0]
+
+ if "skirt" in cons_op.attrs:
+ assert cons_op.attrs["explicit_padding"] == cons_op.attrs["skirt"]
+ cons_op.attrs["skirt"] = None
+ cons_op.attrs["force_padding"] = True
+ op.ofm.consumer_list.remove(cons_op)
+ op.ofm.ops = []
+ op.ifm.consumer_list.remove(op)
+
+
def check_reshapes(op, arch):
if op.run_on_npu and op.type == Op.Reshape:
ofm = op.ofm
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 4a4fd33..0fbed46 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -44,6 +44,12 @@ def create_avgpool_nop(name: str) -> Operation:
return op
+def create_add_nop(name: str) -> Operation:
+ op = Operation(Op.Add, name)
+ op.run_on_npu = True
+ return op
+
+
def create_depthwise_maxpool(
name: str,
ifm: Tensor,
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index ef39aea..7526f46 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -34,6 +34,7 @@ from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import convert_depthwise_to_conv
from .graph_optimiser_util import fix_sg_input_output
+from .graph_optimiser_util import move_splitsliceread_to_consumer
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
@@ -193,24 +194,7 @@ def remove_SplitSliceRead(op, arch):
):
# SplitSliceRead can be performed by tensor consumer
cons_op = op.ofm.consumer_list[0]
- if cons_op.ifm == op.ofm:
- cons_op.read_offsets[0] = op.read_offsets[0]
- cons_op.read_shapes[0] = op.read_shapes[0]
- cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
- cons_op.ifm_shapes[0] = op.ifm_shapes[0]
- elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
- cons_op.read_offsets[1] = op.read_offsets[0]
- cons_op.read_shapes[1] = op.read_shapes[0]
- cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
- cons_op.ifm_shapes[1] = op.ifm_shapes[0]
-
- if "skirt" in cons_op.attrs:
- assert cons_op.attrs["explicit_padding"] == cons_op.attrs["skirt"]
- cons_op.attrs["skirt"] = None
- cons_op.attrs["force_padding"] = True
- op.ofm.consumer_list.remove(cons_op)
- op.ofm.ops = []
- op.ifm.consumer_list.remove(op)
+ move_splitsliceread_to_consumer(op, cons_op)
else:
avgpool_op = create_avgpool_nop(op.name + "_avgpool")
avgpool_op.add_input_tensor(op.ifm)
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 169da40..f3cddad 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -22,14 +22,17 @@ from .debug_database import DebugDatabase
from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import convert_depthwise_to_conv
-from .graph_optimiser_util import fix_sg_input_output
+from .graph_optimiser_util import move_splitsliceread_to_consumer
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
from .operation import ExplicitScaling
from .operation import NpuBlockType
from .operation import Op
+from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
+from .shape4d import Shape4D
+from .tensor import create_const_tensor
def replace_rescale_with_avg_pool(rescale_op):
@@ -103,12 +106,157 @@ def remove_const_transpose(op, arch, nng):
removed = True
if not removed:
- print("Cannot remove Transpose, and handling of Transpose is not supported")
+ print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
assert False
return op
+# TODO can we change to add for both TFLite and TOSA?
+def insert_add_copy_op_after_tens(tens):
+ tens_cons_list_copy = tens.consumer_list.copy()
+ copy_tens = tens.clone()
+
+ name = tens.name + "_add"
+ ifm2 = create_const_tensor(
+ name + "_zero_scalar",
+ [1],
+ copy_tens.dtype,
+ [0],
+ copy_tens.dtype.as_numpy_type(),
+ quantization=copy_tens.quantization,
+ )
+ copy_op = create_add_nop(name)
+ copy_op.add_input_tensor(tens)
+ copy_op.add_input_tensor(ifm2)
+ copy_op.set_output_tensor(copy_tens)
+ copy_op.set_ifm_ofm_shapes()
+ copy_op.run_on_npu = True
+
+ # Set copy_ifm consumers
+ for tens_cons in tens_cons_list_copy:
+ if tens_cons is not None:
+ for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
+ if cons_inp == tens:
+ tens_cons.set_input_tensor(copy_tens, ifm_idx)
+
+ DebugDatabase.add_optimised(tens.ops[0], copy_op)
+
+
+def fix_sg_input_output_tosa(op, arch, nng):
+ if not op.run_on_npu or op.type != Op.Reshape:
+ return op
+
+ # For the Reshape operators we want to remove, tensors are removed.
+ # But in order to to do this, they cannot be outputs of the sg,
+ # this need to be fixed prior to the removal.
+ # Solution is to add a copy op, to maintain the original tensor.
+ # This is also valid when reshape ifm/ofm is produced respectively
+ # consumed by CPU
+
+ # Check if operator ifm/ofm are sg ifm/ofm
+ ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+ ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+ ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+ # Check if ifm/ofm is produced repectivly consumed by CPU
+ ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
+ if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
+ insert_add_copy_op_after_tens(op.ifm)
+
+ return op
+
+
+def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
+ """Creates an add op for the given concat op/input feature map"""
+ ofm = concat_op.ofm
+ ifm2 = create_const_tensor(
+ name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
+ )
+ add_op = create_add_nop(name)
+
+ add_op.inputs = [ifm, ifm2]
+ add_op.outputs = [ofm]
+ add_op.write_offset = write_offset
+ add_op.write_shape = ifm_shape
+ ofm.ops.append(add_op)
+ DebugDatabase.add_optimised(concat_op, add_op)
+ add_op.ifm_shapes.append(ifm_shape)
+ add_op.ifm_shapes.append(Shape4D(ifm2.shape))
+ add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
+ add_op.memory_function = Op.ConcatSliceWrite
+ return add_op
+
+
+# TODO Could be further optimized checking the type of the consumer,
+# rather than just mimic the TFLite behaviour depending on type.
+# TOSA bool_t not considered yet
+def remove_splitsliceread(op, arch):
+
+ if op.type == Op.SplitSliceRead:
+ # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
+ if (
+ len(op.ofm.consumer_list) == 1
+ and op.ofm.consumer_list[0] is not None
+ and op.ofm.consumer_list[0].run_on_npu
+ and op.ofm.consumer_list[0].type != Op.Reshape
+ and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+ and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
+ ):
+ # SplitSliceRead can be performed by tensor consumer
+ cons_op = op.ofm.consumer_list[0]
+ move_splitsliceread_to_consumer(op, cons_op)
+ else:
+ name = op.name + "_add"
+ ofm = op.ofm
+ ifm2 = create_const_tensor(
+ name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
+ )
+ add_op = create_add_nop(name)
+ add_op.inputs = [op.ifm, ifm2]
+ add_op.outputs = [ofm]
+ op.ofm.ops.remove(op)
+ op.ofm.ops.append(add_op)
+ add_op.ifm_shapes.append(op.ifm_shapes[0])
+ add_op.ifm_shapes.append(Shape4D(ifm2.shape))
+ add_op.ofm_shapes.append(op.ofm_shapes[0])
+ add_op.read_offsets[0] = op.read_offsets[0]
+ add_op.read_shapes[0] = op.read_shapes[0]
+
+ op.ifm.consumer_list.remove(op)
+ DebugDatabase.add_optimised(op, add_op)
+
+
+def rewrite_concat_ops(op, arch):
+ if not op.run_on_npu or not op.type == Op.Concat:
+ return
+
+ axis_4D = 0
+ ofm = op.ofm
+ ofm.ops = []
+ offset = 0
+
+ inputs = op.inputs
+ axis = op.attrs["axis"]
+
+ for idx, inp in enumerate(inputs):
+ op.ifm_shapes[idx] = Shape4D(inp.shape)
+ if axis >= 0:
+ axis_4D = axis + (4 - len(inp.shape))
+ else:
+ axis_4D = axis
+ write_offset = [0, 0, 0, 0]
+ write_offset[axis_4D] = offset
+ concat_end = offset + op.ifm_shapes[idx][axis_4D]
+ create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
+ offset = concat_end
+ assert ofm.shape[axis] == offset
+
+ return op
+
+
def remove_reshapes(op, arch):
if op.run_on_npu and op.type == Op.Reshape:
bypass_reshape_and_squeeze_ops(op)
@@ -271,9 +419,14 @@ def tosa_optimise_graph(nng, arch):
# Handle sg input output
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
+ nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
)
+ # Rewrite concat ops
+ for idx, sg in enumerate(nng.subgraphs):
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
+ sg.refresh_after_modification()
+
# Removal of reshapes
for sg in nng.subgraphs:
rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
@@ -293,6 +446,12 @@ def tosa_optimise_graph(nng, arch):
nng, sg, arch, [], [rewrite_activation, add_padding_fields],
)
+ # Removal of Slice, need to be done after optimisation has been performed,
+ # since ifm/ofm_shapes are of importance to this function
+ for sg in nng.subgraphs:
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
+ sg.refresh_after_modification()
+
# Post-processing step 2
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 377f455..6efc479 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -154,7 +154,7 @@ transpose_conv2d_attrs = AttrSerializer(
"TransposeConv2dAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec))
)
relun_attrs = AttrSerializer("ReluNAttribute", ("max_int"))
-axis_attrs = AttrSerializer("AxisAttribute", ("axis"))
+axis_attrs = AttrSerializer("AxisAttribute", ("axis",))
reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),))
slice_attrs = AttrSerializer("SliceAttribute", (("begin", is_vec), ("size", is_vec)))
tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),))
@@ -218,10 +218,8 @@ unsupported_tosa_operators = {
TosaOp.REDUCE_MIN,
TosaOp.REDUCE_PRODUCT,
TosaOp.REDUCE_SUM,
- TosaOp.CONCAT,
TosaOp.PAD,
TosaOp.REVERSE,
- TosaOp.SLICE,
TosaOp.TILE,
TosaOp.GATHER,
TosaOp.SCATTER,
@@ -241,7 +239,7 @@ TOSA_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
TOSA_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
# TOSA_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
# TOSA_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
-# TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], [])
+TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], [])
# TOSA_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
# TOSA_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
@@ -299,11 +297,11 @@ tosa_operator_map = {
# TODO TosaOp.REDUCE_MIN
# TODO TosaOp.REDUCE_PRODUCT
# TODO TosaOp.REDUCE_SUM
- # TODO TosaOp.CONCAT
+ TosaOp.CONCAT: (Op.Concat, axis_attrs, None, TOSA_CONCAT_INDICES),
# TODO TosaOp.PAD
TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.REVERSE
- # TODO TosaOp.SLICE
+ TosaOp.SLICE: (Op.SplitSliceRead, slice_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.TILE
TosaOp.TRANSPOSE: (Op.Transpose, None, None, TOSA_IFM_IFM2_INDICES),
# TODO TosaOp.GATHER
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 2925ab4..94ba350 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -29,6 +29,7 @@ from .reader_util import align_tensor_indices_to_nng
from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
+from .shape4d import Shape4D
from .tensor import QuantizationParameters
from .tensor import shape_num_elements
from .tensor import Tensor
@@ -186,6 +187,9 @@ class TosaSubgraph:
op.rescale = [1, shift]
if op.type.is_depthwise_conv2d_op():
op.attrs["depth_multiplier"] = op.weights.shape[3]
+ if op.type == Op.SplitSliceRead:
+ op.read_offsets[0] = Shape4D.from_list(list(op.attrs["begin"]), 0)
+ op.read_shapes[0] = op.attrs["size"]
elif op.type == Op.Transpose:
op.attrs["perms"] = perms.values
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index d7a1ebc..c619f2f 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -38,7 +38,7 @@ class TosaSupportedOperators:
fc_vector_products = set((Op.FullyConnected,))
mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
- memory_only_ops = set((Op.Reshape, Op.Transpose,))
+ memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,))
binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,))
@@ -53,7 +53,7 @@ class TosaSupportedOperators:
# Supported data types
# TODO will differ compared to TensorFlow Lite, currently set to the same
- supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+ supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32)) # TODO add bool
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -109,7 +109,7 @@ class TosaSupportedOperators:
valid = op.ifm.ops and op.ifm.ops[0].type == Op.Const
return valid, "Op has ifm with non-constant data"
- # TODO duplicates tflite_supported operators, but support for depth multiplier should be added at a later stage
+ # TODO duplicates TFLite_supported operators, but support for depth multiplier should be added at a later stage
@staticmethod
def constraint_depth_multiplier(op):
"For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier"