aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonas Ohlsson <jonas.ohlsson@arm.com>2021-09-01 15:57:21 +0200
committerJonas Ohlsson <jonas.ohlsson@arm.com>2021-09-15 10:48:08 +0100
commit0957e3ef4b94f17efb67429c88bab8ba650f78e8 (patch)
treef3e2600367bc7c89145657023b45b9dde2c316c2
parent1a7527cd4ad56b49f120b10dc5e87a1e8f5a8122 (diff)
downloadethos-u-vela-0957e3ef4b94f17efb67429c88bab8ba650f78e8.tar.gz
MLBEDSW-5102 Update removal of memory only operators
Memory only operators such as Reshape, Squeeze and ExpandDims are removed in the graph optimiser step. - Added semantic check that memory only operators have same quantisation parameters on ifm/ofm. - Added support for the ExpandDims operator. - Addition and cleanup of related unit tests. - Removed TOSA from the generated SUPPORTED_OPS.md documentation. Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com> Change-Id: If848d8afc58c18806e10997ed94e4dae83f30879
-rw-r--r--SUPPORTED_OPS.md76
-rw-r--r--ethosu/vela/graph_optimiser.py4
-rw-r--r--ethosu/vela/graph_optimiser_util.py25
-rw-r--r--ethosu/vela/mark_tensors.py5
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py205
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py54
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py46
-rw-r--r--ethosu/vela/tflite_model_semantic.py15
-rw-r--r--ethosu/vela/tflite_supported_operators.py2
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py4
-rw-r--r--ethosu/vela/vela.py10
11 files changed, 283 insertions, 163 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 9c2a9f4e..f96bd4af 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,14 +1,13 @@
# Supported Ops
This file was automatically generated by Vela using the `--supported-ops-report` parameter.
-Vela version: `3.1.0rc2.dev6+g4f87092`
+Vela version: `3.1.1.dev13+gf54e94a.d20210914`
This file complies with
[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
Summary table of constraints for:
- [TFLite](#tflite-summary-table)
-- [TOSA](#tosa-summary-table)
## TFLite Summary Table
@@ -25,6 +24,7 @@ Please check the supported operator list for your chosen runtime for further inf
| CONCATENATION | [Generic](#tflite-generic-constraints), [Specific](#tflite-concatenation-constraints) |
| CONV_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-conv_2d-constraints) |
| DEPTHWISE_CONV_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-depthwise_conv_2d-constraints) |
+| EXPAND_DIMS | [Generic](#tflite-generic-constraints), [Specific](#tflite-expand_dims-constraints) |
| FULLY_CONNECTED | [Generic](#tflite-generic-constraints), [Specific](#tflite-fully_connected-constraints) |
| HARD_SWISH | [Generic](#tflite-generic-constraints), [Specific](#tflite-hard_swish-constraints) |
| LEAKY_RELU | [Generic](#tflite-generic-constraints), [Specific](#tflite-leaky_relu-constraints) |
@@ -40,13 +40,13 @@ Please check the supported operator list for your chosen runtime for further inf
| RELU | [Generic](#tflite-generic-constraints) |
| RELU6 | [Generic](#tflite-generic-constraints) |
| RELU_N1_TO_1 | [Generic](#tflite-generic-constraints) |
-| RESHAPE | [Generic](#tflite-generic-constraints) |
+| RESHAPE | [Generic](#tflite-generic-constraints), [Specific](#tflite-reshape-constraints) |
| RESIZE_BILINEAR | [Generic](#tflite-generic-constraints), [Specific](#tflite-resize_bilinear-constraints) |
| SLICE | [Generic](#tflite-generic-constraints) |
| SOFTMAX | [Generic](#tflite-generic-constraints), [Specific](#tflite-softmax-constraints) |
| SPLIT | [Generic](#tflite-generic-constraints) |
| SPLIT_V | [Generic](#tflite-generic-constraints), [Specific](#tflite-split_v-constraints) |
-| SQUEEZE | [Generic](#tflite-generic-constraints) |
+| SQUEEZE | [Generic](#tflite-generic-constraints), [Specific](#tflite-squeeze-constraints) |
| STRIDED_SLICE | [Generic](#tflite-generic-constraints), [Specific](#tflite-strided_slice-constraints) |
| SUB | [Generic](#tflite-generic-constraints), [Specific](#tflite-sub-constraints) |
| TANH | [Generic](#tflite-generic-constraints) |
@@ -60,7 +60,7 @@ This is a list of constraints that all NPU operators must satisfy in order to be
- Input(s) and Output tensors must not be dynamic
- Input(s) and Output tensors must have a defined shape
- Output tensors cannot be scalar
-- Scalar Input tensors are only valid for op type: ADD, MAXIMUM, MEAN, MINIMUM, MUL, SPLIT, SPLIT_V, SUB
+- Scalar Input tensors are only valid for op type: ADD, EXPAND_DIMS, MAXIMUM, MEAN, MINIMUM, MUL, SPLIT, SPLIT_V, SUB
- Input(s) and Output tensors must not be greater than 4D
- Input(s), Output and Weight tensors must have quantization parameters
- Input(s), Output and Weight tensors with quantization scales must be finite
@@ -148,6 +148,12 @@ This is a list of constraints that the DEPTHWISE_CONV_2D operator must satisfy i
- IFM Tensor batch size must be 1
- For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier
+### TFLite EXPAND_DIMS Constraints
+
+This is a list of constraints that the EXPAND_DIMS operator must satisfy in order to be scheduled on the NPU.
+
+- Input and output quantisation must match.
+
### TFLite FULLY_CONNECTED Constraints
This is a list of constraints that the FULLY_CONNECTED operator must satisfy in order to be scheduled on the NPU.
@@ -244,6 +250,12 @@ This is a list of constraints that the PAD operator must satisfy in order to be
- The pad tensor can only pad width and height
- Pad tensor must be of type: int32, int64
+### TFLite RESHAPE Constraints
+
+This is a list of constraints that the RESHAPE operator must satisfy in order to be scheduled on the NPU.
+
+- Input and output quantisation must match.
+
### TFLite RESIZE_BILINEAR Constraints
This is a list of constraints that the RESIZE_BILINEAR operator must satisfy in order to be scheduled on the NPU.
@@ -268,6 +280,12 @@ This is a list of constraints that the SPLIT_V operator must satisfy in order to
- Only one size is allowed to be inferred
+### TFLite SQUEEZE Constraints
+
+This is a list of constraints that the SQUEEZE operator must satisfy in order to be scheduled on the NPU.
+
+- Input and output quantisation must match.
+
### TFLite STRIDED_SLICE Constraints
This is a list of constraints that the STRIDED_SLICE operator must satisfy in order to be scheduled on the NPU.
@@ -310,51 +328,3 @@ This is a list of constraints that the TRANSPOSE_CONV operator must satisfy in o
- SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride
- VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
minus difference between kernel size and stride
-
-## TOSA Summary Table
-
-The table below contains TOSA operators that can be placed on the Ethos-U NPU.
-Note: There is limited support for compiling a TOSA neural network (EXPERIMENTAL).
-The related constraints have not yet been populated in the list.
-
-| Operator | TOSA Constraints |
-| --- | --- |
-| ABS | [Generic](#tosa-generic-constraints) |
-| ADD | [Generic](#tosa-generic-constraints) |
-| AVERAGE_POOL_2D | [Generic](#tosa-generic-constraints) |
-| CONCATENATION | [Generic](#tosa-generic-constraints) |
-| CONV_2D | [Generic](#tosa-generic-constraints) |
-| DEPTHWISE_CONV_2D | [Generic](#tosa-generic-constraints) |
-| FULLY_CONNECTED | [Generic](#tosa-generic-constraints) |
-| HARD_SWISH | [Generic](#tosa-generic-constraints) |
-| LEAKY_RELU | [Generic](#tosa-generic-constraints) |
-| LOGISTIC | [Generic](#tosa-generic-constraints) |
-| MAXIMUM | [Generic](#tosa-generic-constraints) |
-| MAX_POOL_2D | [Generic](#tosa-generic-constraints) |
-| MEAN | [Generic](#tosa-generic-constraints) |
-| MINIMUM | [Generic](#tosa-generic-constraints) |
-| MUL | [Generic](#tosa-generic-constraints) |
-| PACK | [Generic](#tosa-generic-constraints) |
-| PAD | [Generic](#tosa-generic-constraints) |
-| QUANTIZE | [Generic](#tosa-generic-constraints) |
-| RELU | [Generic](#tosa-generic-constraints) |
-| RELU6 | [Generic](#tosa-generic-constraints) |
-| RELU_N1_TO_1 | [Generic](#tosa-generic-constraints) |
-| RESHAPE | [Generic](#tosa-generic-constraints) |
-| RESIZE_BILINEAR | [Generic](#tosa-generic-constraints) |
-| SLICE | [Generic](#tosa-generic-constraints) |
-| SOFTMAX | [Generic](#tosa-generic-constraints) |
-| SPLIT | [Generic](#tosa-generic-constraints) |
-| SPLIT_V | [Generic](#tosa-generic-constraints) |
-| SQUEEZE | [Generic](#tosa-generic-constraints) |
-| STRIDED_SLICE | [Generic](#tosa-generic-constraints) |
-| SUB | [Generic](#tosa-generic-constraints) |
-| TANH | [Generic](#tosa-generic-constraints) |
-| TRANSPOSE_CONV | [Generic](#tosa-generic-constraints) |
-| UNPACK | [Generic](#tosa-generic-constraints) |
-
-### TOSA Generic Constraints
-
-This is a list of constraints that all NPU operators must satisfy in order to be scheduled on the NPU.
-
-- Tensors must be of type: int16, int32, int8, uint8
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 87e3bc8d..0f5636ed 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -17,7 +17,7 @@
# Early optimisation of the network graph, using the rewrite_graph module to do the traversal of the graph.
from . import rewrite_graph
from .graph_optimiser_util import check_format_restrictions
-from .graph_optimiser_util import check_reshapes
+from .graph_optimiser_util import check_memory_only_removed
from .graph_optimiser_util import record_optimised
from .nn_graph import NetworkType
from .tflite_graph_optimiser import tflite_optimise_graph
@@ -38,7 +38,7 @@ def optimise_graph(nng, arch, network_type, verbose_graph=False):
# Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
for sg in nng.subgraphs:
rewrite_graph.visit_graph_post_order(
- sg.output_tensors, arch, [check_format_restrictions], [check_reshapes, record_optimised]
+ sg.output_tensors, arch, [check_format_restrictions], [check_memory_only_removed, record_optimised]
)
if verbose_graph:
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 8095f082..dafd2849 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -26,11 +26,12 @@ from .errors import VelaError
from .operation import Op
from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
-from .tensor import check_quantized_tens_scaling_equal
memory_only_ops = (
Op.Reshape,
+ Op.QuantizedReshape,
Op.Squeeze,
+ Op.ExpandDims,
)
@@ -177,10 +178,11 @@ def set_ifm_ofm_op_shapes(op, arch, nng):
return op
-def bypass_reshape_and_squeeze_ops(op):
- assert op.type in (Op.Reshape, Op.Squeeze)
+def bypass_memory_only_ops(op):
+ assert op.type in memory_only_ops
ofm = op.ofm
ifm = op.ifm
+
# Check if ifm/ofm are network ifm/ofm
ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
@@ -235,13 +237,10 @@ def move_splitsliceread_to_consumer(op, cons_op):
op.ifm.consumer_list.remove(op)
-def check_reshapes(op, arch):
- if op.run_on_npu and op.type == Op.Reshape:
- ofm = op.ofm
-
- if check_quantized_tens_scaling_equal(op.ifm, ofm):
- # Reshape should have been removed
- raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
+def check_memory_only_removed(op, arch):
+ if op.run_on_npu and op.type in memory_only_ops:
+ # Memory only operators should have been removed
+ raise VelaError(f"Memory only {op.type} op {op} expected to have been removed, still remains")
def record_optimised(op, arch):
@@ -271,10 +270,10 @@ def insert_copy_op_after_tens(tens):
def fix_sg_input_output(op, arch, nng):
- if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
+ if not op.run_on_npu or op.type not in memory_only_ops:
return op
- # For the Reshape/Squeeze operators we want to remove, tensors are removed.
+ # For the memory only operators we want to remove, tensors are removed.
# But in order to to do this, they cannot be outputs of the sg,
# this need to be fixed prior to the removal.
# Solution is to add a avgpool NOP, to maintain the original tensor.
@@ -290,7 +289,7 @@ def fix_sg_input_output(op, arch, nng):
ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
- # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator.
insert_copy_op_after_tens(op.ifm)
return op
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index f3d5e855..f76c59d7 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -15,6 +15,7 @@
# limitations under the License.
# Description:
# Mark purpose and select formats for Tensors.
+from .graph_optimiser_util import memory_only_ops
from .operation import CustomType
from .operation import Op
from .rewrite_graph import visit_graph_post_order
@@ -72,8 +73,8 @@ def rewrite_mark_tensor_purpose(op, arch):
else:
purpose = TensorPurpose.FeatureMap
mark_purpose(tens, arch, purpose)
- if op.type == Op.Reshape:
- # Reshape's input and output point to same data
+ if op.type in memory_only_ops:
+ # Memory only operator input and output point to same data
op.ofm.mem_area = op.ifm.mem_area
if op.type == Op.Custom and op.attrs.get("custom_type") == CustomType.ExistingNpuOp:
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 2f965724..b8655c97 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -21,7 +21,6 @@ import pytest
from ethosu.vela.data_type import DataType
from ethosu.vela.graph_optimiser import optimise_graph
-from ethosu.vela.nn_graph import Graph
from ethosu.vela.nn_graph import NetworkType
from ethosu.vela.operation import Op
from ethosu.vela.operation import Padding
@@ -323,27 +322,23 @@ def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
assert pool_op.attrs["padding"] == Padding.VALID
-# Setup network to test removal of op with op_type Op.Reshape or Op.Squeeze
-# op_type should be Op.Reshape or Op.Squeeze
-def setup_network(op_type):
- assert op_type == Op.Reshape or op_type == Op.Squeeze
- if op_type == Op.Reshape:
- op_str = "reshape"
- elif op_type == Op.Squeeze:
- op_str = "squeeze"
+def test_remove_reshape():
+ """
+ Test that the expected reshape are removed in graph_optimisation
+ """
+ # Create tensors and operators Test1
quant = testutil.default_quant_params()
+
# create reshape1 op
ifm_shape = [64, 16]
reshape1_ofm_shape = [1, 4, 16, 16]
- reshape1_ifm = create_const_tensor(f"{op_str}1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ reshape1_ifm = create_const_tensor("reshape1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
reshape1_ifm.quantization = quant
- reshape1_ofm = create_const_tensor(
- f"{op_str}1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape)
- )
+ reshape1_ofm = create_const_tensor("reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape))
reshape1_ofm.quantization = quant
- shape_tens = create_const_tensor(f"{op_str}1_shape", [1], DataType.int32, reshape1_ofm_shape)
- reshape1_op = testutil.create_op(op_type, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
+ shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
+ reshape1_op = testutil.create_op(Op.Reshape, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
reshape1_op.run_on_npu = True
@@ -365,43 +360,42 @@ def setup_network(op_type):
# create reshape2 op
ofm_shape = [8, 8, 16]
- reshape2_ofm = create_const_tensor(f"{op_str}2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ reshape2_ofm = create_const_tensor("reshape2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
reshape2_ofm.quantization = quant
- shape_tens = create_const_tensor(f"{op_str}2_shape", [1], DataType.int32, ofm_shape)
- reshape2_op = testutil.create_op(op_type, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
+ shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, ofm_shape)
+ reshape2_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
reshape2_op.attrs["new_shape"] = ofm_shape
reshape2_op.run_on_npu = True
- nng = Graph()
- sg = testutil.create_subgraph([reshape1_op, conv2d_op, reshape2_op])
- nng.subgraphs.append(sg)
-
- return nng, reshape1_op, conv2d_op, reshape2_op
-
-
-def test_remove_reshape():
- """
- Tests that the expected reshape are removed in graph_optimisation
- """
# Test1 no Reshape op is expected to remain in the NPU subgrapgh
# but first one will be put on CPU
# Network is Reshape-Conv-Reshape
# Result is Conv
- nng, reshape1_op, conv2d_op, reshape2_op = setup_network(Op.Reshape)
+ nng = testutil.create_graph([reshape1_op, conv2d_op, reshape2_op])
arch = testutil.create_arch()
assert verify_graph_health(nng)
- nng = optimise_graph(nng, arch, NetworkType.TFLite)
+ nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
assert verify_graph_health(nng)
- # Test2 reshape1 with different quantisation, this Reshape op is expected to remain
- # Network is Reshape-Conv-Reshape
- # expected is Reshape-Conv
- nng, reshape1_op, conv2d_op, reshape2_op = setup_network(Op.Reshape)
- quant_zp32 = testutil.default_quant_params()
- quant_zp32.zero_point = 32
- reshape1_op.ofm.quantization = quant_zp32
+ # Create tensors and operator Test2
+ # create reshape op
+ reshape_ifm = create_const_tensor("reshape_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ reshape_ifm.quantization = quant
+ reshape_ofm = create_const_tensor("reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape))
+ reshape_ofm.quantization = quant
+ shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
+ reshape_op = testutil.create_op(Op.Reshape, [reshape_ifm, shape_tens], reshape_ofm, set_ifm_ofm_shapes=False)
+ reshape_op.attrs["new_shape"] = reshape1_ofm_shape
+ reshape_op.run_on_npu = True
+
+ # Test2 Reshape ifm/ofm is sg input/output.
+ # Reshape op is expected to be replaced by a AvgPool 'NOP'.
+ #
+ # Network is Reshape
+ # expected is AvgPool
+ nng = testutil.create_graph([reshape_op])
assert verify_graph_health(nng)
- nng = optimise_graph(nng, arch, NetworkType.TFLite)
+ nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
assert verify_graph_health(nng)
@@ -410,23 +404,132 @@ def test_remove_squeeze():
Tests that the expected squeeze are removed in graph_optimisation
"""
+ # Create tensors and operators Test1
+ quant = testutil.default_quant_params()
+
+ # create conv op
+ ifm_shape = [1, 1, 1, 1024]
+ conv_ifm = create_const_tensor("conv_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ conv_ifm.quantization = quant
+ conv_ofm = Tensor([1, 1, 1, 1001], DataType.uint8, "output")
+ conv_ofm.quantization = quant.clone()
+ weight_tens = Tensor([1, 1, 1024, 1001], DataType.uint8, "weights")
+ weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
+ weight_tens.quantization = quant.clone()
+ bias_tens = Tensor([1001], DataType.int32, "biases")
+
+ attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
+ attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+
+ conv2d_op = testutil.create_op(
+ Op.Conv2D, [conv_ifm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
+ )
+ conv2d_op.run_on_npu = True
+
+ # create squeeze op
+ ofm_shape = [1, 1001]
+ squeeze_ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ squeeze_ofm.quantization = quant.clone()
+ squeeze_op = testutil.create_op(Op.Squeeze, [conv_ofm], squeeze_ofm, set_ifm_ofm_shapes=False)
+ squeeze_op.attrs["squeeze_dims"] = [1, 2]
+ squeeze_op.run_on_npu = True
+
# Test1 no Squeeze op is expected to remain in the NPU subgrapgh
- # but first one will be put on CPU
- # Network is Squeeze-Conv-Squeeze
+ #
+ # Network is Conv-Squeeze
# Result is Conv
- nng, squeeze1_op, conv2d_op, squeeze2_op = setup_network(Op.Squeeze)
+ nng = testutil.create_graph([conv2d_op, squeeze_op])
arch = testutil.create_arch()
assert verify_graph_health(nng)
- nng = optimise_graph(nng, arch, NetworkType.TFLite)
+ nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
assert verify_graph_health(nng)
- # Test2 squeeze1 with different quantisation, this Squeeze op is expected to remain
- # Network is Squeeze-Conv-Squeeze
- # expected is Squeeze-Conv
- nng, squeeze1_op, conv2d_op, squeeze2_op = setup_network(Op.Squeeze)
- quant_zp32 = testutil.default_quant_params()
- quant_zp32.zero_point = 32
- squeeze1_op.ofm.quantization = quant_zp32
+ # Create tensors and operator Test2
+ # create squeeze op
+ ifm_shape = [1, 1, 1, 1001]
+ squeeze_ifm = create_const_tensor("squeeze_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ squeeze_ifm.quantization = quant
+ squeeze_ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ squeeze_ofm.quantization = quant.clone()
+ squeeze_op = testutil.create_op(Op.Squeeze, [squeeze_ifm], squeeze_ofm, set_ifm_ofm_shapes=False)
+ squeeze_op.attrs["squeeze_dims"] = [1, 2]
+ squeeze_op.run_on_npu = True
+
+ # Test2 Squeeze ifm/ofm is sg input/output.
+ # Squeeze op is expected to be replaced by a AvgPool 'NOP'.
+ #
+ # Network is Squeeze
+ # expected is AvgPool
+ nng = testutil.create_graph([squeeze_op])
assert verify_graph_health(nng)
- nng = optimise_graph(nng, arch, NetworkType.TFLite)
+ nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
+ assert verify_graph_health(nng)
+
+
+def test_remove_expand_dims():
+ """
+ Tests that the expected ExpandDims are removed in graph_optimisation
+ """
+
+ # Create tensors and operators Test1
+ quant = testutil.default_quant_params()
+
+ # create ExpandDims op
+ ifm_shape = [4, 16, 16]
+ ofm_shape = [1, 4, 16, 16]
+ expand_dims_ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ expand_dims_ifm.quantization = quant
+ expand_dims_ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ expand_dims_ofm.quantization = quant.clone()
+ dim_tens = create_const_tensor("dim_tens", [], DataType.uint8, 1)
+ expand_dims_op = testutil.create_op(
+ Op.ExpandDims, [expand_dims_ifm, dim_tens], expand_dims_ofm, set_ifm_ofm_shapes=False
+ )
+ expand_dims_op.run_on_npu = True
+
+ # create conv op
+ conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
+ conv_ofm.quantization = quant.clone()
+ weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
+ weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
+ weight_tens.quantization = quant.clone()
+ bias_tens = Tensor([16], DataType.int32, "biases")
+
+ attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
+ attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+
+ conv2d_op = testutil.create_op(
+ Op.Conv2D, [expand_dims_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
+ )
+ conv2d_op.run_on_npu = True
+
+ # Test1 no ExpandDims op is expected to remain in the NPU subgrapgh
+ #
+ # Network is ExpandDims-Conv
+ # Result is Conv
+ nng = testutil.create_graph([expand_dims_op, conv2d_op])
+ arch = testutil.create_arch()
+ assert verify_graph_health(nng)
+ nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
+ assert verify_graph_health(nng)
+
+ # create ExpandDims op
+ expand_dims_ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ expand_dims_ifm.quantization = quant
+ expand_dims_ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ expand_dims_ofm.quantization = quant.clone()
+ dim_tens = create_const_tensor("dim_tens", [], DataType.uint8, 1)
+ expand_dims_op = testutil.create_op(
+ Op.ExpandDims, [expand_dims_ifm, dim_tens], expand_dims_ofm, set_ifm_ofm_shapes=False
+ )
+ expand_dims_op.run_on_npu = True
+
+ # Test2 ExpandDims ifm/ofm is sg input/output.
+ # ExpandDims op is expected to be replaced by a AvgPool 'NOP'.
+ #
+ # Network is ExpandDims
+ # expected is AvgPool
+ nng = testutil.create_graph([expand_dims_op])
+ assert verify_graph_health(nng)
+ nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
assert verify_graph_health(nng)
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 4c329844..84f99160 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -458,3 +458,57 @@ def test_mean_axis():
assert semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_matching_in_out_quant():
+ # quantisation parameters of ifm and ofm should match.
+ quant = testutil.default_quant_params()
+ # create reshape op
+ ifm_shape = [64, 16]
+ ofm_shape = [1, 4, 16, 16]
+ ifm = create_const_tensor("reshape_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ ifm.quantization = quant
+ ofm = create_const_tensor("reshape_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ ofm.quantization = quant.clone()
+ shape_tens = create_const_tensor("shape", [1], DataType.int32, ofm_shape)
+ op = testutil.create_op(Op.Reshape, [ifm, shape_tens], ofm, set_ifm_ofm_shapes=False)
+ op.attrs["new_shape"] = ofm_shape
+
+ # Matching quantisation parameters
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+ # Different zp
+ ofm.quantization.zero_point = 32
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+ # Different scale
+ ofm.quantization.zero_point = 0
+ ofm.quantization.scale_f32 = 0.9
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+ # Squeeze op diff quant
+ # create squeeze op
+ ifm_shape = [1, 1, 1, 1001]
+ ofm_shape = [1, 1001]
+ ifm = create_const_tensor("squeeze_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ ifm.quantization = quant
+ ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ ofm.quantization = quant.clone()
+ ofm.quantization.zero_point = 32
+ op = testutil.create_op(Op.Squeeze, [ifm], ofm, set_ifm_ofm_shapes=False)
+ op.attrs["squeeze_dims"] = [1, 2]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+ # ExpandDims diff quant
+ quant = testutil.default_quant_params()
+ # create expand_dims op
+ ifm_shape = [4, 16, 16]
+ ofm_shape = [1, 4, 16, 16]
+ ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+ ifm.quantization = quant
+ ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+ ofm.quantization = quant.clone()
+ ofm.quantization.zero_point = 32
+ dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
+ op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
+ assert not semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 15b82c7e..b48cc7af 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -30,10 +30,11 @@ from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
-from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
+from .graph_optimiser_util import bypass_memory_only_ops
from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import convert_depthwise_to_conv
from .graph_optimiser_util import fix_sg_input_output
+from .graph_optimiser_util import memory_only_ops
from .graph_optimiser_util import move_splitsliceread_to_consumer
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
@@ -190,7 +191,7 @@ def remove_SplitSliceRead(op, arch):
len(op.ofm.consumer_list) == 1
and op.ofm.consumer_list[0] is not None
and op.ofm.consumer_list[0].run_on_npu
- and op.ofm.consumer_list[0].type not in (Op.Reshape, Op.Squeeze)
+ and op.ofm.consumer_list[0].type not in memory_only_ops
and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
):
# SplitSliceRead can be performed by tensor consumer
@@ -211,27 +212,6 @@ def remove_SplitSliceRead(op, arch):
DebugDatabase.add_optimised(op, avgpool_op)
-def insert_copy_op_after_tens(tens):
- tens_cons_list_copy = tens.consumer_list.copy()
-
- # Create a avg_pool nop op with ifm as input
- copy_tens = tens.clone()
- copy_op = create_avgpool_nop(tens.name + "_avgpool")
- copy_op.add_input_tensor(tens)
- copy_op.set_output_tensor(copy_tens)
- copy_op.set_ifm_ofm_shapes()
- copy_op.run_on_npu = True
-
- # Set copy_ifm consumers
- for tens_cons in tens_cons_list_copy:
- if tens_cons is not None:
- for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
- if cons_inp == tens:
- tens_cons.set_input_tensor(copy_tens, ifm_idx)
-
- DebugDatabase.add_optimised(tens.ops[0], copy_op)
-
-
def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
k_w, k_h = kernel.dilated_wh()
s_x, s_y = kernel.stride
@@ -985,19 +965,9 @@ def convert_tanh_sigmoid_to_lut(op, arch, nng):
return op
-def remove_reshape_and_squeeze_ops(op, arch):
- if op.run_on_npu and op.type in (Op.Reshape, Op.Squeeze):
- ofm = op.ofm
- ifm = op.ifm
-
- # Check if quantization is the same in the input and output for the reshape ops
- if not check_quantized_tens_scaling_equal(ifm, ofm):
- # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
- # In order to remove this reshape either quantization properties need to be moved to Operator,
- # or the reshape need to be replace with a NOP.
- return
-
- bypass_reshape_and_squeeze_ops(op)
+def remove_memory_only_ops(op, arch):
+ if op.run_on_npu and op.type in memory_only_ops:
+ bypass_memory_only_ops(op)
def fuse_activation_function_with_prev(op, arch, nng):
@@ -1463,9 +1433,9 @@ def tflite_optimise_graph(nng, arch):
nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
)
- # Removal of reshapes and squeeze
+ # Removal of memory only operators
for sg in nng.subgraphs:
- rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshape_and_squeeze_ops])
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
sg.refresh_after_modification()
# Rewrite of operators
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index c8b373a3..6e2467bb 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -26,6 +26,7 @@ from .operation import get_slice_offsets
from .operation import Op
from .supported_operators_util import docstring_format_args
from .supported_operators_util import list_formatter
+from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
@@ -53,7 +54,8 @@ class TFLiteSemantic:
binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
- shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean))
+ shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims))
+ reshape_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims,))
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -110,6 +112,10 @@ class TFLiteSemantic:
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
+ # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
+ for op_type in TFLiteSemantic.reshape_ops:
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
+
# Softmax specific checks:
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
@@ -518,6 +524,13 @@ class TFLiteSemantic:
valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
return valid, f"Axis is {axis}"
+ @staticmethod
+ def constraint_matching_in_out_quant(op):
+ "Input and output quantisation must match."
+ if not check_quantized_tens_scaling_equal(op.ifm, op.ofm):
+ return False, "IFM and OFM quantisation parameters are not equal."
+ return True, "IFM and OFM quantisation parameters matches."
+
def tflite_semantic_checker(nng):
semantic_checker = TFLiteSemantic()
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 016d44e5..933302f5 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -86,7 +86,7 @@ class TFLiteSupportedOperators:
)
split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
- memory_only_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze,)) | concat_ops | split_ops
+ memory_only_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims,)) | concat_ops | split_ops
per_axis_quant_ops = convolution_like_ops # per-axis/channel quantization only currently supported for conv ops
supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index bade4a97..a298ddbb 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -21,7 +21,7 @@ from . import rewrite_graph
from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
-from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
+from .graph_optimiser_util import bypass_memory_only_ops
from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import convert_depthwise_to_conv
from .graph_optimiser_util import move_splitsliceread_to_consumer
@@ -294,7 +294,7 @@ def rewrite_concat_ops(op, arch):
def remove_reshapes(op, arch):
if op.run_on_npu and op.type == Op.Reshape:
- bypass_reshape_and_squeeze_ops(op)
+ bypass_memory_only_ops(op)
def rewrite_activation(op, arch, nng):
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 94487499..8a808276 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -167,6 +167,10 @@ def print_subgraph_io_summary(nng):
def generate_supported_ops():
+ # Exclude network type from generation by adding value to exclude list.
+ # To easily exclude NetworkType from generated documentation.
+ exclude_generation_network_type_value = [NetworkType.TOSA.value]
+
lines = [
"# Supported Ops",
"",
@@ -180,11 +184,17 @@ def generate_supported_ops():
]
for network_type in NetworkType:
+ if network_type.value in exclude_generation_network_type_value:
+ continue
+
lines += [
f"- [{network_type.name}](#{network_type.name.lower()}-summary-table)",
]
for network_type in NetworkType:
+ if network_type.value in exclude_generation_network_type_value:
+ continue
+
lines += [
"",
f"## {network_type.name} Summary Table",