diff options
Diffstat (limited to 'ethosu/vela/model_reader.py')
-rw-r--r-- | ethosu/vela/model_reader.py | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py index f48645d3..3b094361 100644 --- a/ethosu/vela/model_reader.py +++ b/ethosu/vela/model_reader.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -15,7 +15,9 @@ # limitations under the License. # Description: # Dispatcher for reading a neural network model. +from . import tflite_model_semantic from . import tflite_reader +from . import tosa_model_semantic from . import tosa_reader from .errors import InputFileError from .nn_graph import NetworkType @@ -39,16 +41,17 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis output_node_names = [] if initialisation_nodes is None: initialisation_nodes = [] - return ( - tflite_reader.read_tflite( - fname, - options.batch_size, - feed_dict=feed_dict, - output_node_names=output_node_names, - initialisation_nodes=initialisation_nodes, - ), - NetworkType.TFLite, + + nng = tflite_reader.read_tflite( + fname, + options.batch_size, + feed_dict=feed_dict, + output_node_names=output_node_names, + initialisation_nodes=initialisation_nodes, ) + nng = tflite_model_semantic.tflite_semantic_checker(nng) + + return (nng, NetworkType.TFLite) elif fname.endswith(".tosa"): if feed_dict is None: feed_dict = {} @@ -57,15 +60,15 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis if initialisation_nodes is None: initialisation_nodes = [] - return ( - tosa_reader.read_tosa( - fname, - options.batch_size, - feed_dict=feed_dict, - output_node_names=output_node_names, - initialisation_nodes=initialisation_nodes, - ), - NetworkType.TOSA, + nng = tosa_reader.read_tosa( + fname, + options.batch_size, + feed_dict=feed_dict, + output_node_names=output_node_names, + initialisation_nodes=initialisation_nodes, ) + nng = tosa_model_semantic.tosa_semantic_checker(nng) + + return (nng, NetworkType.TOSA) else: raise InputFileError(fname, "Unsupported file extension. Only .tflite and .tosa files are supported") |