aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/SerializerTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp61
1 files changed, 61 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 4b6bf1ec53..77bf78683a 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -91,6 +91,67 @@ BOOST_AUTO_TEST_CASE(SimpleNetworkWithMultiplicationSerialization)
BOOST_TEST(stream.str().find(multLayerName) != stream.str().npos);
}
+BOOST_AUTO_TEST_CASE(SimpleReshapeIntegration)
+{
+ armnn::NetworkId networkIdentifier;
+ armnn::IRuntime::CreationOptions options; // default options
+ armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
+
+ unsigned int inputShape[] = {1, 9};
+ unsigned int outputShape[] = {3, 3};
+
+ auto inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::DataType::Float32);
+ auto outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32);
+ auto reshapeOutputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::DataType::Float32);
+
+ armnn::ReshapeDescriptor reshapeDescriptor;
+ reshapeDescriptor.m_TargetShape = reshapeOutputTensorInfo.GetShape();
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer *const reshapeLayer = network->AddReshapeLayer(reshapeDescriptor, "ReshapeLayer");
+ armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(reshapeLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+ reshapeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ armnnSerializer::Serializer serializer;
+ serializer.Serialize(*network);
+ std::stringstream stream;
+ serializer.SaveSerializedToStream(stream);
+ std::string const serializerString{stream.str()};
+
+ //Deserialize network.
+ auto deserializedNetwork = DeserializeNetwork(serializerString);
+
+ //Optimize the deserialized network
+ auto deserializedOptimized = Optimize(*deserializedNetwork, {armnn::Compute::CpuRef},
+ run->GetDeviceSpec());
+
+ // Load graph into runtime
+ run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
+
+ std::vector<float> input1Data(inputTensorInfo.GetNumElements());
+ std::iota(input1Data.begin(), input1Data.end(), 8);
+
+ armnn::InputTensors inputTensors
+ {
+ {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), input1Data.data())}
+ };
+
+ std::vector<float> outputData(input1Data.size());
+ armnn::OutputTensors outputTensors
+ {
+ {0,armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
+ };
+
+ run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
+
+ BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(),outputData.end(), input1Data.begin(),input1Data.end());
+}
+
BOOST_AUTO_TEST_CASE(SimpleSoftmaxIntegration)
{
armnn::TensorInfo tensorInfo({1, 10}, armnn::DataType::Float32);