diff options
author | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2021-07-26 16:13:12 +0200 |
---|---|---|
committer | Jonas Ohlsson <jonas.ohlsson@arm.com> | 2021-07-27 11:06:27 +0200 |
commit | 45e653dbd81633b8d78215b16a9b2205e39dd8e2 (patch) | |
tree | 18b3073eac45e9e8d69a616ae96d7a3fbdef9663 /ethosu/vela/model_reader.py | |
parent | c2449827ec55f49b6087e3e385fb3c4f6776dc6a (diff) | |
download | ethos-u-vela-45e653dbd81633b8d78215b16a9b2205e39dd8e2.tar.gz |
MLBEDSW-4853: Refactor supported operators
Refactor supported operators by breaking out model semantics
into its own class. Model semantics checked right after model
read.
Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com>
Change-Id: If442b189efcd91dda01af60b2b3adedfacdf2fad
Diffstat (limited to 'ethosu/vela/model_reader.py')
-rw-r--r-- | ethosu/vela/model_reader.py | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py index f48645d3..3b094361 100644 --- a/ethosu/vela/model_reader.py +++ b/ethosu/vela/model_reader.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -15,7 +15,9 @@ # limitations under the License. # Description: # Dispatcher for reading a neural network model. +from . import tflite_model_semantic from . import tflite_reader +from . import tosa_model_semantic from . import tosa_reader from .errors import InputFileError from .nn_graph import NetworkType @@ -39,16 +41,17 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis output_node_names = [] if initialisation_nodes is None: initialisation_nodes = [] - return ( - tflite_reader.read_tflite( - fname, - options.batch_size, - feed_dict=feed_dict, - output_node_names=output_node_names, - initialisation_nodes=initialisation_nodes, - ), - NetworkType.TFLite, + + nng = tflite_reader.read_tflite( + fname, + options.batch_size, + feed_dict=feed_dict, + output_node_names=output_node_names, + initialisation_nodes=initialisation_nodes, ) + nng = tflite_model_semantic.tflite_semantic_checker(nng) + + return (nng, NetworkType.TFLite) elif fname.endswith(".tosa"): if feed_dict is None: feed_dict = {} @@ -57,15 +60,15 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis if initialisation_nodes is None: initialisation_nodes = [] - return ( - tosa_reader.read_tosa( - fname, - options.batch_size, - feed_dict=feed_dict, - output_node_names=output_node_names, - initialisation_nodes=initialisation_nodes, - ), - NetworkType.TOSA, + nng = tosa_reader.read_tosa( + fname, + options.batch_size, + feed_dict=feed_dict, + output_node_names=output_node_names, + initialisation_nodes=initialisation_nodes, ) + nng = tosa_model_semantic.tosa_semantic_checker(nng) + + return (nng, NetworkType.TOSA) else: raise InputFileError(fname, "Unsupported file extension. Only .tflite and .tosa files are supported") |