diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-08-23 15:33:59 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2021-09-03 12:19:48 +0000 |
commit | df99510f04aef99d1b8e9be9bfcde8fc1738b65f (patch) | |
tree | 00668b0e74f95da5cc51a41b9340d8c88fbc7ffe /ethosu/vela/tosa_reader.py | |
parent | cce872bc3de3ed5f9bf1aa1a8cf9ce41cf2b2520 (diff) | |
download | ethos-u-vela-df99510f04aef99d1b8e9be9bfcde8fc1738b65f.tar.gz |
TOSA: Added Depthwise support
This is mainly to add support for depthwise conv2d
with dephmultiplier = 1.
(But there are no testcases suited, all I have sourced
has depth_multiplier set to 2, which is not supported.)
-Added support for depthwise conv2d.
-Added support for removing Transpose of constant data
-Added support for removing reshape
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I143e6246becfa78fd9f7510af0bf0d6b3fbbf2c7
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r-- | ethosu/vela/tosa_reader.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index eb317169..268d43ce 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -33,9 +33,11 @@ from .tensor import QuantizationParameters from .tensor import shape_num_elements from .tensor import Tensor from .tflite_mapping import DataType +from .tosa.Op import Op as TosaOp from .tosa.TosaGraph import TosaGraph as TG from .tosa_mapping import datatype_map from .tosa_mapping import datatype_map_numpy +from .tosa_mapping import TOSA_IFM_INDICES from .tosa_mapping import tosa_operator_map from .tosa_mapping import unsupported_tosa_operators @@ -89,7 +91,7 @@ class TosaSubgraph: op_code = op_data.Op() if op_code in unsupported_tosa_operators: print("Unsupported Operator", op_code) - assert False + return op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code] inputs = [] @@ -104,6 +106,15 @@ class TosaSubgraph: outputs.append(output_tens) assert output_tens is not None + # Permutation attribute for TRANSPOSE is an input tensor in TOSA + # TODO In order to optimise Depthwise spawning from TFLite Support for removing + # Transpose of constant data. + # Moving permutation to an attribute, to match internal graph representation for now + perms = None + if op_code == TosaOp.TRANSPOSE: + perms = perms = inputs.pop(1) + indices = TOSA_IFM_INDICES + name = "unknown_op_name" if len(outputs): name = outputs[0].name @@ -148,6 +159,7 @@ class TosaSubgraph: stride = op.attrs["stride"] if len(stride) == 2: op.attrs["strides"] = (1, stride[0], stride[1], 1) + del op.attrs["stride"] else: # TODO CONV3D more to be done.... print("Unsupported kernel dimensions: ", len(stride)) @@ -167,6 +179,11 @@ class TosaSubgraph: # TODO CONV3D more to be done.... print("Unsupported kernel dimensions: ", len(kernel)) assert False + if op.type.is_depthwise_conv2d_op(): + op.attrs["depth_multiplier"] = op.weights.shape[3] + + elif op.type == Op.Transpose: + op.attrs["perms"] = perms.values if quant_serializer is not None: quant_info = quant_serializer.deserialize(op_data) |