diff options
Diffstat (limited to 'applications/message_handler/test/main.cpp')
-rw-r--r-- | applications/message_handler/test/main.cpp | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/applications/message_handler/test/main.cpp b/applications/message_handler/test/main.cpp new file mode 100644 index 0000000..0d88611 --- /dev/null +++ b/applications/message_handler/test/main.cpp @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/**************************************************************************** + * Includes + ****************************************************************************/ + +#include "FreeRTOS.h" +#include "queue.h" +#include "semphr.h" +#include "task.h" + +#include <inttypes.h> +#include <stdio.h> + +#include "ethosu_core_interface.h" +#include "indexed_networks.hpp" +#include "input.h" +#include "message_client.hpp" +#include "message_handler.hpp" +#include "message_queue.hpp" +#include "networks.hpp" +#include "output.h" + +#include <mailbox.hpp> +#include <mhu_dummy.hpp> + +/* Disable semihosting */ +__asm(".global __use_no_semihosting\n\t"); + +using namespace EthosU; +using namespace MessageHandler; + +/**************************************************************************** + * Defines + ****************************************************************************/ + +#define TEST_ASSERT(v) \ + do { \ + if (!(v)) { \ + fprintf(stderr, "%s:%d ERROR test failed: '%s'\n", __FILE__, __LINE__, #v); \ + exit(1); \ + } \ + } while (0) + +// Nr. of tasks to process inferences with, reserves driver & runs inference (Normally 1 per NPU, but not a must) +#if defined(ETHOSU) && defined(ETHOSU_NPU_COUNT) && ETHOSU_NPU_COUNT > 0 +constexpr size_t NUM_PARALLEL_TASKS = ETHOSU_NPU_COUNT; +#else +constexpr size_t NUM_PARALLEL_TASKS = 1; +#endif + +// TensorArena static initialisation +constexpr size_t arenaSize = TENSOR_ARENA_SIZE; + +__attribute__((section(".bss.tensor_arena"), aligned(16))) uint8_t tensorArena[NUM_PARALLEL_TASKS][arenaSize]; + +// Message queue from remote host +__attribute__((section("ethosu_core_in_queue"))) MessageQueue::Queue<1000> inputMessageQueue; + +// Message queue to remote host +__attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outputMessageQueue; + +namespace { +Mailbox::MHUDummy mailbox; +} // namespace + +/**************************************************************************** + * Application + ****************************************************************************/ +namespace { + +struct TaskParams { + TaskParams() : + messageNotify(xSemaphoreCreateBinary()), + inferenceInputQueue(std::make_shared<Queue<ethosu_core_inference_req>>()), + inferenceOutputQueue(xQueueCreate(10, sizeof(ethosu_core_inference_rsp))), + networks(std::make_shared<WithIndexedNetworks>()) {} + + SemaphoreHandle_t messageNotify; + // Used to pass inference requests to the inference runner task + std::shared_ptr<Queue<ethosu_core_inference_req>> inferenceInputQueue; + // Queue for message responses to the remote host + QueueHandle_t inferenceOutputQueue; + // Networks provider + std::shared_ptr<Networks> networks; +}; + +struct InferenceTaskParams { + TaskParams *taskParams; + uint8_t *arena; +}; + +void inferenceTask(void *pvParameters) { + printf("Starting inference task\n"); + InferenceTaskParams *params = reinterpret_cast<InferenceTaskParams *>(pvParameters); + + InferenceHandler process(params->arena, + arenaSize, + params->taskParams->inferenceInputQueue, + params->taskParams->inferenceOutputQueue, + params->taskParams->messageNotify, + params->taskParams->networks); + + process.run(); +} + +void messageTask(void *pvParameters) { + printf("Starting message task\n"); + TaskParams *params = reinterpret_cast<TaskParams *>(pvParameters); + + IncomingMessageHandler process(*inputMessageQueue.toQueue(), + *outputMessageQueue.toQueue(), + mailbox, + params->inferenceInputQueue, + params->inferenceOutputQueue, + params->messageNotify, + params->networks); + process.run(); +} + +void testPing(MessageClient client) { + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_PING)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_PONG)); +} + +void testVersion(MessageClient client) { + ethosu_core_msg_version ver; + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_VERSION_REQ)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_VERSION_RSP, ver)); + + TEST_ASSERT(ver.major == ETHOSU_CORE_MSG_VERSION_MAJOR); + TEST_ASSERT(ver.minor == ETHOSU_CORE_MSG_VERSION_MINOR); + TEST_ASSERT(ver.patch == ETHOSU_CORE_MSG_VERSION_PATCH); +} + +void readCapabilities(ethosu_core_msg_capabilities_rsp &rsp) { +#ifdef ETHOSU + struct ethosu_driver_version version; + ethosu_get_driver_version(&version); + + struct ethosu_hw_info info; + struct ethosu_driver *drv = ethosu_reserve_driver(); + ethosu_get_hw_info(drv, &info); + ethosu_release_driver(drv); + + rsp.version_status = info.version.version_status; + rsp.version_minor = info.version.version_minor; + rsp.version_major = info.version.version_major; + rsp.product_major = info.version.product_major; + rsp.arch_patch_rev = info.version.arch_patch_rev; + rsp.arch_minor_rev = info.version.arch_minor_rev; + rsp.arch_major_rev = info.version.arch_major_rev; + rsp.driver_patch_rev = version.patch; + rsp.driver_minor_rev = version.minor; + rsp.driver_major_rev = version.major; + rsp.macs_per_cc = info.cfg.macs_per_cc; + rsp.cmd_stream_version = info.cfg.cmd_stream_version; + rsp.custom_dma = info.cfg.custom_dma; +#endif +} + +void testCapabilities(MessageClient client) { + const uint64_t fake_user_arg = 42; + ethosu_core_capabilities_req req = {fake_user_arg}; + ethosu_core_msg_capabilities_rsp expected_rsp; + ethosu_core_msg_capabilities_rsp rsp; + + readCapabilities(expected_rsp); + expected_rsp.user_arg = req.user_arg; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CAPABILITIES_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CAPABILITIES_RSP, rsp)); + + TEST_ASSERT(expected_rsp.version_status == rsp.version_status); + TEST_ASSERT(expected_rsp.version_minor == rsp.version_minor); + TEST_ASSERT(expected_rsp.version_major == rsp.version_major); + TEST_ASSERT(expected_rsp.product_major == rsp.product_major); + TEST_ASSERT(expected_rsp.arch_patch_rev == rsp.arch_patch_rev); + TEST_ASSERT(expected_rsp.arch_minor_rev == rsp.arch_minor_rev); + TEST_ASSERT(expected_rsp.arch_major_rev == rsp.arch_major_rev); + TEST_ASSERT(expected_rsp.driver_patch_rev == rsp.driver_patch_rev); + TEST_ASSERT(expected_rsp.driver_minor_rev == rsp.driver_minor_rev); + TEST_ASSERT(expected_rsp.driver_major_rev == rsp.driver_major_rev); + TEST_ASSERT(expected_rsp.macs_per_cc == rsp.macs_per_cc); + TEST_ASSERT(expected_rsp.cmd_stream_version == rsp.cmd_stream_version); + TEST_ASSERT(expected_rsp.custom_dma == rsp.custom_dma); + +#ifdef ETHOSU + TEST_ASSERT(rsp.version_major > 0 || rsp.version_minor > 0); + TEST_ASSERT(rsp.product_major > 0); + TEST_ASSERT(rsp.arch_major_rev > 0 || rsp.arch_minor_rev > 0 || rsp.arch_patch_rev > 0); + TEST_ASSERT(rsp.driver_major_rev > 0 || rsp.driver_minor_rev > 0 || rsp.driver_patch_rev > 0); + TEST_ASSERT(rsp.macs_per_cc > 0); +#endif +} + +void testNetworkInfo(MessageClient client) { + const uint64_t fake_user_arg = 42; + ethosu_core_network_info_req req = {fake_user_arg, // user_arg + { // network + ETHOSU_CORE_NETWORK_INDEX, // type + {{ + 0, // index + 0 // ignored padding of union + }}}}; + ethosu_core_network_info_rsp rsp; + ethosu_core_network_info_rsp expected_rsp = { + req.user_arg, // user_arg + "Vela Optimised", // description + 1, // ifm_count + {/* not comparable */}, // ifm_sizes + 1, // ofm_count + {/* not comparable */}, // ofm_sizes + 0 // status + }; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)); + + TEST_ASSERT(expected_rsp.user_arg == rsp.user_arg); + TEST_ASSERT(std::strncmp(expected_rsp.desc, rsp.desc, sizeof(rsp.desc)) == 0); + TEST_ASSERT(expected_rsp.ifm_count == rsp.ifm_count); + TEST_ASSERT(expected_rsp.ofm_count == rsp.ofm_count); + TEST_ASSERT(expected_rsp.status == rsp.status); +} + +void testInferenceRun(MessageClient client) { + uint8_t data[sizeof(expectedOutputData)]; + const uint64_t fake_user_arg = 42; + ethosu_core_inference_req req = { + fake_user_arg, // user_arg + 1, // ifm_count + { // ifm: + { + reinterpret_cast<uint32_t>(&inputData[0]), // ptr + sizeof(inputData) // size + }}, + 1, // ofm_count + { // ofm + { + reinterpret_cast<uint32_t>(&data[0]), // ptr + sizeof(data) // size + }}, + { // network + ETHOSU_CORE_NETWORK_INDEX, // type + {{ + 0, // index + 0 // ignored padding of union + }}}, + {0, 0, 0, 0, 0, 0, 0, 0}, // pmu_event_config + 0 // pmu_cycle_counter_enable + }; + ethosu_core_inference_rsp rsp; + + TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req)); + TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)); + + TEST_ASSERT(req.user_arg == rsp.user_arg); + TEST_ASSERT(rsp.ofm_count == 1); + TEST_ASSERT(std::memcmp(expectedOutputData, data, sizeof(expectedOutputData)) == 0); + TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_OK); + TEST_ASSERT(rsp.pmu_cycle_counter_enable == req.pmu_cycle_counter_enable); + TEST_ASSERT(std::memcmp(rsp.pmu_event_config, req.pmu_event_config, sizeof(req.pmu_event_config)) == 0); +} + +void clientTask(void *) { + printf("Starting client task\n"); + + MessageClient client(*inputMessageQueue.toQueue(), *outputMessageQueue.toQueue(), mailbox); + + vTaskDelay(10); + + testPing(client); + testVersion(client); + testCapabilities(client); + testNetworkInfo(client); + testInferenceRun(client); + + exit(0); +} + +/* + * Keep task parameters as global data as FreeRTOS resets the stack when the + * scheduler is started. + */ +TaskParams taskParams; +InferenceTaskParams infParams[NUM_PARALLEL_TASKS]; + +} // namespace + +// FreeRTOS application. NOTE: Additional tasks may require increased heap size. +int main() { + BaseType_t ret; + + if (!mailbox.verifyHardware()) { + printf("Failed to verify mailbox hardware\n"); + return 1; + } + + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'messageTask'\n"); + return ret; + } + + // One inference task for each NPU + for (size_t n = 0; n < NUM_PARALLEL_TASKS; n++) { + infParams[n].taskParams = &taskParams; + infParams[n].arena = reinterpret_cast<uint8_t *>(&tensorArena[n]); + ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &infParams[n], 3, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'inferenceTask%d'\n", n); + return ret; + } + } + + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(clientTask, "clientTask", 512, nullptr, 2, nullptr); + if (ret != pdPASS) { + printf("Failed to create 'messageTask'\n"); + return ret; + } + + // Start Scheduler + vTaskStartScheduler(); + + return 1; +} |