diff options
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 3c00fc43ae..a568bf15c9 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -92,6 +92,47 @@ TEST_CASE("SerializeArgMinMaxSigned64") SerializeArgMinMaxTest(armnn::DataType::Signed64); } +TEST_CASE("SerializeBatchMatMul") +{ + const std::string layerName("batchMatMul"); + const armnn::TensorInfo inputXInfo({2, 3, 4, 5}, armnn::DataType::Float32); + const armnn::TensorInfo inputYInfo({2, 4, 3, 5}, armnn::DataType::Float32); + + const armnn::TensorInfo outputInfo({2, 3, 3, 5}, armnn::DataType::Float32); + + armnn::BatchMatMulDescriptor descriptor(false, + false, + false, + false, + armnn::DataLayout::NHWC, + armnn::DataLayout::NHWC); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputXLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputYLayer = network->AddInputLayer(1); + + armnn::IConnectableLayer* const batchMatMulLayer = + network->AddBatchMatMulLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputXLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(0)); + inputYLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(1)); + batchMatMulLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputXLayer->GetOutputSlot(0).SetTensorInfo(inputXInfo); + inputYLayer->GetOutputSlot(0).SetTensorInfo(inputYInfo); + batchMatMulLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + LayerVerifierBaseWithDescriptor<armnn::BatchMatMulDescriptor> verifier(layerName, + {inputXInfo, inputYInfo}, + {outputInfo}, + descriptor); + deserializedNetwork->ExecuteStrategy(verifier); +} + TEST_CASE("SerializeBatchNormalization") { const std::string layerName("batchNormalization"); |