diff options
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r-- | ethosu/vela/tosa_reader.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index eb317169..268d43ce 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -33,9 +33,11 @@ from .tensor import QuantizationParameters from .tensor import shape_num_elements from .tensor import Tensor from .tflite_mapping import DataType +from .tosa.Op import Op as TosaOp from .tosa.TosaGraph import TosaGraph as TG from .tosa_mapping import datatype_map from .tosa_mapping import datatype_map_numpy +from .tosa_mapping import TOSA_IFM_INDICES from .tosa_mapping import tosa_operator_map from .tosa_mapping import unsupported_tosa_operators @@ -89,7 +91,7 @@ class TosaSubgraph: op_code = op_data.Op() if op_code in unsupported_tosa_operators: print("Unsupported Operator", op_code) - assert False + return op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code] inputs = [] @@ -104,6 +106,15 @@ class TosaSubgraph: outputs.append(output_tens) assert output_tens is not None + # Permutation attribute for TRANSPOSE is an input tensor in TOSA + # TODO In order to optimise Depthwise spawning from TFLite Support for removing + # Transpose of constant data. + # Moving permutation to an attribute, to match internal graph representation for now + perms = None + if op_code == TosaOp.TRANSPOSE: + perms = perms = inputs.pop(1) + indices = TOSA_IFM_INDICES + name = "unknown_op_name" if len(outputs): name = outputs[0].name @@ -148,6 +159,7 @@ class TosaSubgraph: stride = op.attrs["stride"] if len(stride) == 2: op.attrs["strides"] = (1, stride[0], stride[1], 1) + del op.attrs["stride"] else: # TODO CONV3D more to be done.... print("Unsupported kernel dimensions: ", len(stride)) @@ -167,6 +179,11 @@ class TosaSubgraph: # TODO CONV3D more to be done.... print("Unsupported kernel dimensions: ", len(kernel)) assert False + if op.type.is_depthwise_conv2d_op(): + op.attrs["depth_multiplier"] = op.weights.shape[3] + + elif op.type == Op.Transpose: + op.attrs["perms"] = perms.values if quant_serializer is not None: quant_info = quant_serializer.deserialize(op_data) |