aboutsummaryrefslogtreecommitdiff
path: root/tests/MultipleNetworksCifar10
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-01-24 17:05:36 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-01-30 11:25:56 +0000
commit7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4 (patch)
tree407b519ede76b235c54907fe80411970741e8a00 /tests/MultipleNetworksCifar10
parent28d3d63cc0a33f8396b32fa8347c03912c065911 (diff)
downloadarmnn-7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4.tar.gz
IVGCVSW-2564 Add support for multiple input and output bindings in InferenceModel
Change-Id: I64d724367d42dca4b768b6c6e42acda714985950
Diffstat (limited to 'tests/MultipleNetworksCifar10')
-rw-r--r--tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp17
1 files changed, 13 insertions, 4 deletions
diff --git a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
index f31e0c95a9..44b8890fc2 100644
--- a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
+++ b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp
@@ -173,14 +173,23 @@ int main(int argc, char* argv[])
// Loads test case data (including image data).
std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
- // Tests inference.
- std::vector<std::array<float, 10>> outputs(networksCount);
+ using TInputContainer = std::vector<float>;
+ using TOutputContainer = std::array<float, 10>;
+ // Tests inference.
+ std::vector<TOutputContainer> outputs(networksCount);
for (unsigned int k = 0; k < networksCount; ++k)
{
+ using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
+ std::vector<BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo };
+ std::vector<BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
+
+ std::vector<TInputContainer> inputData = { testCaseData->m_InputImage };
+ std::vector<TOutputContainer> outputData = { outputs[k] };
+
status = runtime->EnqueueWorkload(networks[k].m_Network,
- MakeInputTensors(networks[k].m_InputBindingInfo, testCaseData->m_InputImage),
- MakeOutputTensors(networks[k].m_OutputBindingInfo, outputs[k]));
+ MakeInputTensors(inputBindings, inputData),
+ MakeOutputTensors(outputBindings, outputData));
if (status == armnn::Status::Failure)
{
BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload";