diff options
Diffstat (limited to 'tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp')
-rw-r--r-- | tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp index 37138f4a78..ca6ff45b1b 100644 --- a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp +++ b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp @@ -30,25 +30,26 @@ int main(int argc, char* argv[]) try { - // Configure logging for both the ARMNN library and this test program + // Configures logging for both the ARMNN library and this test program. armnn::ConfigureLogging(true, true, level); armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level); namespace po = boost::program_options; - armnn::Compute computeDevice; + std::vector<armnn::Compute> computeDevice; std::string modelDir; std::string dataDir; po::options_description desc("Options"); try { - // Add generic options needed for all inference tests + // Adds generic options needed for all inference tests. desc.add_options() ("help", "Display help messages") ("model-dir,m", po::value<std::string>(&modelDir)->required(), "Path to directory containing the Cifar10 model file") - ("compute,c", po::value<armnn::Compute>(&computeDevice)->default_value(armnn::Compute::CpuAcc), + ("compute,c", po::value<std::vector<armnn::Compute>>(&computeDevice)->default_value + ({armnn::Compute::CpuAcc, armnn::Compute::CpuRef}), "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") ("data-dir,d", po::value<std::string>(&dataDir)->required(), "Path to directory containing the Cifar10 test data"); @@ -91,9 +92,10 @@ int main(int argc, char* argv[]) string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel"; // Create runtime - armnn::IRuntimePtr runtime(armnn::IRuntime::Create(computeDevice)); + armnn::IRuntime::CreationOptions options; + armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options)); - // Load networks + // Loads networks. armnn::Status status; struct Net { @@ -116,14 +118,14 @@ int main(int argc, char* argv[]) const int networksCount = 4; for (int i = 0; i < networksCount; ++i) { - // Create a network from a file on disk + // Creates a network from a file on the disk. armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" }); - // optimize the network + // Optimizes the network. armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr); try { - optimizedNet = armnn::Optimize(*network, runtime->GetDeviceSpec()); + optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec()); } catch (armnn::Exception& e) { @@ -133,7 +135,7 @@ int main(int argc, char* argv[]) return 1; } - // Load the network into the runtime + // Loads the network into the runtime. armnn::NetworkId networkId; status = runtime->LoadNetwork(networkId, std::move(optimizedNet)); if (status == armnn::Status::Failure) @@ -147,7 +149,7 @@ int main(int argc, char* argv[]) parser->GetNetworkOutputBindingInfo("prob")); } - // Load a test case and test inference + // Loads a test case and tests inference. if (!ValidateDirectory(dataDir)) { return 1; @@ -156,10 +158,10 @@ int main(int argc, char* argv[]) for (unsigned int i = 0; i < 3; ++i) { - // Load test case data (including image data) + // Loads test case data (including image data). std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i); - // Test inference + // Tests inference. std::vector<std::array<float, 10>> outputs(networksCount); for (unsigned int k = 0; k < networksCount; ++k) @@ -174,7 +176,7 @@ int main(int argc, char* argv[]) } } - // Compare outputs + // Compares outputs. for (unsigned int k = 1; k < networksCount; ++k) { if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end())) |