aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-08-23 15:33:59 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-09-03 12:19:48 +0000
commitdf99510f04aef99d1b8e9be9bfcde8fc1738b65f (patch)
tree00668b0e74f95da5cc51a41b9340d8c88fbc7ffe
parentcce872bc3de3ed5f9bf1aa1a8cf9ce41cf2b2520 (diff)
downloadethos-u-vela-df99510f04aef99d1b8e9be9bfcde8fc1738b65f.tar.gz
TOSA: Added Depthwise support
This is mainly to add support for depthwise conv2d with dephmultiplier = 1. (But there are no testcases suited, all I have sourced has depth_multiplier set to 2, which is not supported.) -Added support for depthwise conv2d. -Added support for removing Transpose of constant data -Added support for removing reshape Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I143e6246becfa78fd9f7510af0bf0d6b3fbbf2c7
-rw-r--r--ethosu/vela/graph_optimiser_util.py114
-rw-r--r--ethosu/vela/operation.py2
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py86
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py60
-rw-r--r--ethosu/vela/tosa_mapping.py6
-rw-r--r--ethosu/vela/tosa_reader.py19
-rw-r--r--ethosu/vela/tosa_supported_operators.py41
-rw-r--r--ethosu/vela/vela.py2
8 files changed, 235 insertions, 95 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 570c7244..d01d4a19 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -17,14 +17,17 @@
# Common functions and definitions used during the graph optimization.
from typing import Tuple
+import numpy as np
+
from .data_type import DataType
from .debug_database import DebugDatabase
+from .errors import UnsupportedFeatureError
from .errors import VelaError
from .operation import Op
+from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
from .tensor import check_quantized_tens_scaling_equal
-
memory_only_ops = (
Op.Reshape,
Op.Squeeze,
@@ -174,6 +177,41 @@ def set_ifm_ofm_op_shapes(op, arch, nng):
return op
+def bypass_reshape_and_squeeze_ops(op):
+ assert op.type in (Op.Reshape, Op.Squeeze)
+ ofm = op.ofm
+ ifm = op.ifm
+ # Check if ifm/ofm are network ifm/ofm
+ ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+ ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
+ ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
+ # Check if ifm/ofm is produced respectively consumed by CPU
+ ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
+ # This case should be handled prior to this function
+ assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
+
+ if ofm_is_sg_ofm or ofm_is_cpu_consumed:
+ # Bypassed by replacing ifm with ofm
+ ofm.ops = []
+ for prev_op in ifm.ops:
+ prev_op.outputs = [ofm]
+ ofm.ops.append(prev_op)
+
+ # All ifm consumers need to use ofm as input
+ for ifm_cons in ifm.consumer_list:
+ for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+ if cons_ifm == ifm:
+ ifm_cons.set_input_tensor(ofm, ifm_idx)
+ else:
+ # Bypassed by replacing ofm with ifm
+ for cons in ofm.consumer_list:
+ for ifm_idx, cons_ifm in enumerate(cons.inputs):
+ if cons_ifm == ofm:
+ cons.set_input_tensor(ifm, ifm_idx)
+
+
def check_reshapes(op, arch):
if op.run_on_npu and op.type == Op.Reshape:
ofm = op.ofm
@@ -186,3 +224,77 @@ def check_reshapes(op, arch):
def record_optimised(op, arch):
if op.type != Op.Const:
DebugDatabase.add_optimised(op, op)
+
+
+def insert_copy_op_after_tens(tens):
+ tens_cons_list_copy = tens.consumer_list.copy()
+
+ # Create a avg_pool nop op with ifm as input
+ copy_tens = tens.clone()
+ copy_op = create_avgpool_nop(tens.name + "_avgpool")
+ copy_op.add_input_tensor(tens)
+ copy_op.set_output_tensor(copy_tens)
+ copy_op.set_ifm_ofm_shapes()
+ copy_op.run_on_npu = True
+
+ # Set copy_ifm consumers
+ for tens_cons in tens_cons_list_copy:
+ if tens_cons is not None:
+ for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
+ if cons_inp == tens:
+ tens_cons.set_input_tensor(copy_tens, ifm_idx)
+
+ DebugDatabase.add_optimised(tens.ops[0], copy_op)
+
+
+def fix_sg_input_output(op, arch, nng):
+ if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
+ return op
+
+ # For the Reshape/Squeeze operators we want to remove, tensors are removed.
+ # But in order to to do this, they cannot be outputs of the sg,
+ # this need to be fixed prior to the removal.
+ # Solution is to add a avgpool NOP, to maintain the original tensor.
+ # This is also valid when reshape ifm/ofm is produced respectively
+ # consumed by CPU
+
+ # Check if operator ifm/ofm are sg ifm/ofm
+ ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+ ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+ ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+ # Check if ifm/ofm is produced respectively consumed by CPU
+ ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
+ if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
+ insert_copy_op_after_tens(op.ifm)
+
+ return op
+
+
+def convert_depthwise_to_conv(op, arch, nng):
+ # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
+ # the ofm depth equals the depth multipler.
+ # If those conditions are true, then we can perform a simple
+ # switch of the operator type (and weight order)
+
+ if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
+ ifm_shape = op.ifm_shapes[0]
+ weight_tensor = op.inputs[1]
+ ofm_shape = op.ofm_shapes[0]
+ if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
+ # Change op type to Conv2d
+ op.type = Op.Conv2DBias
+ del op.attrs["channel_multiplier"]
+ del op.attrs["depth_multiplier"]
+
+ weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
+ weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
+ else:
+ raise UnsupportedFeatureError(
+ f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
+ f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
+ )
+ DebugDatabase.add_optimised(op, op)
+ return op
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index ca833e21..80be228b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -283,7 +283,7 @@ class Op(Enum):
Tanh = OperatorInfo(indices=NNG_IFM_INDICES)
Tile = OperatorInfo()
TopKV2 = OperatorInfo()
- Transpose = OperatorInfo()
+ Transpose = OperatorInfo(indices=NNG_IFM_INDICES)
UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Unique = OperatorInfo()
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 255a1f5e..ef39aea3 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -30,7 +30,10 @@ from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
from .graph_optimiser_util import calc_explicit_padding
+from .graph_optimiser_util import convert_depthwise_to_conv
+from .graph_optimiser_util import fix_sg_input_output
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
@@ -244,32 +247,6 @@ def insert_copy_op_after_tens(tens):
DebugDatabase.add_optimised(tens.ops[0], copy_op)
-def fix_sg_input_output(op, arch, nng):
- if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
- return op
-
- # For the Reshape/Squeeze operators we want to remove, tensors are removed.
- # But in order to to do this, they cannot be outputs of the sg,
- # this need to be fixed prior to the removal.
- # Solution is to add a avgpool NOP, to maintain the original tensor.
- # This is also valid when reshape ifm/ofm is produced respectively
- # consumed by CPU
-
- # Check if operator ifm/ofm are sg ifm/ofm
- ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
- ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
- ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
- # Check if ifm/ofm is produced respectively consumed by CPU
- ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
- ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
-
- if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
- # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
- insert_copy_op_after_tens(op.ifm)
-
- return op
-
-
def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
k_w, k_h = kernel.dilated_wh()
s_x, s_y = kernel.stride
@@ -576,33 +553,6 @@ def add_padding_fields(op, arch, nng):
return op
-def convert_depthwise_to_conv(op, arch, nng):
- # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
- # the ofm depth equals the depth multipler.
- # If those conditions are true, then we can perform a simple
- # switch of the operator type (and weight order)
-
- if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
- ifm_shape = op.ifm_shapes[0]
- weight_tensor = op.inputs[1]
- ofm_shape = op.ofm_shapes[0]
- if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
- # Change op type to Conv2d
- op.type = Op.Conv2DBias
- del op.attrs["channel_multiplier"]
- del op.attrs["depth_multiplier"]
-
- weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
- weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
- else:
- raise UnsupportedFeatureError(
- f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
- f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
- )
- DebugDatabase.add_optimised(op, op)
- return op
-
-
def reorder_depthwise_weights(op, arch, nng):
if op.type.is_depthwise_conv2d_op():
weight_tensor = op.inputs[1]
@@ -1058,35 +1008,7 @@ def remove_reshape_and_squeeze_ops(op, arch):
# or the reshape need to be replace with a NOP.
return
- # Check if ifm/ofm are network ifm/ofm
- ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
- ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
- ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
- # Check if ifm/ofm is produced respectively consumed by CPU
- ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
- ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
-
- # This case should be handled prior to this function
- assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
-
- if ofm_is_sg_ofm or ofm_is_cpu_consumed:
- # Bypassed by replacing ifm with ofm
- ofm.ops = []
- for prev_op in ifm.ops:
- prev_op.outputs = [ofm]
- ofm.ops.append(prev_op)
-
- # All ifm consumers need to use ofm as input
- for ifm_cons in ifm.consumer_list:
- for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
- if cons_ifm == ifm:
- ifm_cons.set_input_tensor(ofm, ifm_idx)
- else:
- # Bypassed by replacing ofm with ifm
- for cons in ofm.consumer_list:
- for ifm_idx, cons_ifm in enumerate(cons.inputs):
- if cons_ifm == ofm:
- cons.set_input_tensor(ifm, ifm_idx)
+ bypass_reshape_and_squeeze_ops(op)
def fuse_activation_function_with_prev(op, arch, nng):
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 44e0f8ec..169da40d 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -19,7 +19,10 @@ from . import rewrite_graph
from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
+from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
from .graph_optimiser_util import calc_explicit_padding
+from .graph_optimiser_util import convert_depthwise_to_conv
+from .graph_optimiser_util import fix_sg_input_output
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
@@ -74,6 +77,43 @@ def add_padding_fields(op, arch, nng):
return op
+def remove_const_transpose(op, arch, nng):
+ if op.type == Op.Transpose:
+ removed = False
+ if len(op.ifm.ops) == 1:
+ prev_op = op.ifm.ops[0]
+ if prev_op.type == Op.Const:
+ # Transpose the Tensor and data and remove Transpose
+ # TODO move to Tensor?
+ reorder = op.attrs["perms"]
+ shape = op.ifm.shape.copy()
+ tens = op.ifm
+
+ tens.shape = [shape[idx] for idx in reorder]
+ tens.bandwidth_shape = tens.shape
+ tens.storage_shape = tens.shape
+
+ if tens.values is not None:
+ tens.values = tens.values.transpose(reorder)
+
+ op.ofm.values = tens.values
+ # Bypass the Transpose op
+ prev_op.set_output_tensor(op.ofm)
+ DebugDatabase.add_optimised(op, prev_op)
+ removed = True
+
+ if not removed:
+ print("Cannot remove Transpose, and handling of Transpose is not supported")
+ assert False
+
+ return op
+
+
+def remove_reshapes(op, arch):
+ if op.run_on_npu and op.type == Op.Reshape:
+ bypass_reshape_and_squeeze_ops(op)
+
+
def rewrite_activation(op, arch, nng):
if op.type not in (Op.ReluN, Op.Clamp):
return op
@@ -206,6 +246,7 @@ def fixup_quantization(op, arch, nng):
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
+ assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
return op
@@ -221,8 +262,25 @@ def tosa_optimise_graph(nng, arch):
nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
)
+ # Removal of Transpose
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
+ )
+
+ # Handle sg input output
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
+ )
+
+ # Removal of reshapes
+ for sg in nng.subgraphs:
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+ sg.refresh_after_modification()
+
# Rewite Operators step
- op_rewrite_list = [set_tensor_equivalence, rewrite_rescale]
+ op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv]
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 75ca43ef..5d0dd33d 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -221,11 +221,9 @@ unsupported_tosa_operators = {
TosaOp.REDUCE_SUM,
TosaOp.CONCAT,
TosaOp.PAD,
- TosaOp.RESHAPE,
TosaOp.REVERSE,
TosaOp.SLICE,
TosaOp.TILE,
- TosaOp.TRANSPOSE,
TosaOp.GATHER,
TosaOp.SCATTER,
TosaOp.RESIZE,
@@ -304,11 +302,11 @@ tosa_operator_map = {
# TODO TosaOp.REDUCE_SUM
# TODO TosaOp.CONCAT
# TODO TosaOp.PAD
- # TODO TosaOp.RESHAPE
+ TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.REVERSE
# TODO TosaOp.SLICE
# TODO TosaOp.TILE
- # TODO TosaOp.TRANSPOSE
+ TosaOp.TRANSPOSE: (Op.Transpose, None, None, TOSA_IFM_IFM2_INDICES),
# TODO TosaOp.GATHER
# TODO TosaOp.SCATTER
# TODO TosaOp.RESIZE
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index eb317169..268d43ce 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -33,9 +33,11 @@ from .tensor import QuantizationParameters
from .tensor import shape_num_elements
from .tensor import Tensor
from .tflite_mapping import DataType
+from .tosa.Op import Op as TosaOp
from .tosa.TosaGraph import TosaGraph as TG
from .tosa_mapping import datatype_map
from .tosa_mapping import datatype_map_numpy
+from .tosa_mapping import TOSA_IFM_INDICES
from .tosa_mapping import tosa_operator_map
from .tosa_mapping import unsupported_tosa_operators
@@ -89,7 +91,7 @@ class TosaSubgraph:
op_code = op_data.Op()
if op_code in unsupported_tosa_operators:
print("Unsupported Operator", op_code)
- assert False
+ return
op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
inputs = []
@@ -104,6 +106,15 @@ class TosaSubgraph:
outputs.append(output_tens)
assert output_tens is not None
+ # Permutation attribute for TRANSPOSE is an input tensor in TOSA
+ # TODO In order to optimise Depthwise spawning from TFLite Support for removing
+ # Transpose of constant data.
+ # Moving permutation to an attribute, to match internal graph representation for now
+ perms = None
+ if op_code == TosaOp.TRANSPOSE:
+ perms = perms = inputs.pop(1)
+ indices = TOSA_IFM_INDICES
+
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
@@ -148,6 +159,7 @@ class TosaSubgraph:
stride = op.attrs["stride"]
if len(stride) == 2:
op.attrs["strides"] = (1, stride[0], stride[1], 1)
+ del op.attrs["stride"]
else:
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(stride))
@@ -167,6 +179,11 @@ class TosaSubgraph:
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(kernel))
assert False
+ if op.type.is_depthwise_conv2d_op():
+ op.attrs["depth_multiplier"] = op.weights.shape[3]
+
+ elif op.type == Op.Transpose:
+ op.attrs["perms"] = perms.values
if quant_serializer is not None:
quant_info = quant_serializer.deserialize(op_data)
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 3b0e6b39..90d54687 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -28,19 +28,24 @@ class TosaSupportedOperators:
# TODO currently sparsely populated
# Categorised lists of supported operators
convolution_ops = set((Op.Conv2DBias,))
- convolution_like_ops = convolution_ops
+ depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
+ convolution_like_ops = convolution_ops | depthwise_convolution_ops
+
+ # TODO depending on what will be committed
max_pooling_ops = Op.op_set(Op.is_maxpool_op)
avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
- pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
+ pooling_ops = max_pooling_ops | avg_pooling_ops
+ fc_vector_products = set((Op.FullyConnected,))
- mac_main_ops = convolution_like_ops | pooling_ops
+ mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
+ memory_only_ops = set((Op.Reshape, Op.Transpose,))
type_conversion_ops = set((Op.Rescale,))
relu_ops = set((Op.Clamp, Op.ReluN,))
activation_ops = relu_ops
npu_post_ops = activation_ops
- supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops
+ supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops
# Supported data types
# TODO will differ compared to TensorFlow Lite, currently set to the same
@@ -54,6 +59,12 @@ class TosaSupportedOperators:
# Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
+ self.specific_constraints[Op.Transpose].append(TosaSupportedOperators.constraint_ifm_producer)
+
+ # Depthwise Conv specific checks:
+ for op_type in TosaSupportedOperators.depthwise_convolution_ops:
+ self.specific_constraints[op_type].append(TosaSupportedOperators.constraint_depth_multiplier)
+
def is_operator_supported(self, op):
ext_type = optype_to_tosa_op_type(op.type)
if op.type not in TosaSupportedOperators.supported_operators:
@@ -87,3 +98,25 @@ class TosaSupportedOperators:
valid = False
extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}")
return valid, ", ".join(extra)
+
+ @staticmethod
+ def constraint_ifm_producer(cls, op):
+ "Input must be constant data"
+ valid = op.ifm.ops and op.ifm.ops[0].type == Op.Const
+ return valid, "Op has ifm with non-constant data"
+
+ # TODO duplicates tflite_supported operators, but support for depth multiplier should be added at a later stage
+ @staticmethod
+ def constraint_depth_multiplier(op):
+ "For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier"
+ depth_multiplier = op.attrs.get("depth_multiplier", 1)
+ if depth_multiplier > 1:
+ ifm_channels = op.ifm.shape[3]
+ ofm_channels = op.ofm.shape[3]
+ valid = (ifm_channels == 1) and (ofm_channels == depth_multiplier)
+ extra = (
+ f"Op has ifm_channels={ifm_channels}, ofm_channels={ofm_channels}"
+ f" and depth_multiplier={depth_multiplier}"
+ )
+ return valid, extra
+ return True, "Op has depth_multiplier=1"
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 94487499..6c9fbce2 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -85,7 +85,7 @@ def process(input_name, enable_debug_db, arch, model_reader_options, compiler_op
)
output_tfl_filename = output_basename + "_vela.tflite"
- if input_name.endswith(".tflite"):
+ if input_name.endswith(".tflite") or input_name.endswith(".tosa"):
tflite_writer.write_tflite(nng, output_tfl_filename)
if input_name.endswith(".tosa"):
rawdata_writer.write_rawdata_output(nng, arch, output_basename)