From 45e653dbd81633b8d78215b16a9b2205e39dd8e2 Mon Sep 17 00:00:00 2001 From: Jonas Ohlsson Date: Mon, 26 Jul 2021 16:13:12 +0200 Subject: MLBEDSW-4853: Refactor supported operators Refactor supported operators by breaking out model semantics into its own class. Model semantics checked right after model read. Signed-off-by: Jonas Ohlsson Change-Id: If442b189efcd91dda01af60b2b3adedfacdf2fad --- ethosu/vela/model_reader.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) (limited to 'ethosu/vela/model_reader.py') diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py index f48645d3..3b094361 100644 --- a/ethosu/vela/model_reader.py +++ b/ethosu/vela/model_reader.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -15,7 +15,9 @@ # limitations under the License. # Description: # Dispatcher for reading a neural network model. +from . import tflite_model_semantic from . import tflite_reader +from . import tosa_model_semantic from . import tosa_reader from .errors import InputFileError from .nn_graph import NetworkType @@ -39,16 +41,17 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis output_node_names = [] if initialisation_nodes is None: initialisation_nodes = [] - return ( - tflite_reader.read_tflite( - fname, - options.batch_size, - feed_dict=feed_dict, - output_node_names=output_node_names, - initialisation_nodes=initialisation_nodes, - ), - NetworkType.TFLite, + + nng = tflite_reader.read_tflite( + fname, + options.batch_size, + feed_dict=feed_dict, + output_node_names=output_node_names, + initialisation_nodes=initialisation_nodes, ) + nng = tflite_model_semantic.tflite_semantic_checker(nng) + + return (nng, NetworkType.TFLite) elif fname.endswith(".tosa"): if feed_dict is None: feed_dict = {} @@ -57,15 +60,15 @@ def read_model(fname, options, feed_dict=None, output_node_names=None, initialis if initialisation_nodes is None: initialisation_nodes = [] - return ( - tosa_reader.read_tosa( - fname, - options.batch_size, - feed_dict=feed_dict, - output_node_names=output_node_names, - initialisation_nodes=initialisation_nodes, - ), - NetworkType.TOSA, + nng = tosa_reader.read_tosa( + fname, + options.batch_size, + feed_dict=feed_dict, + output_node_names=output_node_names, + initialisation_nodes=initialisation_nodes, ) + nng = tosa_model_semantic.tosa_semantic_checker(nng) + + return (nng, NetworkType.TOSA) else: raise InputFileError(fname, "Unsupported file extension. Only .tflite and .tosa files are supported") -- cgit v1.2.1