diff options
-rw-r--r-- | kernel/ethosu_network.c | 12 | ||||
-rw-r--r-- | tests/run_inference_test.cpp | 15 |
2 files changed, 24 insertions, 3 deletions
diff --git a/kernel/ethosu_network.c b/kernel/ethosu_network.c index 94354ed..f7871de 100644 --- a/kernel/ethosu_network.c +++ b/kernel/ethosu_network.c @@ -15,7 +15,6 @@ * You should have received a copy of the GNU General Public License * along with this program; if not, you can access it online at * http://www.gnu.org/licenses/gpl-2.0.html. - * */ /**************************************************************************** @@ -163,14 +162,21 @@ int ethosu_network_create(struct device *dev, net->buf = NULL; kref_init(&net->kref); - if (uapi->type == ETHOSU_UAPI_NETWORK_BUFFER) { + switch (uapi->type) { + case ETHOSU_UAPI_NETWORK_BUFFER: net->buf = ethosu_buffer_get_from_fd(uapi->fd); if (IS_ERR(net->buf)) { ret = PTR_ERR(net->buf); goto free_net; } - } else { + + break; + case ETHOSU_UAPI_NETWORK_INDEX: net->index = uapi->index; + break; + default: + ret = -EINVAL; + goto free_net; } ret = anon_inode_getfd("ethosu-network", ðosu_network_fops, net, diff --git a/tests/run_inference_test.cpp b/tests/run_inference_test.cpp index 480e26f..6075d7a 100644 --- a/tests/run_inference_test.cpp +++ b/tests/run_inference_test.cpp @@ -105,6 +105,20 @@ void testNetworkInfoUnparsableBuffer(const Device &device) { } catch (std::exception &e) { throw TestFailureException("NetworkInfo unparsable buffer test: ", e.what()); } } +void testNetworkInvalidType(const Device &device) { + const std::string expected_error = + std::string("IOCTL cmd=") + std::to_string(ETHOSU_IOCTL_NETWORK_CREATE) + " failed: " + std::strerror(EINVAL); + struct ethosu_uapi_network_create net_req = {}; + net_req.type = ETHOSU_UAPI_NETWORK_INDEX + 1; + try { + int r = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, &net_req); + FAIL(); + } catch (Exception &e) { + // The call is expected to throw + TEST_ASSERT(expected_error.compare(e.what()) == 0); + } catch (std::exception &e) { throw TestFailureException("NetworkCreate invalid type test: ", e.what()); } +} + void testRunInferenceBuffer(const Device &device) { try { auto networkBuffer = std::make_shared<Buffer>(device, sizeof(networkModelData)); @@ -154,6 +168,7 @@ int main() { testPing(device); testDriverVersion(device); testCapabilties(device); + testNetworkInvalidType(device); testNetworkInfoNotExistentIndex(device); testNetworkInfoBuffer(device); testNetworkInfoUnparsableBuffer(device); |