diff options
Diffstat (limited to 'utils/inference_runner/inference_runner.cpp')
-rw-r--r-- | utils/inference_runner/inference_runner.cpp | 78 |
1 files changed, 62 insertions, 16 deletions
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp index 3388170..ae8d2a7 100644 --- a/utils/inference_runner/inference_runner.cpp +++ b/utils/inference_runner/inference_runner.cpp @@ -16,8 +16,8 @@ * limitations under the License. */ -#include <ethosu.hpp> +#include <ethosu.hpp> #include <uapi/ethosu.h> #include <unistd.h> @@ -77,7 +77,55 @@ shared_ptr<Buffer> allocAndFill(Device &device, const string filename) return buffer; } -std::ostream &operator<<(std::ostream &os, Buffer &buf) +shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &network, const string &filename) +{ + // Open IFM file + ifstream stream(filename, ios::binary); + if (!stream.is_open()) + { + cerr << "Error: Failed to open '" << filename << "'" << endl; + exit(1); + } + + // Get IFM file size + stream.seekg(0, ios_base::end); + size_t size = stream.tellg(); + stream.seekg(0, ios_base::beg); + + if (size != network->getIfmSize()) + { + cerr << "Error: IFM size does not match network size. filename=" << filename << ", size=" << size << ", network=" << network->getIfmSize() << endl; + exit(1); + } + + // Create IFM buffers + vector<shared_ptr<Buffer>> ifm; + for (auto size: network->getIfmDims()) + { + shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size); + buffer->resize(size); + stream.read(buffer->data(), size); + + if (!stream) + { + cerr << "Error: Failed to read IFM" << endl; + exit(1); + } + + ifm.push_back(buffer); + } + + // Create OFM buffers + vector<shared_ptr<Buffer>> ofm; + for (auto size: network->getOfmDims()) + { + ofm.push_back(make_shared<Buffer>(device, size)); + } + + return make_shared<Inference>(network, ifm.begin(), ifm.end(), ofm.begin(), ofm.end()); +} + +ostream &operator<<(ostream &os, Buffer &buf) { char *c = buf.data(); const char *end = c + buf.size(); @@ -128,7 +176,7 @@ int main(int argc, char *argv[]) else if (arg == "--timeout" || arg == "-t") { rangeCheck(++i, argc, arg); - timeout = std::stoi(argv[i]); + timeout = stoi(argv[i]); } else if (arg == "-p") { @@ -167,20 +215,17 @@ int main(int argc, char *argv[]) cout << "Send ping" << endl; device.ioctl(ETHOSU_IOCTL_PING); + /* Create network */ cout << "Create network" << endl; shared_ptr<Buffer> networkBuffer = allocAndFill(device, networkArg); shared_ptr<Network> network = make_shared<Network>(device, networkBuffer); - cout << "Queue inferences" << endl; + /* Create one inference per IFM */ list<shared_ptr<Inference>> inferences; - for (auto &filename: ifmArg) { cout << "Create inference" << endl; - shared_ptr<Buffer> ifmBuffer = allocAndFill(device, filename); - shared_ptr<Buffer> ofmBuffer = make_shared<Buffer>(device, 128 * 1024); - shared_ptr<Inference> inference = make_shared<Inference>(network, ifmBuffer, ofmBuffer); - inferences.push_back(inference); + inferences.push_back(createInference(device, network, filename)); } cout << "Wait for inferences" << endl; @@ -203,16 +248,17 @@ int main(int argc, char *argv[]) if (!inference->failed()) { - shared_ptr<Buffer> ofmBuffer = inference->getOfmBuffer(); + for (auto &ofmBuffer: inference->getOfmBuffers()) + { + cout << "OFM size: " << ofmBuffer->size() << endl; - cout << "OFM size: " << ofmBuffer->size() << endl; + if (print) + { + cout << "OFM data: " << *ofmBuffer << endl; + } - if (print) - { - cout << "OFM data: " << *ofmBuffer << endl; + ofmStream.write(ofmBuffer->data(), ofmBuffer->size()); } - - ofmStream.write(ofmBuffer->data(), ofmBuffer->size()); } ofmIndex++; |