From 116a635581f292cb4882ea1a086f842904f85c3c Mon Sep 17 00:00:00 2001 From: Kristofer Jonsson Date: Thu, 20 Aug 2020 17:25:23 +0200 Subject: Initial commit Change-Id: I14b6becc908a0ac215769c32ee9c43db192ae6c8 --- utils/inference_runner/inference_runner.cpp | 228 ++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 utils/inference_runner/inference_runner.cpp (limited to 'utils/inference_runner/inference_runner.cpp') diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp new file mode 100644 index 0000000..3388170 --- /dev/null +++ b/utils/inference_runner/inference_runner.cpp @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace EthosU; + +namespace +{ +int defaultTimeout = 60; + +void help(const string exe) +{ + cerr << "Usage: " << exe << " [ARGS]\n"; + cerr << "\n"; + cerr << "Arguments:\n"; + cerr << " -h --help Print this help message.\n"; + cerr << " -n --network File to read network from.\n"; + cerr << " -i --ifm File to read IFM from.\n"; + cerr << " -o --ofm File to write IFM to.\n"; + cerr << " -t --timeout Timeout in seconds (default " << defaultTimeout << ").\n"; + cerr << " -p Print OFM.\n"; + cerr << endl; +} + +void rangeCheck(const int i, const int argc, const string arg) +{ + if (i >= argc) + { + cerr << "Error: Missing argument to '" << arg << "'" << endl; + exit(1); + } +} + +shared_ptr allocAndFill(Device &device, const string filename) +{ + ifstream stream(filename, ios::binary); + if (!stream.is_open()) + { + cerr << "Error: Failed to open '" << filename << "'" << endl; + exit(1); + } + + stream.seekg(0, ios_base::end); + size_t size = stream.tellg(); + stream.seekg(0, ios_base::beg); + + shared_ptr buffer = make_shared(device, size); + buffer->resize(size); + stream.read(buffer->data(), size); + + return buffer; +} + +std::ostream &operator<<(std::ostream &os, Buffer &buf) +{ + char *c = buf.data(); + const char *end = c + buf.size(); + + while (c < end) + { + os << hex << setw(2) << static_cast(*c++) << " " << dec; + } + + return os; +} + +} + +int main(int argc, char *argv[]) +{ + const string exe = argv[0]; + string networkArg; + list ifmArg; + string ofmArg; + int timeout = defaultTimeout; + bool print = false; + + for (int i = 1; i < argc; ++i) + { + const string arg(argv[i]); + + if (arg == "-h" || arg == "--help") + { + help(exe); + exit(1); + } + else if (arg == "--network" || arg == "-n") + { + rangeCheck(++i, argc, arg); + networkArg = argv[i]; + } + else if (arg == "--ifm" || arg == "-i") + { + rangeCheck(++i, argc, arg); + ifmArg.push_back(argv[i]); + } + else if (arg == "--ofm" || arg == "-o") + { + rangeCheck(++i, argc, arg); + ofmArg = argv[i]; + } + else if (arg == "--timeout" || arg == "-t") + { + rangeCheck(++i, argc, arg); + timeout = std::stoi(argv[i]); + } + else if (arg == "-p") + { + print = true; + } + else + { + cerr << "Error: Invalid argument '" << arg << "'" << endl; + help(exe); + exit(1); + } + } + + if (networkArg.empty()) + { + cerr << "Error: Missing 'network' argument" << endl; + exit(1); + } + + if (ifmArg.empty()) + { + cerr << "Error: Missing 'ifm' argument" << endl; + exit(1); + } + + if (ofmArg.empty()) + { + cerr << "Error: Missing 'ofm' argument" << endl; + exit(1); + } + + try + { + Device device; + + cout << "Send ping" << endl; + device.ioctl(ETHOSU_IOCTL_PING); + + cout << "Create network" << endl; + shared_ptr networkBuffer = allocAndFill(device, networkArg); + shared_ptr network = make_shared(device, networkBuffer); + + cout << "Queue inferences" << endl; + list> inferences; + + for (auto &filename: ifmArg) + { + cout << "Create inference" << endl; + shared_ptr ifmBuffer = allocAndFill(device, filename); + shared_ptr ofmBuffer = make_shared(device, 128 * 1024); + shared_ptr inference = make_shared(network, ifmBuffer, ofmBuffer); + inferences.push_back(inference); + } + + cout << "Wait for inferences" << endl; + + int ofmIndex = 0; + for (auto &inference: inferences) + { + inference->wait(timeout); + + string status = inference->failed() ? "failed" : "success"; + cout << "Inference status: " << status << endl; + + string ofmFilename = ofmArg + "." + to_string(ofmIndex); + ofstream ofmStream(ofmFilename, ios::binary); + if (!ofmStream.is_open()) + { + cerr << "Error: Failed to open '" << ofmFilename << "'" << endl; + exit(1); + } + + if (!inference->failed()) + { + shared_ptr ofmBuffer = inference->getOfmBuffer(); + + cout << "OFM size: " << ofmBuffer->size() << endl; + + if (print) + { + cout << "OFM data: " << *ofmBuffer << endl; + } + + ofmStream.write(ofmBuffer->data(), ofmBuffer->size()); + } + + ofmIndex++; + } + } + catch (Exception &e) + { + cerr << "Error: " << e.what() << endl; + return 1; + } + + return 0; +} -- cgit v1.2.1