From 45d47991f745094e328f32e769c22d811d397b1d Mon Sep 17 00:00:00 2001 From: Mikael Olsson Date: Thu, 12 Oct 2023 15:32:56 +0200 Subject: Fix type validation in the network create UAPI Currently, the network create UAPI will assume that any network type that isn't a buffer is an index. This means that the Linux kernel NPU driver will accept any network type value and the user won't get any feedback that they have specified an incorrect type. To resolve this, the Linux kernel NPU driver will now return -EINVAL if an unknown network type is given and a test has been added to validate this behavior. Change-Id: Ib7d9f5d5451897787981aae61a4e0a6650a73e05 Signed-off-by: Mikael Olsson --- kernel/ethosu_network.c | 12 +++++++++--- tests/run_inference_test.cpp | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/kernel/ethosu_network.c b/kernel/ethosu_network.c index 94354ed..f7871de 100644 --- a/kernel/ethosu_network.c +++ b/kernel/ethosu_network.c @@ -15,7 +15,6 @@ * You should have received a copy of the GNU General Public License * along with this program; if not, you can access it online at * http://www.gnu.org/licenses/gpl-2.0.html. - * */ /**************************************************************************** @@ -163,14 +162,21 @@ int ethosu_network_create(struct device *dev, net->buf = NULL; kref_init(&net->kref); - if (uapi->type == ETHOSU_UAPI_NETWORK_BUFFER) { + switch (uapi->type) { + case ETHOSU_UAPI_NETWORK_BUFFER: net->buf = ethosu_buffer_get_from_fd(uapi->fd); if (IS_ERR(net->buf)) { ret = PTR_ERR(net->buf); goto free_net; } - } else { + + break; + case ETHOSU_UAPI_NETWORK_INDEX: net->index = uapi->index; + break; + default: + ret = -EINVAL; + goto free_net; } ret = anon_inode_getfd("ethosu-network", ðosu_network_fops, net, diff --git a/tests/run_inference_test.cpp b/tests/run_inference_test.cpp index 480e26f..6075d7a 100644 --- a/tests/run_inference_test.cpp +++ b/tests/run_inference_test.cpp @@ -105,6 +105,20 @@ void testNetworkInfoUnparsableBuffer(const Device &device) { } catch (std::exception &e) { throw TestFailureException("NetworkInfo unparsable buffer test: ", e.what()); } } +void testNetworkInvalidType(const Device &device) { + const std::string expected_error = + std::string("IOCTL cmd=") + std::to_string(ETHOSU_IOCTL_NETWORK_CREATE) + " failed: " + std::strerror(EINVAL); + struct ethosu_uapi_network_create net_req = {}; + net_req.type = ETHOSU_UAPI_NETWORK_INDEX + 1; + try { + int r = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, &net_req); + FAIL(); + } catch (Exception &e) { + // The call is expected to throw + TEST_ASSERT(expected_error.compare(e.what()) == 0); + } catch (std::exception &e) { throw TestFailureException("NetworkCreate invalid type test: ", e.what()); } +} + void testRunInferenceBuffer(const Device &device) { try { auto networkBuffer = std::make_shared(device, sizeof(networkModelData)); @@ -154,6 +168,7 @@ int main() { testPing(device); testDriverVersion(device); testCapabilties(device); + testNetworkInvalidType(device); testNetworkInfoNotExistentIndex(device); testNetworkInfoBuffer(device); testNetworkInfoUnparsableBuffer(device); -- cgit v1.2.1