aboutsummaryrefslogtreecommitdiff
path: root/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp')
-rw-r--r--tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp17
1 files changed, 13 insertions, 4 deletions
diff --git a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
index f31e0c95a9..44b8890fc2 100644
--- a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
+++ b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
@@ -173,14 +173,23 @@ int main(int argc, char* argv[])
// Loads test case data (including image data).
std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
- // Tests inference.
- std::vector<std::array<float, 10>> outputs(networksCount);
+ using TInputContainer = std::vector<float>;
+ using TOutputContainer = std::array<float, 10>;
+ // Tests inference.
+ std::vector<TOutputContainer> outputs(networksCount);
for (unsigned int k = 0; k < networksCount; ++k)
{
+ using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
+ std::vector<BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo };
+ std::vector<BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
+
+ std::vector<TInputContainer> inputData = { testCaseData->m_InputImage };
+ std::vector<TOutputContainer> outputData = { outputs[k] };
+
status = runtime->EnqueueWorkload(networks[k].m_Network,
- MakeInputTensors(networks[k].m_InputBindingInfo, testCaseData->m_InputImage),
- MakeOutputTensors(networks[k].m_OutputBindingInfo, outputs[k]));
+ MakeInputTensors(inputBindings, inputData),
+ MakeOutputTensors(outputBindings, outputData));
if (status == armnn::Status::Failure)
{
BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload";