diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/inference_runner/inference_runner.cpp | 123 |
1 files changed, 41 insertions, 82 deletions
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp index ae8d2a7..c5a62f6 100644 --- a/utils/inference_runner/inference_runner.cpp +++ b/utils/inference_runner/inference_runner.cpp @@ -16,26 +16,23 @@ * limitations under the License. */ - #include <ethosu.hpp> #include <uapi/ethosu.h> -#include <unistd.h> #include <fstream> #include <iomanip> #include <iostream> #include <list> #include <string> +#include <unistd.h> using namespace std; using namespace EthosU; -namespace -{ +namespace { int defaultTimeout = 60; -void help(const string exe) -{ +void help(const string exe) { cerr << "Usage: " << exe << " [ARGS]\n"; cerr << "\n"; cerr << "Arguments:\n"; @@ -48,20 +45,16 @@ void help(const string exe) cerr << endl; } -void rangeCheck(const int i, const int argc, const string arg) -{ - if (i >= argc) - { +void rangeCheck(const int i, const int argc, const string arg) { + if (i >= argc) { cerr << "Error: Missing argument to '" << arg << "'" << endl; exit(1); } } -shared_ptr<Buffer> allocAndFill(Device &device, const string filename) -{ +shared_ptr<Buffer> allocAndFill(Device &device, const string filename) { ifstream stream(filename, ios::binary); - if (!stream.is_open()) - { + if (!stream.is_open()) { cerr << "Error: Failed to open '" << filename << "'" << endl; exit(1); } @@ -77,12 +70,10 @@ shared_ptr<Buffer> allocAndFill(Device &device, const string filename) return buffer; } -shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &network, const string &filename) -{ +shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &network, const string &filename) { // Open IFM file ifstream stream(filename, ios::binary); - if (!stream.is_open()) - { + if (!stream.is_open()) { cerr << "Error: Failed to open '" << filename << "'" << endl; exit(1); } @@ -92,22 +83,20 @@ shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &netwo size_t size = stream.tellg(); stream.seekg(0, ios_base::beg); - if (size != network->getIfmSize()) - { - cerr << "Error: IFM size does not match network size. filename=" << filename << ", size=" << size << ", network=" << network->getIfmSize() << endl; + if (size != network->getIfmSize()) { + cerr << "Error: IFM size does not match network size. filename=" << filename << ", size=" << size + << ", network=" << network->getIfmSize() << endl; exit(1); } // Create IFM buffers vector<shared_ptr<Buffer>> ifm; - for (auto size: network->getIfmDims()) - { + for (auto size : network->getIfmDims()) { shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size); buffer->resize(size); stream.read(buffer->data(), size); - if (!stream) - { + if (!stream) { cerr << "Error: Failed to read IFM" << endl; exit(1); } @@ -117,99 +106,77 @@ shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &netwo // Create OFM buffers vector<shared_ptr<Buffer>> ofm; - for (auto size: network->getOfmDims()) - { + for (auto size : network->getOfmDims()) { ofm.push_back(make_shared<Buffer>(device, size)); } return make_shared<Inference>(network, ifm.begin(), ifm.end(), ofm.begin(), ofm.end()); } -ostream &operator<<(ostream &os, Buffer &buf) -{ - char *c = buf.data(); +ostream &operator<<(ostream &os, Buffer &buf) { + char *c = buf.data(); const char *end = c + buf.size(); - while (c < end) - { + while (c < end) { os << hex << setw(2) << static_cast<int>(*c++) << " " << dec; } return os; } -} +} // namespace -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { const string exe = argv[0]; string networkArg; list<string> ifmArg; string ofmArg; int timeout = defaultTimeout; - bool print = false; + bool print = false; - for (int i = 1; i < argc; ++i) - { + for (int i = 1; i < argc; ++i) { const string arg(argv[i]); - if (arg == "-h" || arg == "--help") - { + if (arg == "-h" || arg == "--help") { help(exe); exit(1); - } - else if (arg == "--network" || arg == "-n") - { + } else if (arg == "--network" || arg == "-n") { rangeCheck(++i, argc, arg); networkArg = argv[i]; - } - else if (arg == "--ifm" || arg == "-i") - { + } else if (arg == "--ifm" || arg == "-i") { rangeCheck(++i, argc, arg); ifmArg.push_back(argv[i]); - } - else if (arg == "--ofm" || arg == "-o") - { + } else if (arg == "--ofm" || arg == "-o") { rangeCheck(++i, argc, arg); ofmArg = argv[i]; - } - else if (arg == "--timeout" || arg == "-t") - { + } else if (arg == "--timeout" || arg == "-t") { rangeCheck(++i, argc, arg); timeout = stoi(argv[i]); - } - else if (arg == "-p") - { + } else if (arg == "-p") { print = true; - } - else - { + } else { cerr << "Error: Invalid argument '" << arg << "'" << endl; help(exe); exit(1); } } - if (networkArg.empty()) - { + if (networkArg.empty()) { cerr << "Error: Missing 'network' argument" << endl; exit(1); } - if (ifmArg.empty()) - { + if (ifmArg.empty()) { cerr << "Error: Missing 'ifm' argument" << endl; exit(1); } - if (ofmArg.empty()) - { + if (ofmArg.empty()) { cerr << "Error: Missing 'ofm' argument" << endl; exit(1); } - try - { + try { Device device; cout << "Send ping" << endl; @@ -218,12 +185,11 @@ int main(int argc, char *argv[]) /* Create network */ cout << "Create network" << endl; shared_ptr<Buffer> networkBuffer = allocAndFill(device, networkArg); - shared_ptr<Network> network = make_shared<Network>(device, networkBuffer); + shared_ptr<Network> network = make_shared<Network>(device, networkBuffer); /* Create one inference per IFM */ list<shared_ptr<Inference>> inferences; - for (auto &filename: ifmArg) - { + for (auto &filename : ifmArg) { cout << "Create inference" << endl; inferences.push_back(createInference(device, network, filename)); } @@ -231,8 +197,7 @@ int main(int argc, char *argv[]) cout << "Wait for inferences" << endl; int ofmIndex = 0; - for (auto &inference: inferences) - { + for (auto &inference : inferences) { inference->wait(timeout); string status = inference->failed() ? "failed" : "success"; @@ -240,20 +205,16 @@ int main(int argc, char *argv[]) string ofmFilename = ofmArg + "." + to_string(ofmIndex); ofstream ofmStream(ofmFilename, ios::binary); - if (!ofmStream.is_open()) - { + if (!ofmStream.is_open()) { cerr << "Error: Failed to open '" << ofmFilename << "'" << endl; exit(1); } - if (!inference->failed()) - { - for (auto &ofmBuffer: inference->getOfmBuffers()) - { + if (!inference->failed()) { + for (auto &ofmBuffer : inference->getOfmBuffers()) { cout << "OFM size: " << ofmBuffer->size() << endl; - if (print) - { + if (print) { cout << "OFM data: " << *ofmBuffer << endl; } @@ -263,9 +224,7 @@ int main(int argc, char *argv[]) ofmIndex++; } - } - catch (Exception &e) - { + } catch (Exception &e) { cerr << "Error: " << e.what() << endl; return 1; } |