aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r--ethosu/vela/tosa_reader.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index eb317169..268d43ce 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -33,9 +33,11 @@ from .tensor import QuantizationParameters
from .tensor import shape_num_elements
from .tensor import Tensor
from .tflite_mapping import DataType
+from .tosa.Op import Op as TosaOp
from .tosa.TosaGraph import TosaGraph as TG
from .tosa_mapping import datatype_map
from .tosa_mapping import datatype_map_numpy
+from .tosa_mapping import TOSA_IFM_INDICES
from .tosa_mapping import tosa_operator_map
from .tosa_mapping import unsupported_tosa_operators
@@ -89,7 +91,7 @@ class TosaSubgraph:
op_code = op_data.Op()
if op_code in unsupported_tosa_operators:
print("Unsupported Operator", op_code)
- assert False
+ return
op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
inputs = []
@@ -104,6 +106,15 @@ class TosaSubgraph:
outputs.append(output_tens)
assert output_tens is not None
+ # Permutation attribute for TRANSPOSE is an input tensor in TOSA
+ # TODO In order to optimise Depthwise spawning from TFLite Support for removing
+ # Transpose of constant data.
+ # Moving permutation to an attribute, to match internal graph representation for now
+ perms = None
+ if op_code == TosaOp.TRANSPOSE:
+ perms = perms = inputs.pop(1)
+ indices = TOSA_IFM_INDICES
+
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
@@ -148,6 +159,7 @@ class TosaSubgraph:
stride = op.attrs["stride"]
if len(stride) == 2:
op.attrs["strides"] = (1, stride[0], stride[1], 1)
+ del op.attrs["stride"]
else:
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(stride))
@@ -167,6 +179,11 @@ class TosaSubgraph:
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(kernel))
assert False
+ if op.type.is_depthwise_conv2d_op():
+ op.attrs["depth_multiplier"] = op.weights.shape[3]
+
+ elif op.type == Op.Transpose:
+ op.attrs["perms"] = perms.values
if quant_serializer is not None:
quant_info = quant_serializer.deserialize(op_data)