aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r--ethosu/vela/tosa_reader.py66
1 files changed, 31 insertions, 35 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 56e0b1c..56af59d 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -40,7 +40,6 @@ from .tosa.Op import Op as TosaOp
from .tosa.TosaGraph import TosaGraph as TG
from .tosa_mapping import datatype_map
from .tosa_mapping import datatype_map_numpy
-from .tosa_mapping import TOSA_IFM_INDICES
from .tosa_mapping import tosa_operator_map
from .tosa_mapping import unsupported_tosa_operators
@@ -94,30 +93,30 @@ class TosaSubgraph:
op_code = op_data.Op()
if op_code in unsupported_tosa_operators:
print("Unsupported Operator", op_code)
+ for opname in dir(TosaOp):
+ if op_code == getattr(TosaOp, opname):
+ print(f" {opname}")
return
op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
inputs = []
outputs = []
for idx in range(op_data.InputsLength()):
- input_tens = self.get_tensor_by_name(decode_str(op_data.Inputs(idx)))
+ input = decode_str(op_data.Inputs(idx))
+ input_tens = self.get_tensor_by_name(input)
inputs.append(input_tens)
+ if input_tens is None:
+ print(f"could not find named input tensor {input}::{input_tens}")
assert input_tens is not None
for idx in range(op_data.OutputsLength()):
- output_tens = self.get_tensor_by_name(decode_str(op_data.Outputs(idx)))
+ output = decode_str(op_data.Outputs(idx))
+ output_tens = self.get_tensor_by_name(output)
outputs.append(output_tens)
+ if output_tens is None:
+ print(f"could not find named output tensor {output}::{output_tens}")
assert output_tens is not None
- # Permutation attribute for TRANSPOSE is an input tensor in TOSA
- # TODO In order to optimise Depthwise spawning from TFLite Support for removing
- # Transpose of constant data.
- # Moving permutation to an attribute, to match internal graph representation for now
- perms = None
- if op_code == TosaOp.TRANSPOSE:
- perms = inputs.pop(1)
- indices = TOSA_IFM_INDICES
-
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
@@ -189,27 +188,21 @@ class TosaSubgraph:
if op.type.is_depthwise_conv2d_op():
op.attrs["depth_multiplier"] = op.weights.shape[3]
if op.type == Op.SplitSliceRead:
- op.read_offsets[0] = Shape4D.from_list(list(op.attrs["begin"]), 0)
+ op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0)
op.read_shapes[0] = op.attrs["size"]
- elif op.type == Op.Transpose:
- op.attrs["perms"] = perms.values
-
- if quant_serializer is not None:
- quant_info = quant_serializer.deserialize(op_data)
-
# TODO tensor zero points currently set here
# zero points part of Rescale operation, handled in tosa_graph_optimizer
- if "input_zp" in quant_info:
- self.set_tensor_zp(op.ifm, quant_info["input_zp"])
- if "weight_zp" in quant_info:
- self.set_tensor_zp(op.weights, quant_info["weight_zp"])
- if "output_zp" in quant_info:
- self.set_tensor_zp(op.ofm, quant_info["output_zp"])
- if "a_zp" in quant_info:
- self.set_tensor_zp(op.ifm, quant_info["a_zp"])
- if "b_zp" in quant_info:
- self.set_tensor_zp(op.ifm2, quant_info["b_zp"])
+ if "input_zp" in op.attrs:
+ self.set_tensor_zp(op.ifm, op.attrs["input_zp"])
+ if "weight_zp" in op.attrs:
+ self.set_tensor_zp(op.weights, op.attrs["weight_zp"])
+ if "output_zp" in op.attrs:
+ self.set_tensor_zp(op.ofm, op.attrs["output_zp"])
+ if "a_zp" in op.attrs:
+ self.set_tensor_zp(op.ifm, op.attrs["a_zp"])
+ if "b_zp" in op.attrs:
+ self.set_tensor_zp(op.ifm2, op.attrs["b_zp"])
def parse_tensor(self, tens_data):
name = decode_str(tens_data.Name())
@@ -260,7 +253,6 @@ class TosaSubgraph:
class TosaGraph:
def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
-
self.op_times = {}
if batch_size is None:
batch_size = 1
@@ -278,11 +270,15 @@ class TosaGraph:
parsing_step = "parsing version"
self.check_version(tosa_graph)
+ parsing_step = "parsing single main region"
+ assert 1 == tosa_graph.RegionsLength()
+ assert b"main" == tosa_graph.Regions(0).Name()
+
parsing_step = "parsing blocks length"
self.subgraphs = []
- for b_idx in range(tosa_graph.BlocksLength()):
+ for b_idx in range(tosa_graph.Regions(0).BlocksLength()):
parsing_step = f"parsing block {b_idx}"
- self.subgraphs.append(TosaSubgraph(self, tosa_graph.Blocks(b_idx)))
+ self.subgraphs.append(TosaSubgraph(self, tosa_graph.Regions(0).Blocks(b_idx)))
self.nng = Graph(self.name, self.batch_size)
for tosa_sg in self.subgraphs:
@@ -297,8 +293,8 @@ class TosaGraph:
def check_version(self, tosa_graph):
version = tosa_graph.Version()
- version_str = f"{version._major()}.{version._minor()}.{version._patch()}"
- if version_str != "0.22.0":
+ version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}"
+ if version_str != "0.80.0":
print(f"Unsupported TOSA version: {version_str}")
assert False