From b35f0c681a1153e935bb8e40cd2cca2b04c7b5c0 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 15 Jun 2022 11:23:25 +0200 Subject: Add message_handler tests for inference cancellation Change-Id: Ifdacc47024250e34549d45377795501c371c69f5 --- applications/message_handler/test/CMakeLists.txt | 53 ++- .../test/cancel_reject_inference_test.cpp | 255 ++++++++++ applications/message_handler/test/main.cpp | 512 --------------------- .../message_handler/test/run_inference_test.cpp | 418 +++++++++++++++++ .../message_handler/test/test_assertions.hpp | 33 ++ applications/message_handler/test/test_helpers.hpp | 131 ++++++ 6 files changed, 866 insertions(+), 536 deletions(-) create mode 100644 applications/message_handler/test/cancel_reject_inference_test.cpp delete mode 100644 applications/message_handler/test/main.cpp create mode 100644 applications/message_handler/test/run_inference_test.cpp create mode 100644 applications/message_handler/test/test_assertions.hpp create mode 100644 applications/message_handler/test/test_helpers.hpp diff --git a/applications/message_handler/test/CMakeLists.txt b/applications/message_handler/test/CMakeLists.txt index 44fd471..7cba4c3 100644 --- a/applications/message_handler/test/CMakeLists.txt +++ b/applications/message_handler/test/CMakeLists.txt @@ -21,30 +21,35 @@ set(TEST_MESSAGE_HANDLER_MODEL_1 "" CACHE STRING "Path to built in model 1") set(TEST_MESSAGE_HANDLER_MODEL_2 "" CACHE STRING "Path to built in model 2") set(TEST_MESSAGE_HANDLER_MODEL_3 "" CACHE STRING "Path to built in model 3") -if(TARGET ethosu_core_driver) - file(GLOB models LIST_DIRECTORIES true "${CMAKE_CURRENT_SOURCE_DIR}/../../baremetal/models/${ETHOSU_TARGET_NPU_CONFIG}/*") -endif() +function(ethosu_add_message_handler_test testname) + if(TARGET ethosu_core_driver) + file(GLOB models LIST_DIRECTORIES true "${CMAKE_CURRENT_SOURCE_DIR}/../../baremetal/models/${ETHOSU_TARGET_NPU_CONFIG}/*") + endif() -foreach(model ${models}) - get_filename_component(modelname ${model} NAME) - ethosu_add_executable_test(message_handler_test_${modelname} - SOURCES - main.cpp - message_client.cpp - LIBRARIES - message_handler_lib - freertos_kernel - ethosu_mhu_dummy) + foreach(model ${models}) + get_filename_component(modelname ${model} NAME) + ethosu_add_executable_test(mh_${testname}_${modelname} + SOURCES + ${testname}.cpp + message_client.cpp + LIBRARIES + message_handler_lib + freertos_kernel + ethosu_mhu_dummy) - target_include_directories(message_handler_test_${modelname} PRIVATE - ../indexed_networks - ${model} - ${LINUX_DRIVER_STACK_PATH}/kernel) + target_include_directories(mh_${testname}_${modelname} PRIVATE + ../indexed_networks + ${model} + ${LINUX_DRIVER_STACK_PATH}/kernel) - target_compile_definitions(message_handler_test_${modelname} PRIVATE - TENSOR_ARENA_SIZE=${MESSAGE_HANDLER_ARENA_SIZE} - $<$:MODEL_0=${TEST_MESSAGE_HANDLER_MODEL_0}> - $<$:MODEL_1=${TEST_MESSAGE_HANDLER_MODEL_1}> - $<$:MODEL_2=${TEST_MESSAGE_HANDLER_MODEL_2}> - $<$:MODEL_3=${TEST_MESSAGE_HANDLER_MODEL_3}>) -endforeach() + target_compile_definitions(mh_${testname}_${modelname} PRIVATE + TENSOR_ARENA_SIZE=${MESSAGE_HANDLER_ARENA_SIZE} + $<$:MODEL_0=${TEST_MESSAGE_HANDLER_MODEL_0}> + $<$:MODEL_1=${TEST_MESSAGE_HANDLER_MODEL_1}> + $<$:MODEL_2=${TEST_MESSAGE_HANDLER_MODEL_2}> + $<$:MODEL_3=${TEST_MESSAGE_HANDLER_MODEL_3}>) + endforeach() +endfunction() + +ethosu_add_message_handler_test(run_inference_test) +ethosu_add_message_handler_test(cancel_reject_inference_test) diff --git a/applications/message_handler/test/cancel_reject_inference_test.cpp b/applications/message_handler/test/cancel_reject_inference_test.cpp new file mode 100644 index 0000000..9f4f9b4 --- /dev/null +++ b/applications/message_handler/test/cancel_reject_inference_test.cpp @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/**************************************************************************** + * Includes + ****************************************************************************/ + +#include "FreeRTOS.h" +#include "queue.h" +#include "semphr.h" +#include "task.h" + +#include +#include + +#include "ethosu_core_interface.h" +#include "indexed_networks.hpp" +#include "message_client.hpp" +#include "message_handler.hpp" +#include "message_queue.hpp" +#include "networks.hpp" +#include "test_assertions.hpp" +#include "test_helpers.hpp" + +#include +#include + +/* Disable semihosting */ +__asm(".global __use_no_semihosting\n\t"); + +using namespace EthosU; +using namespace MessageHandler; + +/**************************************************************************** + * Defines + ****************************************************************************/ + +// TensorArena static initialisation +constexpr size_t arenaSize = TENSOR_ARENA_SIZE; + +__attribute__((section(".bss.tensor_arena"), aligned(16))) uint8_t tensorArena[arenaSize]; + +// Message queue from remote host +__attribute__((section("ethosu_core_in_queue"))) MessageQueue::Queue<1000> inputMessageQueue; + +// Message queue to remote host +__attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outputMessageQueue; + +namespace { +Mailbox::MHUDummy mailbox; +} // namespace + +/**************************************************************************** + * Application + ****************************************************************************/ +namespace { + +struct TaskParams { + TaskParams() : + messageNotify(xSemaphoreCreateBinary()), + inferenceInputQueue(std::make_shared>()), + inferenceOutputQueue(xQueueCreate(5, sizeof(ethosu_core_inference_rsp))), + networks(std::make_shared()) {} + + SemaphoreHandle_t messageNotify; + // Used to pass inference requests to the inference runner task + std::shared_ptr> inferenceInputQueue; + // Queue for message responses to the remote host + QueueHandle_t inferenceOutputQueue; + // Networks provider + std::shared_ptr networks; +}; + +void messageTask(void *pvParameters) { + printf("Starting message task\n"); + TaskParams *params = reinterpret_cast(pvParameters); + + IncomingMessageHandler process(*inputMessageQueue.toQueue(), + *outputMessageQueue.toQueue(), + mailbox, + params->inferenceInputQueue, + params->inferenceOutputQueue, + params->messageNotify, + params->networks); + process.run(); +} + +void testCancelInference(MessageClient client) { + const uint64_t fake_inference_user_arg = 42; + const uint32_t network_index = 0; + ethosu_core_inference_req inference_req = + inferenceIndexedRequest(fake_inference_user_arg, network_index, nullptr, 0, nullptr, 0); + + const uint64_t fake_cancel_inference_user_arg = 55; + ethosu_core_cancel_inference_req cancel_req = {fake_cancel_inference_user_arg, fake_inference_user_arg}; + + ethosu_core_inference_rsp inference_rsp; + ethosu_core_cancel_inference_rsp cancel_rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, inference_req)); + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req)); + + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, inference_rsp)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp)); + + TEST_ASSERT(inference_req.user_arg == inference_rsp.user_arg); + TEST_ASSERT(inference_rsp.status == ETHOSU_CORE_STATUS_ABORTED); + + TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg); + TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_OK); +} + +void testCancelNonExistentInference(MessageClient client) { + const uint64_t fake_inference_user_arg = 42; + const uint64_t fake_cancel_inference_user_arg = 55; + ethosu_core_cancel_inference_req cancel_req = {fake_cancel_inference_user_arg, fake_inference_user_arg}; + ethosu_core_cancel_inference_rsp cancel_rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp)); + + TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg); + TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_ERROR); +} + +void testCannotCancelRunningInference(MessageClient client, + std::shared_ptr> inferenceInputQueue) { + const uint64_t fake_inference_user_arg = 42; + const uint32_t network_index = 0; + ethosu_core_inference_req inference_req = + inferenceIndexedRequest(fake_inference_user_arg, network_index, nullptr, 0, nullptr, 0); + + const uint64_t fake_cancel_inference_user_arg = 55; + ethosu_core_cancel_inference_req cancel_req = {fake_cancel_inference_user_arg, fake_inference_user_arg}; + ethosu_core_cancel_inference_rsp cancel_rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, inference_req)); + + // fake start of the inference by removing the inference from the queue + ethosu_core_inference_req start_req; + inferenceInputQueue->pop(start_req); + TEST_ASSERT(inference_req.user_arg == start_req.user_arg); + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp)); + + TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg); + TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_ERROR); +} + +void testRejectInference(MessageClient client) { + int runs = 6; + const uint64_t fake_inference_user_arg = 42; + const uint32_t network_index = 0; + const uint64_t fake_cancel_inference_user_arg = 55; + ethosu_core_inference_req req; + ethosu_core_inference_rsp rsp; + + for (int i = 0; i < runs; i++) { + + req = inferenceIndexedRequest(fake_inference_user_arg + i, network_index, nullptr, 0, nullptr, 0); + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); + vTaskDelay(150); + } + + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); + TEST_ASSERT(uint64_t(fake_inference_user_arg + runs - 1) == rsp.user_arg); + TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_REJECTED); + + // let's cleanup the queue + ethosu_core_cancel_inference_req cancel_req = {0, 0}; + ethosu_core_cancel_inference_rsp cancel_rsp; + ethosu_core_inference_rsp inference_rsp; + + for (int i = 0; i < runs - 1; i++) { + cancel_req.user_arg = fake_cancel_inference_user_arg + i; + cancel_req.inference_handle = fake_inference_user_arg + i; + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req)); + + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, inference_rsp)); + TEST_ASSERT(inference_rsp.user_arg = cancel_req.inference_handle); + + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp)); + TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg); + TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_OK); + } +} + +void clientTask(void *pvParameters) { + printf("Starting client task\n"); + TaskParams *params = reinterpret_cast(pvParameters); + + MessageClient client(*inputMessageQueue.toQueue(), *outputMessageQueue.toQueue(), mailbox); + + vTaskDelay(50); + + testCancelInference(client); + testCancelNonExistentInference(client); + testCannotCancelRunningInference(client, params->inferenceInputQueue); + testRejectInference(client); + + exit(0); +} + +/* + * Keep task parameters as global data as FreeRTOS resets the stack when the + * scheduler is started. + */ +TaskParams taskParams; + +} // namespace + +// FreeRTOS application. NOTE: Additional tasks may require increased heap size. +int main() { + BaseType_t ret; + + if (!mailbox.verifyHardware()) { + printf("Failed to verify mailbox hardware\n"); + return 1; + } + + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'messageTask'\n"); + return ret; + } + + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(clientTask, "clientTask", 1024, &taskParams, 2, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'messageTask'\n"); + return ret; + } + + // Start Scheduler + vTaskStartScheduler(); + + return 1; +} diff --git a/applications/message_handler/test/main.cpp b/applications/message_handler/test/main.cpp deleted file mode 100644 index 6a4d26d..0000000 --- a/applications/message_handler/test/main.cpp +++ /dev/null @@ -1,512 +0,0 @@ -/* - * Copyright (c) 2022 Arm Limited. - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/**************************************************************************** - * Includes - ****************************************************************************/ - -#include "FreeRTOS.h" -#include "queue.h" -#include "semphr.h" -#include "task.h" - -#include -#include - -#include "ethosu_core_interface.h" -#include "indexed_networks.hpp" -#include "input.h" -#include "message_client.hpp" -#include "message_handler.hpp" -#include "message_queue.hpp" -#include "networks.hpp" -#include "output.h" - -#include -#include - -/* Disable semihosting */ -__asm(".global __use_no_semihosting\n\t"); - -using namespace EthosU; -using namespace MessageHandler; - -/**************************************************************************** - * Defines - ****************************************************************************/ - -#define TEST_ASSERT(v) \ - do { \ - if (!(v)) { \ - fprintf(stderr, "%s:%d ERROR test failed: '%s'\n", __FILE__, __LINE__, #v); \ - exit(1); \ - } \ - } while (0) - -// TensorArena static initialisation -constexpr size_t arenaSize = TENSOR_ARENA_SIZE; - -__attribute__((section(".bss.tensor_arena"), aligned(16))) uint8_t tensorArena[arenaSize]; - -// Message queue from remote host -__attribute__((section("ethosu_core_in_queue"))) MessageQueue::Queue<1000> inputMessageQueue; - -// Message queue to remote host -__attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outputMessageQueue; - -namespace { -Mailbox::MHUDummy mailbox; -} // namespace - -/**************************************************************************** - * Application - ****************************************************************************/ -namespace { - -struct TaskParams { - TaskParams() : - messageNotify(xSemaphoreCreateBinary()), - inferenceInputQueue(std::make_shared>()), - inferenceOutputQueue(xQueueCreate(5, sizeof(ethosu_core_inference_rsp))), - networks(std::make_shared()) {} - - SemaphoreHandle_t messageNotify; - // Used to pass inference requests to the inference runner task - std::shared_ptr> inferenceInputQueue; - // Queue for message responses to the remote host - QueueHandle_t inferenceOutputQueue; - // Networks provider - std::shared_ptr networks; -}; - -void inferenceTask(void *pvParameters) { - printf("Starting inference task\n"); - TaskParams *params = reinterpret_cast(pvParameters); - - InferenceHandler process(tensorArena, - arenaSize, - params->inferenceInputQueue, - params->inferenceOutputQueue, - params->messageNotify, - params->networks); - - process.run(); -} - -void messageTask(void *pvParameters) { - printf("Starting message task\n"); - TaskParams *params = reinterpret_cast(pvParameters); - - IncomingMessageHandler process(*inputMessageQueue.toQueue(), - *outputMessageQueue.toQueue(), - mailbox, - params->inferenceInputQueue, - params->inferenceOutputQueue, - params->messageNotify, - params->networks); - process.run(); -} - -ethosu_core_network_info_req networkInfoIndexedRequest(uint64_t user_arg, uint32_t index) { - ethosu_core_network_info_req req = {user_arg, // user_arg - { // network - ETHOSU_CORE_NETWORK_INDEX, // type - {{ - index, // index - 0 // ignored padding of union - }}}}; - return req; -} - -ethosu_core_network_info_req networkInfoBufferRequest(uint64_t user_arg, unsigned char *ptr, uint32_t ptr_size) { - ethosu_core_network_info_req req = {user_arg, // user_arg - { // network - ETHOSU_CORE_NETWORK_BUFFER, // type - {{ - reinterpret_cast(ptr), // ptr - ptr_size // size - }}}}; - return req; -} - -ethosu_core_network_info_rsp networkInfoResponse(uint64_t user_arg) { - ethosu_core_network_info_rsp rsp = { - user_arg, // user_arg - "Vela Optimised", // description - 1, // ifm_count - {/* not comparable */}, // ifm_sizes - 1, // ofm_count - {/* not comparable */}, // ofm_sizes - ETHOSU_CORE_STATUS_OK // status - }; - return rsp; -} - -ethosu_core_inference_req -inferenceIndexedRequest(uint64_t user_arg, uint32_t index, uint8_t *data, uint32_t data_size) { - ethosu_core_inference_req req = { - user_arg, // user_arg - 1, // ifm_count - { // ifm: - { - reinterpret_cast(&inputData[0]), // ptr - sizeof(inputData) // size - }}, - 1, // ofm_count - { // ofm - { - reinterpret_cast(data), // ptr - data_size // size - }}, - { // network - ETHOSU_CORE_NETWORK_INDEX, // type - {{ - index, // index - 0 // ignored padding of union - }}}, - {0, 0, 0, 0, 0, 0, 0, 0}, // pmu_event_config - 0 // pmu_cycle_counter_enable - }; - return req; -} - -ethosu_core_inference_req -inferenceBufferRequest(uint64_t user_arg, unsigned char *ptr, uint32_t ptr_size, uint8_t *data, uint32_t data_size) { - ethosu_core_inference_req req = { - user_arg, // user_arg - 1, // ifm_count - { // ifm: - { - reinterpret_cast(&inputData[0]), // ptr - sizeof(inputData) // size - }}, - 1, // ofm_count - { // ofm - { - reinterpret_cast(data), // ptr - data_size // size - }}, - { // network - ETHOSU_CORE_NETWORK_BUFFER, // type - {{ - reinterpret_cast(ptr), // ptr - ptr_size // size - }}}, - {0, 0, 0, 0, 0, 0, 0, 0}, // pmu_event_config - 0 // pmu_cycle_counter_enable - }; - return req; -} - -void testPing(MessageClient client) { - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_PING)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_PONG)); -} - -void testVersion(MessageClient client) { - ethosu_core_msg_version ver; - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_VERSION_REQ)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_VERSION_RSP, ver)); - - TEST_ASSERT(ver.major == ETHOSU_CORE_MSG_VERSION_MAJOR); - TEST_ASSERT(ver.minor == ETHOSU_CORE_MSG_VERSION_MINOR); - TEST_ASSERT(ver.patch == ETHOSU_CORE_MSG_VERSION_PATCH); -} - -void readCapabilities(ethosu_core_msg_capabilities_rsp &rsp) { -#ifdef ETHOSU - struct ethosu_driver_version version; - ethosu_get_driver_version(&version); - - struct ethosu_hw_info info; - struct ethosu_driver *drv = ethosu_reserve_driver(); - ethosu_get_hw_info(drv, &info); - ethosu_release_driver(drv); - - rsp.version_status = info.version.version_status; - rsp.version_minor = info.version.version_minor; - rsp.version_major = info.version.version_major; - rsp.product_major = info.version.product_major; - rsp.arch_patch_rev = info.version.arch_patch_rev; - rsp.arch_minor_rev = info.version.arch_minor_rev; - rsp.arch_major_rev = info.version.arch_major_rev; - rsp.driver_patch_rev = version.patch; - rsp.driver_minor_rev = version.minor; - rsp.driver_major_rev = version.major; - rsp.macs_per_cc = info.cfg.macs_per_cc; - rsp.cmd_stream_version = info.cfg.cmd_stream_version; - rsp.custom_dma = info.cfg.custom_dma; -#endif -} - -void testCapabilities(MessageClient client) { - const uint64_t fake_user_arg = 42; - ethosu_core_capabilities_req req = {fake_user_arg}; - ethosu_core_msg_capabilities_rsp expected_rsp; - ethosu_core_msg_capabilities_rsp rsp; - - readCapabilities(expected_rsp); - expected_rsp.user_arg = req.user_arg; - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CAPABILITIES_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CAPABILITIES_RSP, rsp)); - - TEST_ASSERT(expected_rsp.version_status == rsp.version_status); - TEST_ASSERT(expected_rsp.version_minor == rsp.version_minor); - TEST_ASSERT(expected_rsp.version_major == rsp.version_major); - TEST_ASSERT(expected_rsp.product_major == rsp.product_major); - TEST_ASSERT(expected_rsp.arch_patch_rev == rsp.arch_patch_rev); - TEST_ASSERT(expected_rsp.arch_minor_rev == rsp.arch_minor_rev); - TEST_ASSERT(expected_rsp.arch_major_rev == rsp.arch_major_rev); - TEST_ASSERT(expected_rsp.driver_patch_rev == rsp.driver_patch_rev); - TEST_ASSERT(expected_rsp.driver_minor_rev == rsp.driver_minor_rev); - TEST_ASSERT(expected_rsp.driver_major_rev == rsp.driver_major_rev); - TEST_ASSERT(expected_rsp.macs_per_cc == rsp.macs_per_cc); - TEST_ASSERT(expected_rsp.cmd_stream_version == rsp.cmd_stream_version); - TEST_ASSERT(expected_rsp.custom_dma == rsp.custom_dma); - -#ifdef ETHOSU - TEST_ASSERT(rsp.version_status > 0); - TEST_ASSERT(rsp.product_major > 0); - TEST_ASSERT(rsp.arch_major_rev > 0 || rsp.arch_minor_rev > 0 || rsp.arch_patch_rev > 0); - TEST_ASSERT(rsp.driver_major_rev > 0 || rsp.driver_minor_rev > 0 || rsp.driver_patch_rev > 0); - TEST_ASSERT(rsp.macs_per_cc > 0); -#endif -} - -void testNetworkInfoIndex(MessageClient client) { - const uint64_t fake_user_arg = 42; - const uint32_t network_index = 0; - ethosu_core_network_info_req req = networkInfoIndexedRequest(fake_user_arg, network_index); - ethosu_core_network_info_rsp rsp; - ethosu_core_network_info_rsp expected_rsp = networkInfoResponse(fake_user_arg); - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); - - TEST_ASSERT(expected_rsp.user_arg == rsp.user_arg); - TEST_ASSERT(std::strncmp(expected_rsp.desc, rsp.desc, sizeof(rsp.desc)) == 0); - TEST_ASSERT(expected_rsp.ifm_count == rsp.ifm_count); - TEST_ASSERT(expected_rsp.ofm_count == rsp.ofm_count); - TEST_ASSERT(expected_rsp.status == rsp.status); -} - -void testNetworkInfoNonExistantIndex(MessageClient client) { - const uint64_t fake_user_arg = 42; - const uint32_t network_index = 1; - ethosu_core_network_info_req req = networkInfoIndexedRequest(fake_user_arg, network_index); - ethosu_core_network_info_rsp rsp; - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); - - TEST_ASSERT(fake_user_arg == rsp.user_arg); - TEST_ASSERT(ETHOSU_CORE_STATUS_ERROR == rsp.status); -} - -void testNetworkInfoBuffer(MessageClient client) { - const uint64_t fake_user_arg = 42; - uint32_t size = sizeof(Model0::networkModelData); - unsigned char *ptr = Model0::networkModelData; - ethosu_core_network_info_req req = networkInfoBufferRequest(fake_user_arg, ptr, size); - ethosu_core_network_info_rsp rsp; - ethosu_core_network_info_rsp expected_rsp = networkInfoResponse(fake_user_arg); - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); - - TEST_ASSERT(expected_rsp.user_arg == rsp.user_arg); - TEST_ASSERT(std::strncmp(expected_rsp.desc, rsp.desc, sizeof(rsp.desc)) == 0); - TEST_ASSERT(expected_rsp.ifm_count == rsp.ifm_count); - TEST_ASSERT(expected_rsp.ofm_count == rsp.ofm_count); - TEST_ASSERT(expected_rsp.status == rsp.status); -} - -void testNetworkInfoUnparsableBuffer(MessageClient client) { - const uint64_t fake_user_arg = 42; - uint32_t size = sizeof(Model0::networkModelData) / 4; - unsigned char *ptr = Model0::networkModelData + size; - ethosu_core_network_info_req req = networkInfoBufferRequest(fake_user_arg, ptr, size); - ethosu_core_network_info_rsp rsp; - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); - - TEST_ASSERT(42 == rsp.user_arg); - TEST_ASSERT(ETHOSU_CORE_STATUS_ERROR == rsp.status); -} - -void testInferenceRunIndex(MessageClient client) { - const uint64_t fake_user_arg = 42; - const uint32_t network_index = 0; - uint8_t data[sizeof(expectedOutputData)]; - ethosu_core_inference_req req = inferenceIndexedRequest(fake_user_arg, network_index, data, sizeof(data)); - ethosu_core_inference_rsp rsp; - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); - - TEST_ASSERT(req.user_arg == rsp.user_arg); - TEST_ASSERT(rsp.ofm_count == 1); - TEST_ASSERT(std::memcmp(expectedOutputData, data, sizeof(expectedOutputData)) == 0); - TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_OK); - TEST_ASSERT(rsp.pmu_cycle_counter_enable == req.pmu_cycle_counter_enable); - TEST_ASSERT(std::memcmp(rsp.pmu_event_config, req.pmu_event_config, sizeof(req.pmu_event_config)) == 0); -} - -void testInferenceRunNonExistingIndex(MessageClient client) { - const uint64_t fake_user_arg = 42; - const uint32_t network_index = 1; - uint8_t data[sizeof(expectedOutputData)]; - ethosu_core_inference_req req = inferenceIndexedRequest(fake_user_arg, network_index, data, sizeof(data)); - ethosu_core_inference_rsp rsp; - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); - - TEST_ASSERT(req.user_arg == rsp.user_arg); - TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_ERROR); -} - -void testInferenceRunBuffer(MessageClient client) { - const uint64_t fake_user_arg = 42; - uint32_t network_size = sizeof(Model0::networkModelData); - unsigned char *network_ptr = Model0::networkModelData; - uint8_t data[sizeof(expectedOutputData)]; - ethosu_core_inference_req req = - inferenceBufferRequest(fake_user_arg, network_ptr, network_size, data, sizeof(data)); - ethosu_core_inference_rsp rsp; - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); - - TEST_ASSERT(req.user_arg == rsp.user_arg); - TEST_ASSERT(rsp.ofm_count == 1); - TEST_ASSERT(std::memcmp(expectedOutputData, data, sizeof(expectedOutputData)) == 0); - TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_OK); - TEST_ASSERT(rsp.pmu_cycle_counter_enable == req.pmu_cycle_counter_enable); - TEST_ASSERT(std::memcmp(rsp.pmu_event_config, req.pmu_event_config, sizeof(req.pmu_event_config)) == 0); -} - -void testInferenceRunUnparsableBuffer(MessageClient client) { - const uint64_t fake_user_arg = 42; - uint32_t network_size = sizeof(Model0::networkModelData) / 4; - unsigned char *network_ptr = Model0::networkModelData + network_size; - uint8_t data[sizeof(expectedOutputData)]; - ethosu_core_inference_req req = - inferenceBufferRequest(fake_user_arg, network_ptr, network_size, data, sizeof(data)); - ethosu_core_inference_rsp rsp; - - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); - - TEST_ASSERT(req.user_arg == rsp.user_arg); - TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_ERROR); -} - -void testSequentiallyQueuedInferenceRuns(MessageClient client) { - int runs = 5; - uint8_t data[runs][sizeof(expectedOutputData)]; - const uint64_t fake_user_arg = 42; - const uint32_t network_index = 0; - ethosu_core_inference_req req; - ethosu_core_inference_rsp rsp[runs]; - - for (int i = 0; i < runs; i++) { - vTaskDelay(150); - - req = inferenceIndexedRequest(fake_user_arg + i, network_index, data[i], sizeof(expectedOutputData)); - TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); - } - - for (int i = 0; i < runs; i++) { - TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp[i])); - TEST_ASSERT(uint64_t(fake_user_arg + i) == rsp[i].user_arg); - TEST_ASSERT(rsp[i].ofm_count == 1); - TEST_ASSERT(std::memcmp(expectedOutputData, data[i], sizeof(expectedOutputData)) == 0); - TEST_ASSERT(rsp[i].status == ETHOSU_CORE_STATUS_OK); - TEST_ASSERT(rsp[i].pmu_cycle_counter_enable == req.pmu_cycle_counter_enable); - TEST_ASSERT(std::memcmp(rsp[i].pmu_event_config, req.pmu_event_config, sizeof(req.pmu_event_config)) == 0); - } -} - -void clientTask(void *) { - printf("Starting client task\n"); - - MessageClient client(*inputMessageQueue.toQueue(), *outputMessageQueue.toQueue(), mailbox); - - vTaskDelay(50); - - testPing(client); - testVersion(client); - testCapabilities(client); - testNetworkInfoIndex(client); - testNetworkInfoNonExistantIndex(client); - testNetworkInfoBuffer(client); - testNetworkInfoUnparsableBuffer(client); - testInferenceRunIndex(client); - testInferenceRunNonExistingIndex(client); - testInferenceRunBuffer(client); - testInferenceRunUnparsableBuffer(client); - testSequentiallyQueuedInferenceRuns(client); - - exit(0); -} - -/* - * Keep task parameters as global data as FreeRTOS resets the stack when the - * scheduler is started. - */ -TaskParams taskParams; - -} // namespace - -// FreeRTOS application. NOTE: Additional tasks may require increased heap size. -int main() { - BaseType_t ret; - - if (!mailbox.verifyHardware()) { - printf("Failed to verify mailbox hardware\n"); - return 1; - } - - // Task for handling incoming /outgoing messages from the remote host - ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr); - if (ret != pdPASS) { - printf("Failed to create 'messageTask'\n"); - return ret; - } - - ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &taskParams, 3, nullptr); - if (ret != pdPASS) { - printf("Failed to create 'inferenceTask'\n"); - return ret; - } - - // Task for handling incoming /outgoing messages from the remote host - ret = xTaskCreate(clientTask, "clientTask", 1024, nullptr, 2, nullptr); - if (ret != pdPASS) { - printf("Failed to create 'messageTask'\n"); - return ret; - } - - // Start Scheduler - vTaskStartScheduler(); - - return 1; -} diff --git a/applications/message_handler/test/run_inference_test.cpp b/applications/message_handler/test/run_inference_test.cpp new file mode 100644 index 0000000..d05224f --- /dev/null +++ b/applications/message_handler/test/run_inference_test.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/**************************************************************************** + * Includes + ****************************************************************************/ + +#include "FreeRTOS.h" +#include "queue.h" +#include "semphr.h" +#include "task.h" + +#include +#include + +#include "ethosu_core_interface.h" +#include "indexed_networks.hpp" +#include "input.h" +#include "message_client.hpp" +#include "message_handler.hpp" +#include "message_queue.hpp" +#include "networks.hpp" +#include "output.h" +#include "test_assertions.hpp" +#include "test_helpers.hpp" + +#include +#include + +/* Disable semihosting */ +__asm(".global __use_no_semihosting\n\t"); + +using namespace EthosU; +using namespace MessageHandler; + +/**************************************************************************** + * Defines + ****************************************************************************/ + +// TensorArena static initialisation +constexpr size_t arenaSize = TENSOR_ARENA_SIZE; + +__attribute__((section(".bss.tensor_arena"), aligned(16))) uint8_t tensorArena[arenaSize]; + +// Message queue from remote host +__attribute__((section("ethosu_core_in_queue"))) MessageQueue::Queue<1000> inputMessageQueue; + +// Message queue to remote host +__attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outputMessageQueue; + +namespace { +Mailbox::MHUDummy mailbox; +} // namespace + +/**************************************************************************** + * Application + ****************************************************************************/ +namespace { + +struct TaskParams { + TaskParams() : + messageNotify(xSemaphoreCreateBinary()), + inferenceInputQueue(std::make_shared>()), + inferenceOutputQueue(xQueueCreate(5, sizeof(ethosu_core_inference_rsp))), + networks(std::make_shared()) {} + + SemaphoreHandle_t messageNotify; + // Used to pass inference requests to the inference runner task + std::shared_ptr> inferenceInputQueue; + // Queue for message responses to the remote host + QueueHandle_t inferenceOutputQueue; + // Networks provider + std::shared_ptr networks; +}; + +void inferenceTask(void *pvParameters) { + printf("Starting inference task\n"); + TaskParams *params = reinterpret_cast(pvParameters); + + InferenceHandler process(tensorArena, + arenaSize, + params->inferenceInputQueue, + params->inferenceOutputQueue, + params->messageNotify, + params->networks); + + process.run(); +} + +void messageTask(void *pvParameters) { + printf("Starting message task\n"); + TaskParams *params = reinterpret_cast(pvParameters); + + IncomingMessageHandler process(*inputMessageQueue.toQueue(), + *outputMessageQueue.toQueue(), + mailbox, + params->inferenceInputQueue, + params->inferenceOutputQueue, + params->messageNotify, + params->networks); + process.run(); +} + +void testPing(MessageClient client) { + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_PING)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_PONG)); +} + +void testVersion(MessageClient client) { + ethosu_core_msg_version ver; + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_VERSION_REQ)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_VERSION_RSP, ver)); + + TEST_ASSERT(ver.major == ETHOSU_CORE_MSG_VERSION_MAJOR); + TEST_ASSERT(ver.minor == ETHOSU_CORE_MSG_VERSION_MINOR); + TEST_ASSERT(ver.patch == ETHOSU_CORE_MSG_VERSION_PATCH); +} + +void readCapabilities(ethosu_core_msg_capabilities_rsp &rsp) { +#ifdef ETHOSU + struct ethosu_driver_version version; + ethosu_get_driver_version(&version); + + struct ethosu_hw_info info; + struct ethosu_driver *drv = ethosu_reserve_driver(); + ethosu_get_hw_info(drv, &info); + ethosu_release_driver(drv); + + rsp.version_status = info.version.version_status; + rsp.version_minor = info.version.version_minor; + rsp.version_major = info.version.version_major; + rsp.product_major = info.version.product_major; + rsp.arch_patch_rev = info.version.arch_patch_rev; + rsp.arch_minor_rev = info.version.arch_minor_rev; + rsp.arch_major_rev = info.version.arch_major_rev; + rsp.driver_patch_rev = version.patch; + rsp.driver_minor_rev = version.minor; + rsp.driver_major_rev = version.major; + rsp.macs_per_cc = info.cfg.macs_per_cc; + rsp.cmd_stream_version = info.cfg.cmd_stream_version; + rsp.custom_dma = info.cfg.custom_dma; +#endif +} + +void testCapabilities(MessageClient client) { + const uint64_t fake_user_arg = 42; + ethosu_core_capabilities_req req = {fake_user_arg}; + ethosu_core_msg_capabilities_rsp expected_rsp; + ethosu_core_msg_capabilities_rsp rsp; + + readCapabilities(expected_rsp); + expected_rsp.user_arg = req.user_arg; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CAPABILITIES_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CAPABILITIES_RSP, rsp)); + + TEST_ASSERT(expected_rsp.version_status == rsp.version_status); + TEST_ASSERT(expected_rsp.version_minor == rsp.version_minor); + TEST_ASSERT(expected_rsp.version_major == rsp.version_major); + TEST_ASSERT(expected_rsp.product_major == rsp.product_major); + TEST_ASSERT(expected_rsp.arch_patch_rev == rsp.arch_patch_rev); + TEST_ASSERT(expected_rsp.arch_minor_rev == rsp.arch_minor_rev); + TEST_ASSERT(expected_rsp.arch_major_rev == rsp.arch_major_rev); + TEST_ASSERT(expected_rsp.driver_patch_rev == rsp.driver_patch_rev); + TEST_ASSERT(expected_rsp.driver_minor_rev == rsp.driver_minor_rev); + TEST_ASSERT(expected_rsp.driver_major_rev == rsp.driver_major_rev); + TEST_ASSERT(expected_rsp.macs_per_cc == rsp.macs_per_cc); + TEST_ASSERT(expected_rsp.cmd_stream_version == rsp.cmd_stream_version); + TEST_ASSERT(expected_rsp.custom_dma == rsp.custom_dma); + +#ifdef ETHOSU + TEST_ASSERT(rsp.version_status > 0); + TEST_ASSERT(rsp.product_major > 0); + TEST_ASSERT(rsp.arch_major_rev > 0 || rsp.arch_minor_rev > 0 || rsp.arch_patch_rev > 0); + TEST_ASSERT(rsp.driver_major_rev > 0 || rsp.driver_minor_rev > 0 || rsp.driver_patch_rev > 0); + TEST_ASSERT(rsp.macs_per_cc > 0); +#endif +} + +void testNetworkInfoIndex(MessageClient client) { + const uint64_t fake_user_arg = 42; + const uint32_t network_index = 0; + ethosu_core_network_info_req req = networkInfoIndexedRequest(fake_user_arg, network_index); + ethosu_core_network_info_rsp rsp; + ethosu_core_network_info_rsp expected_rsp = networkInfoResponse(fake_user_arg); + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); + + TEST_ASSERT(expected_rsp.user_arg == rsp.user_arg); + TEST_ASSERT(std::strncmp(expected_rsp.desc, rsp.desc, sizeof(rsp.desc)) == 0); + TEST_ASSERT(expected_rsp.ifm_count == rsp.ifm_count); + TEST_ASSERT(expected_rsp.ofm_count == rsp.ofm_count); + TEST_ASSERT(expected_rsp.status == rsp.status); +} + +void testNetworkInfoNonExistantIndex(MessageClient client) { + const uint64_t fake_user_arg = 42; + const uint32_t network_index = 1; + ethosu_core_network_info_req req = networkInfoIndexedRequest(fake_user_arg, network_index); + ethosu_core_network_info_rsp rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); + + TEST_ASSERT(fake_user_arg == rsp.user_arg); + TEST_ASSERT(ETHOSU_CORE_STATUS_ERROR == rsp.status); +} + +void testNetworkInfoBuffer(MessageClient client) { + const uint64_t fake_user_arg = 42; + uint32_t size = sizeof(Model0::networkModelData); + unsigned char *ptr = Model0::networkModelData; + ethosu_core_network_info_req req = networkInfoBufferRequest(fake_user_arg, ptr, size); + ethosu_core_network_info_rsp rsp; + ethosu_core_network_info_rsp expected_rsp = networkInfoResponse(fake_user_arg); + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); + + TEST_ASSERT(expected_rsp.user_arg == rsp.user_arg); + TEST_ASSERT(std::strncmp(expected_rsp.desc, rsp.desc, sizeof(rsp.desc)) == 0); + TEST_ASSERT(expected_rsp.ifm_count == rsp.ifm_count); + TEST_ASSERT(expected_rsp.ofm_count == rsp.ofm_count); + TEST_ASSERT(expected_rsp.status == rsp.status); +} + +void testNetworkInfoUnparsableBuffer(MessageClient client) { + const uint64_t fake_user_arg = 42; + uint32_t size = sizeof(Model0::networkModelData) / 4; + unsigned char *ptr = Model0::networkModelData + size; + ethosu_core_network_info_req req = networkInfoBufferRequest(fake_user_arg, ptr, size); + ethosu_core_network_info_rsp rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); + + TEST_ASSERT(42 == rsp.user_arg); + TEST_ASSERT(ETHOSU_CORE_STATUS_ERROR == rsp.status); +} + +void testInferenceRunIndex(MessageClient client) { + const uint64_t fake_user_arg = 42; + const uint32_t network_index = 0; + uint8_t data[sizeof(expectedOutputData)]; + ethosu_core_inference_req req = + inferenceIndexedRequest(fake_user_arg, network_index, inputData, sizeof(inputData), data, sizeof(data)); + ethosu_core_inference_rsp rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); + + TEST_ASSERT(req.user_arg == rsp.user_arg); + TEST_ASSERT(rsp.ofm_count == 1); + TEST_ASSERT(std::memcmp(expectedOutputData, data, sizeof(expectedOutputData)) == 0); + TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_OK); + TEST_ASSERT(rsp.pmu_cycle_counter_enable == req.pmu_cycle_counter_enable); + TEST_ASSERT(std::memcmp(rsp.pmu_event_config, req.pmu_event_config, sizeof(req.pmu_event_config)) == 0); +} + +void testInferenceRunNonExistingIndex(MessageClient client) { + const uint64_t fake_user_arg = 42; + const uint32_t network_index = 1; + uint8_t data[sizeof(expectedOutputData)]; + ethosu_core_inference_req req = + inferenceIndexedRequest(fake_user_arg, network_index, inputData, sizeof(inputData), data, sizeof(data)); + ethosu_core_inference_rsp rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); + + TEST_ASSERT(req.user_arg == rsp.user_arg); + TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_ERROR); +} + +void testInferenceRunBuffer(MessageClient client) { + const uint64_t fake_user_arg = 42; + uint32_t network_size = sizeof(Model0::networkModelData); + unsigned char *network_ptr = Model0::networkModelData; + uint8_t data[sizeof(expectedOutputData)]; + ethosu_core_inference_req req = inferenceBufferRequest( + fake_user_arg, network_ptr, network_size, inputData, sizeof(inputData), data, sizeof(data)); + ethosu_core_inference_rsp rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); + + TEST_ASSERT(req.user_arg == rsp.user_arg); + TEST_ASSERT(rsp.ofm_count == 1); + TEST_ASSERT(std::memcmp(expectedOutputData, data, sizeof(expectedOutputData)) == 0); + TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_OK); + TEST_ASSERT(rsp.pmu_cycle_counter_enable == req.pmu_cycle_counter_enable); + TEST_ASSERT(std::memcmp(rsp.pmu_event_config, req.pmu_event_config, sizeof(req.pmu_event_config)) == 0); +} + +void testInferenceRunUnparsableBuffer(MessageClient client) { + const uint64_t fake_user_arg = 42; + uint32_t network_size = sizeof(Model0::networkModelData) / 4; + unsigned char *network_ptr = Model0::networkModelData + network_size; + uint8_t data[sizeof(expectedOutputData)]; + ethosu_core_inference_req req = inferenceBufferRequest( + fake_user_arg, network_ptr, network_size, inputData, sizeof(inputData), data, sizeof(data)); + ethosu_core_inference_rsp rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); + + TEST_ASSERT(req.user_arg == rsp.user_arg); + TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_ERROR); +} + +void testSequentiallyQueuedInferenceRuns(MessageClient client) { + int runs = 5; + uint8_t data[runs][sizeof(expectedOutputData)]; + const uint64_t fake_user_arg = 42; + const uint32_t network_index = 0; + ethosu_core_inference_req req; + ethosu_core_inference_rsp rsp[runs]; + + for (int i = 0; i < runs; i++) { + vTaskDelay(150); + + req = inferenceIndexedRequest( + fake_user_arg + i, network_index, inputData, sizeof(inputData), data[i], sizeof(data[i])); + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); + } + + for (int i = 0; i < runs; i++) { + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp[i])); + TEST_ASSERT(uint64_t(fake_user_arg + i) == rsp[i].user_arg); + TEST_ASSERT(rsp[i].ofm_count == 1); + TEST_ASSERT(std::memcmp(expectedOutputData, data[i], sizeof(expectedOutputData)) == 0); + TEST_ASSERT(rsp[i].status == ETHOSU_CORE_STATUS_OK); + TEST_ASSERT(rsp[i].pmu_cycle_counter_enable == req.pmu_cycle_counter_enable); + TEST_ASSERT(std::memcmp(rsp[i].pmu_event_config, req.pmu_event_config, sizeof(req.pmu_event_config)) == 0); + } +} + +void clientTask(void *) { + printf("Starting client task\n"); + + MessageClient client(*inputMessageQueue.toQueue(), *outputMessageQueue.toQueue(), mailbox); + + vTaskDelay(50); + + testPing(client); + testVersion(client); + testCapabilities(client); + testNetworkInfoIndex(client); + testNetworkInfoNonExistantIndex(client); + testNetworkInfoBuffer(client); + testNetworkInfoUnparsableBuffer(client); + testInferenceRunIndex(client); + testInferenceRunNonExistingIndex(client); + testInferenceRunBuffer(client); + testInferenceRunUnparsableBuffer(client); + testSequentiallyQueuedInferenceRuns(client); + + exit(0); +} + +/* + * Keep task parameters as global data as FreeRTOS resets the stack when the + * scheduler is started. + */ +TaskParams taskParams; + +} // namespace + +// FreeRTOS application. NOTE: Additional tasks may require increased heap size. +int main() { + BaseType_t ret; + + if (!mailbox.verifyHardware()) { + printf("Failed to verify mailbox hardware\n"); + return 1; + } + + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'messageTask'\n"); + return ret; + } + + ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &taskParams, 3, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'inferenceTask'\n"); + return ret; + } + + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(clientTask, "clientTask", 1024, nullptr, 2, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'messageTask'\n"); + return ret; + } + + // Start Scheduler + vTaskStartScheduler(); + + return 1; +} diff --git a/applications/message_handler/test/test_assertions.hpp b/applications/message_handler/test/test_assertions.hpp new file mode 100644 index 0000000..7c4cb5c --- /dev/null +++ b/applications/message_handler/test/test_assertions.hpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TEST_ASSERTIONS_H +#define TEST_ASSERTIONS_H + +#include +#include + +#define TEST_ASSERT(v) \ + do { \ + if (!(v)) { \ + fprintf(stderr, "%s:%d ERROR test failed: '%s'\n", __FILE__, __LINE__, #v); \ + exit(1); \ + } \ + } while (0) + +#endif diff --git a/applications/message_handler/test/test_helpers.hpp b/applications/message_handler/test/test_helpers.hpp new file mode 100644 index 0000000..0440b58 --- /dev/null +++ b/applications/message_handler/test/test_helpers.hpp @@ -0,0 +1,131 @@ + +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TEST_HELPERS_H +#define TEST_HELPERS_H + +#include +#include + +#include "ethosu_core_interface.h" + +namespace MessageHandler { + +ethosu_core_network_info_req networkInfoIndexedRequest(uint64_t user_arg, uint32_t index) { + ethosu_core_network_info_req req = {user_arg, // user_arg + { // network + ETHOSU_CORE_NETWORK_INDEX, // type + {{ + index, // index + 0 // ignored padding of union + }}}}; + return req; +} + +ethosu_core_network_info_req networkInfoBufferRequest(uint64_t user_arg, unsigned char *ptr, uint32_t ptr_size) { + ethosu_core_network_info_req req = {user_arg, // user_arg + { // network + ETHOSU_CORE_NETWORK_BUFFER, // type + {{ + reinterpret_cast(ptr), // ptr + ptr_size // size + }}}}; + return req; +} + +ethosu_core_network_info_rsp networkInfoResponse(uint64_t user_arg) { + ethosu_core_network_info_rsp rsp = { + user_arg, // user_arg + "Vela Optimised", // description + 1, // ifm_count + {/* not comparable */}, // ifm_sizes + 1, // ofm_count + {/* not comparable */}, // ofm_sizes + ETHOSU_CORE_STATUS_OK // status + }; + return rsp; +} + +ethosu_core_inference_req inferenceIndexedRequest(uint64_t user_arg, + uint32_t index, + unsigned char *input_data, + uint32_t input_data_size, + uint8_t *output_data, + uint32_t output_data_size) { + ethosu_core_inference_req req = { + user_arg, // user_arg + 1, // ifm_count + { // ifm + { + reinterpret_cast(input_data), // ptr + input_data_size // size + }}, + 1, // ofm_count + { // ofm + { + reinterpret_cast(output_data), // ptr + output_data_size // size + }}, + { // network + ETHOSU_CORE_NETWORK_INDEX, // type + {{ + index, // index + 0 // ignored padding of union + }}}, + {0, 0, 0, 0, 0, 0, 0, 0}, // pmu_event_config + 0 // pmu_cycle_counter_enable + }; + return req; +} + +ethosu_core_inference_req inferenceBufferRequest(uint64_t user_arg, + unsigned char *ptr, + uint32_t ptr_size, + unsigned char *input_data, + uint32_t input_data_size, + uint8_t *output_data, + uint32_t output_data_size) { + ethosu_core_inference_req req = { + user_arg, // user_arg + 1, // ifm_count + { // ifm + { + reinterpret_cast(input_data), // ptr + input_data_size // size + }}, + 1, // ofm_count + { // ofm + { + reinterpret_cast(output_data), // ptr + output_data_size // size + }}, + { // network + ETHOSU_CORE_NETWORK_BUFFER, // type + {{ + reinterpret_cast(ptr), // ptr + ptr_size // size + }}}, + {0, 0, 0, 0, 0, 0, 0, 0}, // pmu_event_config + 0 // pmu_cycle_counter_enable + }; + return req; +} +} // namespace MessageHandler + +#endif -- cgit v1.2.1