aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_reader.py
diff options
context:
space:
mode:
authorRob Elliott <robert.elliott@arm.com>2023-08-17 14:27:06 +0000
committerRickard Bolin <rickard.bolin@arm.com>2023-08-21 16:14:51 +0000
commit00a15db3e1a188b25065d095152d701f4394cdc5 (patch)
tree96761b9f7ac3ad759f9f0ffbf63a6d0ef115ad14 /ethosu/vela/tosa_reader.py
parent8ea90edb75e5d2353aa91c264356fc9d460ca308 (diff)
downloadethos-u-vela-00a15db3e1a188b25065d095152d701f4394cdc5.tar.gz
Moving Vela to use TOSA v0.80.0 specification
* Using serialization_lib main branch to update statically copied files sha 5f920211ac23393a7b98a0d358bfbfc3232d5c8f (v0.80.0) * All files within the ethosu/vela/tosa are copied from that revision * Note: hope to move to serialization_lib as a pip module in future * Modified the ethosu/vela/{tosa_mapping,tosa_reader}.py to use v0.80.0 TOSA FlatBuffers implementation * These are the additional changes made to support this new version, with changes in the format of the FlatBuffers file and where various values are stored. Either changing from input to attribute, or moving to different attributes. Signed-off-by: Rob Elliott <robert.elliott@arm.com> Change-Id: I5e1fcc2a9964148619be3477adf1e88e84cbae2d
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r--ethosu/vela/tosa_reader.py66
1 files changed, 31 insertions, 35 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 56e0b1cb..56af59d8 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -40,7 +40,6 @@ from .tosa.Op import Op as TosaOp
from .tosa.TosaGraph import TosaGraph as TG
from .tosa_mapping import datatype_map
from .tosa_mapping import datatype_map_numpy
-from .tosa_mapping import TOSA_IFM_INDICES
from .tosa_mapping import tosa_operator_map
from .tosa_mapping import unsupported_tosa_operators
@@ -94,30 +93,30 @@ class TosaSubgraph:
op_code = op_data.Op()
if op_code in unsupported_tosa_operators:
print("Unsupported Operator", op_code)
+ for opname in dir(TosaOp):
+ if op_code == getattr(TosaOp, opname):
+ print(f" {opname}")
return
op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
inputs = []
outputs = []
for idx in range(op_data.InputsLength()):
- input_tens = self.get_tensor_by_name(decode_str(op_data.Inputs(idx)))
+ input = decode_str(op_data.Inputs(idx))
+ input_tens = self.get_tensor_by_name(input)
inputs.append(input_tens)
+ if input_tens is None:
+ print(f"could not find named input tensor {input}::{input_tens}")
assert input_tens is not None
for idx in range(op_data.OutputsLength()):
- output_tens = self.get_tensor_by_name(decode_str(op_data.Outputs(idx)))
+ output = decode_str(op_data.Outputs(idx))
+ output_tens = self.get_tensor_by_name(output)
outputs.append(output_tens)
+ if output_tens is None:
+ print(f"could not find named output tensor {output}::{output_tens}")
assert output_tens is not None
- # Permutation attribute for TRANSPOSE is an input tensor in TOSA
- # TODO In order to optimise Depthwise spawning from TFLite Support for removing
- # Transpose of constant data.
- # Moving permutation to an attribute, to match internal graph representation for now
- perms = None
- if op_code == TosaOp.TRANSPOSE:
- perms = inputs.pop(1)
- indices = TOSA_IFM_INDICES
-
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
@@ -189,27 +188,21 @@ class TosaSubgraph:
if op.type.is_depthwise_conv2d_op():
op.attrs["depth_multiplier"] = op.weights.shape[3]
if op.type == Op.SplitSliceRead:
- op.read_offsets[0] = Shape4D.from_list(list(op.attrs["begin"]), 0)
+ op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0)
op.read_shapes[0] = op.attrs["size"]
- elif op.type == Op.Transpose:
- op.attrs["perms"] = perms.values
-
- if quant_serializer is not None:
- quant_info = quant_serializer.deserialize(op_data)
-
# TODO tensor zero points currently set here
# zero points part of Rescale operation, handled in tosa_graph_optimizer
- if "input_zp" in quant_info:
- self.set_tensor_zp(op.ifm, quant_info["input_zp"])
- if "weight_zp" in quant_info:
- self.set_tensor_zp(op.weights, quant_info["weight_zp"])
- if "output_zp" in quant_info:
- self.set_tensor_zp(op.ofm, quant_info["output_zp"])
- if "a_zp" in quant_info:
- self.set_tensor_zp(op.ifm, quant_info["a_zp"])
- if "b_zp" in quant_info:
- self.set_tensor_zp(op.ifm2, quant_info["b_zp"])
+ if "input_zp" in op.attrs:
+ self.set_tensor_zp(op.ifm, op.attrs["input_zp"])
+ if "weight_zp" in op.attrs:
+ self.set_tensor_zp(op.weights, op.attrs["weight_zp"])
+ if "output_zp" in op.attrs:
+ self.set_tensor_zp(op.ofm, op.attrs["output_zp"])
+ if "a_zp" in op.attrs:
+ self.set_tensor_zp(op.ifm, op.attrs["a_zp"])
+ if "b_zp" in op.attrs:
+ self.set_tensor_zp(op.ifm2, op.attrs["b_zp"])
def parse_tensor(self, tens_data):
name = decode_str(tens_data.Name())
@@ -260,7 +253,6 @@ class TosaSubgraph:
class TosaGraph:
def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
-
self.op_times = {}
if batch_size is None:
batch_size = 1
@@ -278,11 +270,15 @@ class TosaGraph:
parsing_step = "parsing version"
self.check_version(tosa_graph)
+ parsing_step = "parsing single main region"
+ assert 1 == tosa_graph.RegionsLength()
+ assert b"main" == tosa_graph.Regions(0).Name()
+
parsing_step = "parsing blocks length"
self.subgraphs = []
- for b_idx in range(tosa_graph.BlocksLength()):
+ for b_idx in range(tosa_graph.Regions(0).BlocksLength()):
parsing_step = f"parsing block {b_idx}"
- self.subgraphs.append(TosaSubgraph(self, tosa_graph.Blocks(b_idx)))
+ self.subgraphs.append(TosaSubgraph(self, tosa_graph.Regions(0).Blocks(b_idx)))
self.nng = Graph(self.name, self.batch_size)
for tosa_sg in self.subgraphs:
@@ -297,8 +293,8 @@ class TosaGraph:
def check_version(self, tosa_graph):
version = tosa_graph.Version()
- version_str = f"{version._major()}.{version._minor()}.{version._patch()}"
- if version_str != "0.22.0":
+ version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}"
+ if version_str != "0.80.0":
print(f"Unsupported TOSA version: {version_str}")
assert False