aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kernel/ethosu_network.c12
-rw-r--r--tests/run_inference_test.cpp15
2 files changed, 24 insertions, 3 deletions
diff --git a/kernel/ethosu_network.c b/kernel/ethosu_network.c
index 94354ed..f7871de 100644
--- a/kernel/ethosu_network.c
+++ b/kernel/ethosu_network.c
@@ -15,7 +15,6 @@
* You should have received a copy of the GNU General Public License
* along with this program; if not, you can access it online at
* http://www.gnu.org/licenses/gpl-2.0.html.
- *
*/
/****************************************************************************
@@ -163,14 +162,21 @@ int ethosu_network_create(struct device *dev,
net->buf = NULL;
kref_init(&net->kref);
- if (uapi->type == ETHOSU_UAPI_NETWORK_BUFFER) {
+ switch (uapi->type) {
+ case ETHOSU_UAPI_NETWORK_BUFFER:
net->buf = ethosu_buffer_get_from_fd(uapi->fd);
if (IS_ERR(net->buf)) {
ret = PTR_ERR(net->buf);
goto free_net;
}
- } else {
+
+ break;
+ case ETHOSU_UAPI_NETWORK_INDEX:
net->index = uapi->index;
+ break;
+ default:
+ ret = -EINVAL;
+ goto free_net;
}
ret = anon_inode_getfd("ethosu-network", &ethosu_network_fops, net,
diff --git a/tests/run_inference_test.cpp b/tests/run_inference_test.cpp
index 480e26f..6075d7a 100644
--- a/tests/run_inference_test.cpp
+++ b/tests/run_inference_test.cpp
@@ -105,6 +105,20 @@ void testNetworkInfoUnparsableBuffer(const Device &device) {
} catch (std::exception &e) { throw TestFailureException("NetworkInfo unparsable buffer test: ", e.what()); }
}
+void testNetworkInvalidType(const Device &device) {
+ const std::string expected_error =
+ std::string("IOCTL cmd=") + std::to_string(ETHOSU_IOCTL_NETWORK_CREATE) + " failed: " + std::strerror(EINVAL);
+ struct ethosu_uapi_network_create net_req = {};
+ net_req.type = ETHOSU_UAPI_NETWORK_INDEX + 1;
+ try {
+ int r = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, &net_req);
+ FAIL();
+ } catch (Exception &e) {
+ // The call is expected to throw
+ TEST_ASSERT(expected_error.compare(e.what()) == 0);
+ } catch (std::exception &e) { throw TestFailureException("NetworkCreate invalid type test: ", e.what()); }
+}
+
void testRunInferenceBuffer(const Device &device) {
try {
auto networkBuffer = std::make_shared<Buffer>(device, sizeof(networkModelData));
@@ -154,6 +168,7 @@ int main() {
testPing(device);
testDriverVersion(device);
testCapabilties(device);
+ testNetworkInvalidType(device);
testNetworkInfoNotExistentIndex(device);
testNetworkInfoBuffer(device);
testNetworkInfoUnparsableBuffer(device);