aboutsummaryrefslogtreecommitdiff
path: root/tests/MultipleNetworksCifar10
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /tests/MultipleNetworksCifar10
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'tests/MultipleNetworksCifar10')
-rw-r--r--tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp30
1 files changed, 16 insertions, 14 deletions
diff --git a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
index 37138f4a78..ca6ff45b1b 100644
--- a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
+++ b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
@@ -30,25 +30,26 @@ int main(int argc, char* argv[])
try
{
- // Configure logging for both the ARMNN library and this test program
+ // Configures logging for both the ARMNN library and this test program.
armnn::ConfigureLogging(true, true, level);
armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
namespace po = boost::program_options;
- armnn::Compute computeDevice;
+ std::vector<armnn::Compute> computeDevice;
std::string modelDir;
std::string dataDir;
po::options_description desc("Options");
try
{
- // Add generic options needed for all inference tests
+ // Adds generic options needed for all inference tests.
desc.add_options()
("help", "Display help messages")
("model-dir,m", po::value<std::string>(&modelDir)->required(),
"Path to directory containing the Cifar10 model file")
- ("compute,c", po::value<armnn::Compute>(&computeDevice)->default_value(armnn::Compute::CpuAcc),
+ ("compute,c", po::value<std::vector<armnn::Compute>>(&computeDevice)->default_value
+ ({armnn::Compute::CpuAcc, armnn::Compute::CpuRef}),
"Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
("data-dir,d", po::value<std::string>(&dataDir)->required(),
"Path to directory containing the Cifar10 test data");
@@ -91,9 +92,10 @@ int main(int argc, char* argv[])
string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
// Create runtime
- armnn::IRuntimePtr runtime(armnn::IRuntime::Create(computeDevice));
+ armnn::IRuntime::CreationOptions options;
+ armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
- // Load networks
+ // Loads networks.
armnn::Status status;
struct Net
{
@@ -116,14 +118,14 @@ int main(int argc, char* argv[])
const int networksCount = 4;
for (int i = 0; i < networksCount; ++i)
{
- // Create a network from a file on disk
+ // Creates a network from a file on the disk.
armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
- // optimize the network
+ // Optimizes the network.
armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
try
{
- optimizedNet = armnn::Optimize(*network, runtime->GetDeviceSpec());
+ optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
}
catch (armnn::Exception& e)
{
@@ -133,7 +135,7 @@ int main(int argc, char* argv[])
return 1;
}
- // Load the network into the runtime
+ // Loads the network into the runtime.
armnn::NetworkId networkId;
status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
if (status == armnn::Status::Failure)
@@ -147,7 +149,7 @@ int main(int argc, char* argv[])
parser->GetNetworkOutputBindingInfo("prob"));
}
- // Load a test case and test inference
+ // Loads a test case and tests inference.
if (!ValidateDirectory(dataDir))
{
return 1;
@@ -156,10 +158,10 @@ int main(int argc, char* argv[])
for (unsigned int i = 0; i < 3; ++i)
{
- // Load test case data (including image data)
+ // Loads test case data (including image data).
std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
- // Test inference
+ // Tests inference.
std::vector<std::array<float, 10>> outputs(networksCount);
for (unsigned int k = 0; k < networksCount; ++k)
@@ -174,7 +176,7 @@ int main(int argc, char* argv[])
}
}
- // Compare outputs
+ // Compares outputs.
for (unsigned int k = 1; k < networksCount; ++k)
{
if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end()))