diff options
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 4b6bf1ec53..77bf78683a 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -91,6 +91,67 @@ BOOST_AUTO_TEST_CASE(SimpleNetworkWithMultiplicationSerialization) BOOST_TEST(stream.str().find(multLayerName) != stream.str().npos); } +BOOST_AUTO_TEST_CASE(SimpleReshapeIntegration) +{ + armnn::NetworkId networkIdentifier; + armnn::IRuntime::CreationOptions options; // default options + armnn::IRuntimePtr run = armnn::IRuntime::Create(options); + + unsigned int inputShape[] = {1, 9}; + unsigned int outputShape[] = {3, 3}; + + auto inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::DataType::Float32); + auto outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32); + auto reshapeOutputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32); + + armnn::ReshapeDescriptor reshapeDescriptor; + reshapeDescriptor.m_TargetShape = reshapeOutputTensorInfo.GetShape(); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer *const reshapeLayer = network->AddReshapeLayer(reshapeDescriptor, "ReshapeLayer"); + armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(reshapeLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + reshapeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnnSerializer::Serializer serializer; + serializer.Serialize(*network); + std::stringstream stream; + serializer.SaveSerializedToStream(stream); + std::string const serializerString{stream.str()}; + + //Deserialize network. + auto deserializedNetwork = DeserializeNetwork(serializerString); + + //Optimize the deserialized network + auto deserializedOptimized = Optimize(*deserializedNetwork, {armnn::Compute::CpuRef}, + run->GetDeviceSpec()); + + // Load graph into runtime + run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized)); + + std::vector<float> input1Data(inputTensorInfo.GetNumElements()); + std::iota(input1Data.begin(), input1Data.end(), 8); + + armnn::InputTensors inputTensors + { + {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), input1Data.data())} + }; + + std::vector<float> outputData(input1Data.size()); + armnn::OutputTensors outputTensors + { + {0,armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())} + }; + + run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); + + BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(),outputData.end(), input1Data.begin(),input1Data.end()); +} + BOOST_AUTO_TEST_CASE(SimpleSoftmaxIntegration) { armnn::TensorInfo tensorInfo({1, 10}, armnn::DataType::Float32); |