From 8d69bbc67566c09553c4afec32e829efa2cb50df Mon Sep 17 00:00:00 2001 From: Nattapat Chaimanowong Date: Wed, 27 Feb 2019 16:52:29 +0000 Subject: IVGCVSW-2766 Modify CheckDeserializedNetworkAgainstOriginal to work with multiple inputs and outputs Change-Id: I90d1701d5bfd8ced32720e495e0126de0014aff9 Signed-off-by: Nattapat Chaimanowong --- src/armnnSerializer/test/SerializerTests.cpp | 102 ++++++++++++++++----------- 1 file changed, 59 insertions(+), 43 deletions(-) diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 13fb0b22f8..282b87bc85 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -63,11 +63,14 @@ static std::vector GenerateRandomData(size_t size) void CheckDeserializedNetworkAgainstOriginal(const armnn::INetwork& deserializedNetwork, const armnn::INetwork& originalNetwork, - const armnn::TensorShape& inputShape, - const armnn::TensorShape& outputShape, - armnn::LayerBindingId inputBindingId = 0, - armnn::LayerBindingId outputBindingId = 0) + const std::vector& inputShapes, + const std::vector& outputShapes, + const std::vector& inputBindingIds = {0}, + const std::vector& outputBindingIds = {0}) { + BOOST_CHECK(inputShapes.size() == inputBindingIds.size()); + BOOST_CHECK(outputShapes.size() == outputBindingIds.size()); + armnn::IRuntime::CreationOptions options; armnn::IRuntimePtr runtime = armnn::IRuntime::Create(options); @@ -94,38 +97,51 @@ void CheckDeserializedNetworkAgainstOriginal(const armnn::INetwork& deserialized BOOST_CHECK(status2 == armnn::Status::Success); // Generate some input data - std::vector inputData = GenerateRandomData(inputShape.GetNumElements()); + armnn::InputTensors inputTensors1; + armnn::InputTensors inputTensors2; + std::vector> inputData; + inputData.reserve(inputShapes.size()); - armnn::InputTensors inputTensors1 + for (unsigned int i = 0; i < inputShapes.size(); i++) { - { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId1, inputBindingId), inputData.data()) } - }; + inputData.push_back(GenerateRandomData(inputShapes[i].GetNumElements())); - armnn::InputTensors inputTensors2 - { - { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId2, inputBindingId), inputData.data()) } - }; + inputTensors1.emplace_back( + i, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId1, inputBindingIds[i]), inputData[i].data())); - std::vector outputData1(outputShape.GetNumElements()); - std::vector outputData2(outputShape.GetNumElements()); + inputTensors2.emplace_back( + i, armnn::ConstTensor(runtime->GetInputTensorInfo(networkId2, inputBindingIds[i]), inputData[i].data())); + } - armnn::OutputTensors outputTensors1 - { - { 0, armnn::Tensor(runtime->GetOutputTensorInfo(networkId1, outputBindingId), outputData1.data()) } - }; + armnn::OutputTensors outputTensors1; + armnn::OutputTensors outputTensors2; + std::vector> outputData1; + std::vector> outputData2; + outputData1.reserve(outputShapes.size()); + outputData2.reserve(outputShapes.size()); - armnn::OutputTensors outputTensors2 + for (unsigned int i = 0; i < outputShapes.size(); i++) { - { 0, armnn::Tensor(runtime->GetOutputTensorInfo(networkId2, outputBindingId), outputData2.data()) } - }; + outputData1.emplace_back(outputShapes[i].GetNumElements()); + outputData2.emplace_back(outputShapes[i].GetNumElements()); + + outputTensors1.emplace_back( + i, armnn::Tensor(runtime->GetOutputTensorInfo(networkId1, outputBindingIds[i]), outputData1[i].data())); + + outputTensors2.emplace_back( + i, armnn::Tensor(runtime->GetOutputTensorInfo(networkId2, outputBindingIds[i]), outputData2[i].data())); + } // Run original and deserialized network runtime->EnqueueWorkload(networkId1, inputTensors1, outputTensors1); runtime->EnqueueWorkload(networkId2, inputTensors2, outputTensors2); // Compare output data - BOOST_CHECK_EQUAL_COLLECTIONS(outputData1.begin(), outputData1.end(), - outputData2.begin(), outputData2.end()); + for (unsigned int i = 0; i < outputShapes.size(); i++) + { + BOOST_CHECK_EQUAL_COLLECTIONS( + outputData1[i].begin(), outputData1[i].end(), outputData2[i].begin(), outputData2[i].end()); + } } } // anonymous namespace @@ -235,8 +251,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeConstant) CheckDeserializedNetworkAgainstOriginal(*net, *deserializedNetwork, - commonTensorInfo.GetShape(), - commonTensorInfo.GetShape()); + {commonTensorInfo.GetShape()}, + {commonTensorInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeMultiplication) @@ -344,8 +360,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeConvolution2d) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputInfo.GetShape(), - outputInfo.GetShape()); + {inputInfo.GetShape()}, + {outputInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializeReshape) @@ -387,8 +403,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeReshape) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputTensorInfo.GetShape(), - outputTensorInfo.GetShape()); + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializeDepthwiseConvolution2d) @@ -443,8 +459,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeDepthwiseConvolution2d) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputInfo.GetShape(), - outputInfo.GetShape()); + {inputInfo.GetShape()}, + {outputInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializeSoftmax) @@ -481,8 +497,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeSoftmax) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - tensorInfo.GetShape(), - tensorInfo.GetShape()); + {tensorInfo.GetShape()}, + {tensorInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializePooling2d) @@ -533,8 +549,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializePooling2d) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputInfo.GetShape(), - outputInfo.GetShape()); + {inputInfo.GetShape()}, + {outputInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializePermute) @@ -575,8 +591,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializePermute) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputTensorInfo.GetShape(), - outputTensorInfo.GetShape()); + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializeFullyConnected) @@ -632,8 +648,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeFullyConnected) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputInfo.GetShape(), - outputInfo.GetShape()); + {inputInfo.GetShape()}, + {outputInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializeSpaceToBatchNd) @@ -678,8 +694,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeSpaceToBatchNd) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputTensorInfo.GetShape(), - outputTensorInfo.GetShape()); + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDeserializeBatchToSpaceNd) @@ -724,8 +740,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeBatchToSpaceNd) CheckDeserializedNetworkAgainstOriginal(*network, *deserializedNetwork, - inputTensorInfo.GetShape(), - outputTensorInfo.GetShape()); + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); } BOOST_AUTO_TEST_CASE(SerializeDivision) -- cgit v1.2.1