aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/model_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/model_reader.py')
-rw-r--r--ethosu/vela/model_reader.py27
1 files changed, 16 insertions, 11 deletions
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index d1cdc9bd..6deb2538 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -15,6 +15,9 @@
# limitations under the License.
# Description:
# Dispatcher for reading a neural network model.
+from . import tflite_reader
+from .errors import InputFileError
+from .errors import VelaError
class ModelReaderOptions:
@@ -29,15 +32,17 @@ class ModelReaderOptions:
def read_model(fname, options, feed_dict={}, output_node_names=[], initialisation_nodes=[]):
if fname.endswith(".tflite"):
- from . import tflite_reader
-
- nng = tflite_reader.read_tflite(
- fname,
- options.batch_size,
- feed_dict=feed_dict,
- output_node_names=output_node_names,
- initialisation_nodes=initialisation_nodes,
- )
+ try:
+ return tflite_reader.read_tflite(
+ fname,
+ options.batch_size,
+ feed_dict=feed_dict,
+ output_node_names=output_node_names,
+ initialisation_nodes=initialisation_nodes,
+ )
+ except VelaError as e:
+ raise e
+ except Exception as e:
+ raise InputFileError(fname, str(e))
else:
- assert 0, "Unknown model format"
- return nng
+ raise InputFileError(fname, "Unknown input file format. Only .tflite files are supported")