aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/model_reader.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-05-25 15:05:26 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit7db78969dc8ead72f3ded81b6d2a6a7ed798ea62 (patch)
tree011bcf579cc8e0f007f9564a98cc5c05df34322b /ethosu/vela/model_reader.py
parent78792223369fa34dacd0e69e189af035283da2ae (diff)
downloadethos-u-vela-7db78969dc8ead72f3ded81b6d2a6a7ed798ea62.tar.gz
MLBEDSW-2067: added custom exceptions
Added custom exceptions to handle different types of input errors. Also performed minor formatting changes using flake8/black. Change-Id: Ie5b05361507d5e569aff045757aec0a4a755ae98 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/model_reader.py')
-rw-r--r--ethosu/vela/model_reader.py27
1 files changed, 16 insertions, 11 deletions
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index d1cdc9bd..6deb2538 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -15,6 +15,9 @@
# limitations under the License.
# Description:
# Dispatcher for reading a neural network model.
+from . import tflite_reader
+from .errors import InputFileError
+from .errors import VelaError
class ModelReaderOptions:
@@ -29,15 +32,17 @@ class ModelReaderOptions:
def read_model(fname, options, feed_dict={}, output_node_names=[], initialisation_nodes=[]):
if fname.endswith(".tflite"):
- from . import tflite_reader
-
- nng = tflite_reader.read_tflite(
- fname,
- options.batch_size,
- feed_dict=feed_dict,
- output_node_names=output_node_names,
- initialisation_nodes=initialisation_nodes,
- )
+ try:
+ return tflite_reader.read_tflite(
+ fname,
+ options.batch_size,
+ feed_dict=feed_dict,
+ output_node_names=output_node_names,
+ initialisation_nodes=initialisation_nodes,
+ )
+ except VelaError as e:
+ raise e
+ except Exception as e:
+ raise InputFileError(fname, str(e))
else:
- assert 0, "Unknown model format"
- return nng
+ raise InputFileError(fname, "Unknown input file format. Only .tflite files are supported")