aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_reader.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-08-23 15:33:59 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-09-03 12:19:48 +0000
commitdf99510f04aef99d1b8e9be9bfcde8fc1738b65f (patch)
tree00668b0e74f95da5cc51a41b9340d8c88fbc7ffe /ethosu/vela/tosa_reader.py
parentcce872bc3de3ed5f9bf1aa1a8cf9ce41cf2b2520 (diff)
downloadethos-u-vela-df99510f04aef99d1b8e9be9bfcde8fc1738b65f.tar.gz
TOSA: Added Depthwise support
This is mainly to add support for depthwise conv2d with dephmultiplier = 1. (But there are no testcases suited, all I have sourced has depth_multiplier set to 2, which is not supported.) -Added support for depthwise conv2d. -Added support for removing Transpose of constant data -Added support for removing reshape Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I143e6246becfa78fd9f7510af0bf0d6b3fbbf2c7
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r--ethosu/vela/tosa_reader.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index eb317169..268d43ce 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -33,9 +33,11 @@ from .tensor import QuantizationParameters
from .tensor import shape_num_elements
from .tensor import Tensor
from .tflite_mapping import DataType
+from .tosa.Op import Op as TosaOp
from .tosa.TosaGraph import TosaGraph as TG
from .tosa_mapping import datatype_map
from .tosa_mapping import datatype_map_numpy
+from .tosa_mapping import TOSA_IFM_INDICES
from .tosa_mapping import tosa_operator_map
from .tosa_mapping import unsupported_tosa_operators
@@ -89,7 +91,7 @@ class TosaSubgraph:
op_code = op_data.Op()
if op_code in unsupported_tosa_operators:
print("Unsupported Operator", op_code)
- assert False
+ return
op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
inputs = []
@@ -104,6 +106,15 @@ class TosaSubgraph:
outputs.append(output_tens)
assert output_tens is not None
+ # Permutation attribute for TRANSPOSE is an input tensor in TOSA
+ # TODO In order to optimise Depthwise spawning from TFLite Support for removing
+ # Transpose of constant data.
+ # Moving permutation to an attribute, to match internal graph representation for now
+ perms = None
+ if op_code == TosaOp.TRANSPOSE:
+ perms = perms = inputs.pop(1)
+ indices = TOSA_IFM_INDICES
+
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
@@ -148,6 +159,7 @@ class TosaSubgraph:
stride = op.attrs["stride"]
if len(stride) == 2:
op.attrs["strides"] = (1, stride[0], stride[1], 1)
+ del op.attrs["stride"]
else:
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(stride))
@@ -167,6 +179,11 @@ class TosaSubgraph:
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(kernel))
assert False
+ if op.type.is_depthwise_conv2d_op():
+ op.attrs["depth_multiplier"] = op.weights.shape[3]
+
+ elif op.type == Op.Transpose:
+ op.attrs["perms"] = perms.values
if quant_serializer is not None:
quant_info = quant_serializer.deserialize(op_data)