From 35de9e63d9c2fe0a557637ac104d7d73382d2d4a Mon Sep 17 00:00:00 2001 From: Kristofer Jonsson Date: Tue, 8 Mar 2022 13:25:45 +0100 Subject: Firmware resident model Support referencing a network model by index that has been built into the firmware binary. Change-Id: Idd5294376ea82503dfeafe1203dcc0694d296dfe --- utils/inference_runner/inference_runner.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) (limited to 'utils') diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp index a72a954..08a47b7 100644 --- a/utils/inference_runner/inference_runner.cpp +++ b/utils/inference_runner/inference_runner.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. All rights reserved. * * SPDX-License-Identifier: Apache-2.0 * @@ -39,6 +39,7 @@ void help(const string exe) { cerr << "Arguments:\n"; cerr << " -h --help Print this help message.\n"; cerr << " -n --network File to read network from.\n"; + cerr << " --index Network model index, stored in firmware binary.\n"; cerr << " -i --ifm File to read IFM from.\n"; cerr << " -o --ofm File to write IFM to.\n"; cerr << " -P --pmu [0.." << Inference::getMaxPmuEventCounters() << "] eventid.\n"; @@ -138,6 +139,7 @@ ostream &operator<<(ostream &os, Buffer &buf) { int main(int argc, char *argv[]) { const string exe = argv[0]; string networkArg; + int networkIndex = -1; list ifmArg; vector enabledCounters(Inference::getMaxPmuEventCounters()); string ofmArg; @@ -154,6 +156,9 @@ int main(int argc, char *argv[]) { } else if (arg == "--network" || arg == "-n") { rangeCheck(++i, argc, arg); networkArg = argv[i]; + } else if (arg == "--index") { + rangeCheck(++i, argc, arg); + networkIndex = stoi(argv[i]); } else if (arg == "--ifm" || arg == "-i") { rangeCheck(++i, argc, arg); ifmArg.push_back(argv[i]); @@ -228,8 +233,15 @@ int main(int argc, char *argv[]) { /* Create network */ cout << "Create network" << endl; - shared_ptr networkBuffer = allocAndFill(device, networkArg); - shared_ptr network = make_shared(device, networkBuffer); + + shared_ptr network; + + if (networkIndex < 0) { + shared_ptr networkBuffer = allocAndFill(device, networkArg); + network = make_shared(device, networkBuffer); + } else { + network = make_shared(device, networkArg, networkIndex); + } /* Create one inference per IFM */ list> inferences; -- cgit v1.2.1