diff options
author | Samuel Yap <samuel.yap@arm.com> | 2022-08-19 11:14:38 +0100 |
---|---|---|
committer | Nikhil Raj <nikhil.raj@arm.com> | 2022-08-30 17:03:44 +0100 |
commit | a04f4a15575ddd778d3a330dbce629412e1ffc0c (patch) | |
tree | 5f9ec80ee2a2f941c475115a274a2ba18e3965ce /src/armnnSerializer/test/SerializerTests.cpp | |
parent | dc8ed9d75e54e914a970e137900930fa64a0782b (diff) | |
download | armnn-a04f4a15575ddd778d3a330dbce629412e1ffc0c.tar.gz |
IVGCVSW-7104: BatchMatMul Serializer/Deserializer Support
* Updated FlatBuffers schema for BatchMatMul layer type
* Added Serializer and Deserializer implementations for BatchMatMul
* Added unit tests for BatchMatMul serialization and deserialization
* Updated CMakeLists and docs
Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: Iad63afbd036a3eb648683eb7416a475561aa20cb
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 3c00fc43ae..a568bf15c9 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -92,6 +92,47 @@ TEST_CASE("SerializeArgMinMaxSigned64") SerializeArgMinMaxTest(armnn::DataType::Signed64); } +TEST_CASE("SerializeBatchMatMul") +{ + const std::string layerName("batchMatMul"); + const armnn::TensorInfo inputXInfo({2, 3, 4, 5}, armnn::DataType::Float32); + const armnn::TensorInfo inputYInfo({2, 4, 3, 5}, armnn::DataType::Float32); + + const armnn::TensorInfo outputInfo({2, 3, 3, 5}, armnn::DataType::Float32); + + armnn::BatchMatMulDescriptor descriptor(false, + false, + false, + false, + armnn::DataLayout::NHWC, + armnn::DataLayout::NHWC); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputXLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputYLayer = network->AddInputLayer(1); + + armnn::IConnectableLayer* const batchMatMulLayer = + network->AddBatchMatMulLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputXLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(0)); + inputYLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(1)); + batchMatMulLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputXLayer->GetOutputSlot(0).SetTensorInfo(inputXInfo); + inputYLayer->GetOutputSlot(0).SetTensorInfo(inputYInfo); + batchMatMulLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + LayerVerifierBaseWithDescriptor<armnn::BatchMatMulDescriptor> verifier(layerName, + {inputXInfo, inputYInfo}, + {outputInfo}, + descriptor); + deserializedNetwork->ExecuteStrategy(verifier); +} + TEST_CASE("SerializeBatchNormalization") { const std::string layerName("batchNormalization"); |