diff options
Diffstat (limited to 'tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp')
-rw-r--r-- | tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp index 44b8890fc2..2bbfb69c8d 100644 --- a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp +++ b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp @@ -168,28 +168,34 @@ int main(int argc, char* argv[]) } Cifar10Database cifar10(dataDir); + using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>; + for (unsigned int i = 0; i < 3; ++i) { // Loads test case data (including image data). std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i); - using TInputContainer = std::vector<float>; - using TOutputContainer = std::array<float, 10>; - // Tests inference. - std::vector<TOutputContainer> outputs(networksCount); + std::vector<TContainer> outputs; + outputs.reserve(networksCount); + + for (unsigned int j = 0; j < networksCount; ++j) + { + outputs.push_back(std::vector<float>(10)); + } + for (unsigned int k = 0; k < networksCount; ++k) { using BindingPointInfo = InferenceModelInternal::BindingPointInfo; std::vector<BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo }; std::vector<BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo }; - std::vector<TInputContainer> inputData = { testCaseData->m_InputImage }; - std::vector<TOutputContainer> outputData = { outputs[k] }; + std::vector<TContainer> inputDataContainers = { testCaseData->m_InputImage }; + std::vector<TContainer> outputDataContainers = { outputs[k] }; status = runtime->EnqueueWorkload(networks[k].m_Network, - MakeInputTensors(inputBindings, inputData), - MakeOutputTensors(outputBindings, outputData)); + MakeInputTensors(inputBindings, inputDataContainers), + MakeOutputTensors(outputBindings, outputDataContainers)); if (status == armnn::Status::Failure) { BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload"; @@ -198,9 +204,13 @@ int main(int argc, char* argv[]) } // Compares outputs. + std::vector<float> output0 = boost::get<std::vector<float>>(outputs[0]); + for (unsigned int k = 1; k < networksCount; ++k) { - if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end())) + std::vector<float> outputK = boost::get<std::vector<float>>(outputs[k]); + + if (!std::equal(output0.begin(), output0.end(), outputK.begin(), outputK.end())) { BOOST_LOG_TRIVIAL(error) << "Multiple networks inference failed!"; return 1; |