From 7db78969dc8ead72f3ded81b6d2a6a7ed798ea62 Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Mon, 25 May 2020 15:05:26 +0200 Subject: MLBEDSW-2067: added custom exceptions Added custom exceptions to handle different types of input errors. Also performed minor formatting changes using flake8/black. Change-Id: Ie5b05361507d5e569aff045757aec0a4a755ae98 Signed-off-by: Louis Verhaard --- ethosu/vela/test/test_model_reader.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 ethosu/vela/test/test_model_reader.py (limited to 'ethosu/vela/test') diff --git a/ethosu/vela/test/test_model_reader.py b/ethosu/vela/test/test_model_reader.py new file mode 100644 index 00000000..ee9a51e8 --- /dev/null +++ b/ethosu/vela/test/test_model_reader.py @@ -0,0 +1,40 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Unit tests for model_reader. +import pytest +from ethosu.vela import model_reader +from ethosu.vela.errors import InputFileError + + +def test_read_model_incorrect_extension(tmpdir): + # Tests read_model with a file name that does not end with .tflite + with pytest.raises(InputFileError): + model_reader.read_model("no_tflite_file.txt", model_reader.ModelReaderOptions()) + + +def test_read_model_corrupt_contents(tmpdir): + # Tests read_model with a corrupt .tflite file + fname = tmpdir.join("corrupt.tflite") + fname.write("abcde1234") + with pytest.raises(InputFileError): + model_reader.read_model(fname.strpath, model_reader.ModelReaderOptions()) + + +def test_read_model_file_not_found(tmpdir): + # Tests read_model with a .tflite file that does not exist + with pytest.raises(InputFileError): + model_reader.read_model("non_existing.tflite", model_reader.ModelReaderOptions()) -- cgit v1.2.1