From 585ce694dbbebfe5ba737fe94888343cb8976ac3 Mon Sep 17 00:00:00 2001 From: Kristofer Jonsson Date: Tue, 8 Mar 2022 13:28:05 +0100 Subject: Firmware resident model Support building a network model into the firmware binary. The model can be placed by a scatter file or linker script in for example a low lateny high bandwidth memory like SRAM. Change-Id: Ic742abed65e20f0da4ded7adefb039389b68b767 --- applications/message_handler/CMakeLists.txt | 12 +- applications/message_handler/message_handler.cpp | 168 +++++++++++++++++++---- applications/message_handler/message_handler.hpp | 2 + applications/message_handler/model_template.hpp | 23 ++++ 4 files changed, 174 insertions(+), 31 deletions(-) create mode 100644 applications/message_handler/model_template.hpp diff --git a/applications/message_handler/CMakeLists.txt b/applications/message_handler/CMakeLists.txt index 0cb95cc..72d930f 100644 --- a/applications/message_handler/CMakeLists.txt +++ b/applications/message_handler/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2021 Arm Limited. All rights reserved. +# Copyright (c) 2020-2022 Arm Limited. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -22,6 +22,10 @@ if (NOT TARGET freertos_kernel) endif() set(MESSAGE_HANDLER_ARENA_SIZE 2000000 CACHE STRING "Size of message handler tensor arena") +set(MESSAGE_HANDLER_MODEL_0 FALSE CACHE STRING "Path to built in model 0") +set(MESSAGE_HANDLER_MODEL_1 FALSE CACHE STRING "Path to built in model 1") +set(MESSAGE_HANDLER_MODEL_2 FALSE CACHE STRING "Path to built in model 2") +set(MESSAGE_HANDLER_MODEL_3 FALSE CACHE STRING "Path to built in model 3") ethosu_add_executable(message_handler SOURCES @@ -39,4 +43,8 @@ target_include_directories(message_handler PRIVATE ${LINUX_DRIVER_STACK_PATH}/kernel) target_compile_definitions(message_handler PRIVATE - TENSOR_ARENA_SIZE=${MESSAGE_HANDLER_ARENA_SIZE}) + TENSOR_ARENA_SIZE=${MESSAGE_HANDLER_ARENA_SIZE} + $<$:MODEL_0=${MESSAGE_HANDLER_MODEL_0}> + $<$:MODEL_1=${MESSAGE_HANDLER_MODEL_1}> + $<$:MODEL_2=${MESSAGE_HANDLER_MODEL_2}> + $<$:MODEL_3=${MESSAGE_HANDLER_MODEL_3}>) diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp index f95d5f6..e530712 100644 --- a/applications/message_handler/message_handler.cpp +++ b/applications/message_handler/message_handler.cpp @@ -32,15 +32,110 @@ #include #include +#define XSTRINGIFY(src) #src +#define STRINGIFY(src) XSTRINGIFY(src) + using namespace EthosU; using namespace MessageQueue; +/**************************************************************************** + * Models + ****************************************************************************/ + +namespace { +#if defined(__has_include) + +#if __has_include(STRINGIFY(MODEL_0)) +namespace Model0 { +#include STRINGIFY(MODEL_0) +} +#endif + +#if __has_include(STRINGIFY(MODEL_1)) +namespace Model1 { +#include STRINGIFY(MODEL_1) +} +#endif + +#if __has_include(STRINGIFY(MODEL_2)) +namespace Model2 { +#include STRINGIFY(MODEL_2) +} +#endif + +#if __has_include(STRINGIFY(MODEL_3)) +namespace Model3 { +#include STRINGIFY(MODEL_3) +} +#endif + +#endif +} // namespace + namespace MessageHandler { /**************************************************************************** * IncomingMessageHandler ****************************************************************************/ +namespace { +bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) { + data = reinterpret_cast(buffer.ptr); + size = buffer.size; + return false; +} + +bool getNetwork(const uint32_t index, void *&data, size_t &size) { + switch (index) { +#if __has_include(STRINGIFY(MODEL_0)) + case 0: + data = reinterpret_cast(Model0::networkModel); + size = sizeof(Model0::networkModel); + break; +#endif + +#if __has_include(STRINGIFY(MODEL_1)) + case 1: + data = reinterpret_cast(Model1::networkModel); + size = sizeof(Model1::networkModel); + break; +#endif + +#if __has_include(STRINGIFY(MODEL_2)) + case 2: + data = reinterpret_cast(Model2::networkModel); + size = sizeof(Model2::networkModel); + break; +#endif + +#if __has_include(STRINGIFY(MODEL_3)) + case 3: + data = reinterpret_cast(Model3::networkModel); + size = sizeof(Model3::networkModel); + break; +#endif + + default: + printf("Error: Network model index out of range. index=%u\n", index); + return true; + } + + return false; +} + +bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &size) { + switch (buffer.type) { + case ETHOSU_CORE_NETWORK_BUFFER: + return getNetwork(buffer.buffer, data, size); + case ETHOSU_CORE_NETWORK_INDEX: + return getNetwork(buffer.index, data, size); + default: + printf("Error: Unsupported network model type. type=%u\n", buffer.type); + return true; + } +} +}; // namespace + IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue, Mailbox::Mailbox &_mailbox, QueueHandle_t _inferenceQueue, @@ -151,10 +246,17 @@ bool IncomingMessageHandler::handleMessage() { break; } - printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network={0x%" PRIx32 ", %" PRIu32 "}, \n", + printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network_type=%" PRIu32 ", ", inference.user_arg, - inference.network.ptr, - inference.network.size); + inference.network.type); + + if (inference.network.type == ETHOSU_CORE_NETWORK_BUFFER) { + printf("network.buffer={0x%" PRIx32 ", %" PRIu32 "},\n", + inference.network.buffer.ptr, + inference.network.buffer.size); + } else { + printf("network.index=%" PRIu32 ",\n", inference.network.index); + } printf("ifm_count=%" PRIu32 ", ifm=[", inference.ifm_count); for (uint32_t i = 0; i < inference.ifm_count; ++i) { @@ -228,43 +330,32 @@ void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_ currentRsp = &rsp; /* - * Setup inference job + * Run inference */ - InferenceProcess::DataPtr networkModel(reinterpret_cast(req.network.ptr), req.network.size); + InferenceProcess::InferenceJob job; + bool failed = getInferenceJob(req, job); - std::vector ifm; - for (uint32_t i = 0; i < req.ifm_count; ++i) { - ifm.push_back(InferenceProcess::DataPtr(reinterpret_cast(req.ifm[i].ptr), req.ifm[i].size)); + if (!failed) { + job.invalidate(); + failed = inference.runJob(job); + job.clean(); } - std::vector ofm; - for (uint32_t i = 0; i < req.ofm_count; ++i) { - ofm.push_back(InferenceProcess::DataPtr(reinterpret_cast(req.ofm[i].ptr), req.ofm[i].size)); - } - - InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, {}, 0, this); - - /* - * Run inference - */ - - job.invalidate(); - bool failed = inference.runJob(job); - job.clean(); - /* * Print PMU counters */ - const int numEvents = std::min(static_cast(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX); + if (!failed) { + const int numEvents = std::min(static_cast(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX); - for (int i = 0; i < numEvents; i++) { - printf("ethosu_pmu_cntr%d : %" PRIu32 "\n", i, rsp.pmu_event_count[i]); - } + for (int i = 0; i < numEvents; i++) { + printf("ethosu_pmu_cntr%d : %" PRIu32 "\n", i, rsp.pmu_event_count[i]); + } - if (rsp.pmu_cycle_counter_enable) { - printf("ethosu_pmu_cycle_cntr : %" PRIu64 " cycles\n", rsp.pmu_cycle_counter_count); + if (rsp.pmu_cycle_counter_enable) { + printf("ethosu_pmu_cycle_cntr : %" PRIu64 " cycles\n", rsp.pmu_cycle_counter_count); + } } /* @@ -283,6 +374,25 @@ void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_ currentRsp = nullptr; } +bool InferenceHandler::getInferenceJob(const ethosu_core_inference_req &req, InferenceProcess::InferenceJob &job) { + bool failed = getNetwork(req.network, job.networkModel.data, job.networkModel.size); + if (failed) { + return true; + } + + for (uint32_t i = 0; i < req.ifm_count; ++i) { + job.input.push_back(InferenceProcess::DataPtr(reinterpret_cast(req.ifm[i].ptr), req.ifm[i].size)); + } + + for (uint32_t i = 0; i < req.ofm_count; ++i) { + job.output.push_back(InferenceProcess::DataPtr(reinterpret_cast(req.ofm[i].ptr), req.ofm[i].size)); + } + + job.externalContext = this; + + return false; +} + /**************************************************************************** * OutgoingMessageHandler ****************************************************************************/ diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp index 36768ee..ee063de 100644 --- a/applications/message_handler/message_handler.hpp +++ b/applications/message_handler/message_handler.hpp @@ -63,6 +63,8 @@ public: private: void runInference(EthosU::ethosu_core_inference_req &req, EthosU::ethosu_core_inference_rsp &rsp); + bool getInferenceJob(const EthosU::ethosu_core_inference_req &req, InferenceProcess::InferenceJob &job); + friend void ::ethosu_inference_begin(struct ethosu_driver *drv, void *userArg); friend void ::ethosu_inference_end(struct ethosu_driver *drv, void *userArg); diff --git a/applications/message_handler/model_template.hpp b/applications/message_handler/model_template.hpp new file mode 100644 index 0000000..06636b2 --- /dev/null +++ b/applications/message_handler/model_template.hpp @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +__attribute__((section(".sram.data"), aligned(16))) uint8_t networkModel[] = { + /* Add network model here */ +}; -- cgit v1.2.1