diff options
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r-- | ethosu/vela/tosa_reader.py | 66 |
1 files changed, 31 insertions, 35 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 56e0b1c..56af59d 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -40,7 +40,6 @@ from .tosa.Op import Op as TosaOp from .tosa.TosaGraph import TosaGraph as TG from .tosa_mapping import datatype_map from .tosa_mapping import datatype_map_numpy -from .tosa_mapping import TOSA_IFM_INDICES from .tosa_mapping import tosa_operator_map from .tosa_mapping import unsupported_tosa_operators @@ -94,30 +93,30 @@ class TosaSubgraph: op_code = op_data.Op() if op_code in unsupported_tosa_operators: print("Unsupported Operator", op_code) + for opname in dir(TosaOp): + if op_code == getattr(TosaOp, opname): + print(f" {opname}") return op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code] inputs = [] outputs = [] for idx in range(op_data.InputsLength()): - input_tens = self.get_tensor_by_name(decode_str(op_data.Inputs(idx))) + input = decode_str(op_data.Inputs(idx)) + input_tens = self.get_tensor_by_name(input) inputs.append(input_tens) + if input_tens is None: + print(f"could not find named input tensor {input}::{input_tens}") assert input_tens is not None for idx in range(op_data.OutputsLength()): - output_tens = self.get_tensor_by_name(decode_str(op_data.Outputs(idx))) + output = decode_str(op_data.Outputs(idx)) + output_tens = self.get_tensor_by_name(output) outputs.append(output_tens) + if output_tens is None: + print(f"could not find named output tensor {output}::{output_tens}") assert output_tens is not None - # Permutation attribute for TRANSPOSE is an input tensor in TOSA - # TODO In order to optimise Depthwise spawning from TFLite Support for removing - # Transpose of constant data. - # Moving permutation to an attribute, to match internal graph representation for now - perms = None - if op_code == TosaOp.TRANSPOSE: - perms = inputs.pop(1) - indices = TOSA_IFM_INDICES - name = "unknown_op_name" if len(outputs): name = outputs[0].name @@ -189,27 +188,21 @@ class TosaSubgraph: if op.type.is_depthwise_conv2d_op(): op.attrs["depth_multiplier"] = op.weights.shape[3] if op.type == Op.SplitSliceRead: - op.read_offsets[0] = Shape4D.from_list(list(op.attrs["begin"]), 0) + op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0) op.read_shapes[0] = op.attrs["size"] - elif op.type == Op.Transpose: - op.attrs["perms"] = perms.values - - if quant_serializer is not None: - quant_info = quant_serializer.deserialize(op_data) - # TODO tensor zero points currently set here # zero points part of Rescale operation, handled in tosa_graph_optimizer - if "input_zp" in quant_info: - self.set_tensor_zp(op.ifm, quant_info["input_zp"]) - if "weight_zp" in quant_info: - self.set_tensor_zp(op.weights, quant_info["weight_zp"]) - if "output_zp" in quant_info: - self.set_tensor_zp(op.ofm, quant_info["output_zp"]) - if "a_zp" in quant_info: - self.set_tensor_zp(op.ifm, quant_info["a_zp"]) - if "b_zp" in quant_info: - self.set_tensor_zp(op.ifm2, quant_info["b_zp"]) + if "input_zp" in op.attrs: + self.set_tensor_zp(op.ifm, op.attrs["input_zp"]) + if "weight_zp" in op.attrs: + self.set_tensor_zp(op.weights, op.attrs["weight_zp"]) + if "output_zp" in op.attrs: + self.set_tensor_zp(op.ofm, op.attrs["output_zp"]) + if "a_zp" in op.attrs: + self.set_tensor_zp(op.ifm, op.attrs["a_zp"]) + if "b_zp" in op.attrs: + self.set_tensor_zp(op.ifm2, op.attrs["b_zp"]) def parse_tensor(self, tens_data): name = decode_str(tens_data.Name()) @@ -260,7 +253,6 @@ class TosaSubgraph: class TosaGraph: def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes): - self.op_times = {} if batch_size is None: batch_size = 1 @@ -278,11 +270,15 @@ class TosaGraph: parsing_step = "parsing version" self.check_version(tosa_graph) + parsing_step = "parsing single main region" + assert 1 == tosa_graph.RegionsLength() + assert b"main" == tosa_graph.Regions(0).Name() + parsing_step = "parsing blocks length" self.subgraphs = [] - for b_idx in range(tosa_graph.BlocksLength()): + for b_idx in range(tosa_graph.Regions(0).BlocksLength()): parsing_step = f"parsing block {b_idx}" - self.subgraphs.append(TosaSubgraph(self, tosa_graph.Blocks(b_idx))) + self.subgraphs.append(TosaSubgraph(self, tosa_graph.Regions(0).Blocks(b_idx))) self.nng = Graph(self.name, self.batch_size) for tosa_sg in self.subgraphs: @@ -297,8 +293,8 @@ class TosaGraph: def check_version(self, tosa_graph): version = tosa_graph.Version() - version_str = f"{version._major()}.{version._minor()}.{version._patch()}" - if version_str != "0.22.0": + version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}" + if version_str != "0.80.0": print(f"Unsupported TOSA version: {version_str}") assert False |