diff options
Diffstat (limited to 'tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp')
-rw-r--r-- | tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp index f31e0c95a9..44b8890fc2 100644 --- a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp +++ b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp @@ -173,14 +173,23 @@ int main(int argc, char* argv[]) // Loads test case data (including image data). std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i); - // Tests inference. - std::vector<std::array<float, 10>> outputs(networksCount); + using TInputContainer = std::vector<float>; + using TOutputContainer = std::array<float, 10>; + // Tests inference. + std::vector<TOutputContainer> outputs(networksCount); for (unsigned int k = 0; k < networksCount; ++k) { + using BindingPointInfo = InferenceModelInternal::BindingPointInfo; + std::vector<BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo }; + std::vector<BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo }; + + std::vector<TInputContainer> inputData = { testCaseData->m_InputImage }; + std::vector<TOutputContainer> outputData = { outputs[k] }; + status = runtime->EnqueueWorkload(networks[k].m_Network, - MakeInputTensors(networks[k].m_InputBindingInfo, testCaseData->m_InputImage), - MakeOutputTensors(networks[k].m_OutputBindingInfo, outputs[k])); + MakeInputTensors(inputBindings, inputData), + MakeOutputTensors(outputBindings, outputData)); if (status == armnn::Status::Failure) { BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload"; |