diff options
author | Samuel Yap <samuel.yap@arm.com> | 2022-08-19 11:14:38 +0100 |
---|---|---|
committer | Sam Yap <samuel.yap@arm.com> | 2022-08-30 09:52:48 +0000 |
commit | b9e6b5c3b96792f40201315c831db0aa257f286c (patch) | |
tree | 5f9ec80ee2a2f941c475115a274a2ba18e3965ce /src/armnnSerializer | |
parent | 75d2cb12de43ad308161ed49a8d87762c0ff873e (diff) | |
download | armnn-b9e6b5c3b96792f40201315c831db0aa257f286c.tar.gz |
IVGCVSW-7104: BatchMatMul Serializer/Deserializer Support
* Updated FlatBuffers schema for BatchMatMul layer type
* Added Serializer and Deserializer implementations for BatchMatMul
* Added unit tests for BatchMatMul serialization and deserialization
* Updated CMakeLists and docs
Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: Iad63afbd036a3eb648683eb7416a475561aa20cb
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 16 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 36 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 4 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 41 |
4 files changed, 97 insertions, 0 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index f301fce818..2dbfd85b23 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -182,6 +182,7 @@ enum LayerType : uint { Convolution3d = 65, Pooling3d = 66, GatherNd = 67, + BatchMatMul = 68, } // Base layer table to be used as part of other layers @@ -1009,6 +1010,20 @@ table UnidirectionalSequenceLstmLayer { inputParams:LstmInputParams; } +table BatchMatMulDescriptor { + transposeX:bool = false; + transposeY:bool = false; + adjointX:bool = false; + adjointY:bool = false; + dataLayoutX:DataLayout = NCHW; + dataLayoutY:DataLayout = NCHW; +} + +table BatchMatMulLayer { + base:LayerBase; + descriptor:BatchMatMulDescriptor; +} + union Layer { ActivationLayer, AdditionLayer, @@ -1078,6 +1093,7 @@ union Layer { Convolution3dLayer, Pooling3dLayer, GatherNdLayer, + BatchMatMulLayer, } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 488dac6186..c9a3022b8d 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -218,6 +218,33 @@ void SerializerStrategy::SerializeArgMinMaxLayer(const armnn::IConnectableLayer CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ArgMinMaxLayer); } +void SerializerStrategy::SerializeBatchMatMulLayer(const armnn::IConnectableLayer* layer, + const armnn::BatchMatMulDescriptor& descriptor, + const char* name) +{ + IgnoreUnused(name); + + // Create FlatBuffer BaseLayer + auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchMatMul); + + // Create the FlatBuffer BatchMatMulDescriptor + auto flatBufferDescriptor = CreateBatchMatMulDescriptor(m_flatBufferBuilder, + descriptor.m_TransposeX, + descriptor.m_TransposeY, + descriptor.m_AdjointX, + descriptor.m_AdjointY, + GetFlatBufferDataLayout(descriptor.m_DataLayoutX), + GetFlatBufferDataLayout(descriptor.m_DataLayoutY)); + + // Create the FlatBuffer BatchMatMulLayer + auto flatBufferBatchMatMulLayer = CreateBatchMatMulLayer(m_flatBufferBuilder, + flatBufferBaseLayer, + flatBufferDescriptor); + + // Add the AnyLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferBatchMatMulLayer.o, serializer::Layer::Layer_BatchMatMulLayer); +} + // Build FlatBuffer for BatchToSpaceNd Layer void SerializerStrategy::SerializeBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer, const armnn::BatchToSpaceNdDescriptor& descriptor, @@ -1971,6 +1998,15 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer, SerializeArgMinMaxLayer(layer, layerDescriptor, name); break; } + case armnn::LayerType::BatchMatMul: + { + const armnn::BatchMatMulDescriptor& layerDescriptor = + static_cast<const armnn::BatchMatMulDescriptor&>(descriptor); + SerializeBatchMatMulLayer(layer, + layerDescriptor, + name); + break; + } case armnn::LayerType::BatchNormalization : { const armnn::BatchNormalizationDescriptor& layerDescriptor = diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 216f4dc016..60fed4f6df 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -113,6 +113,10 @@ private: const armnn::ArgMinMaxDescriptor& argMinMaxDescriptor, const char* name = nullptr); + void SerializeBatchMatMulLayer(const armnn::IConnectableLayer* layer, + const armnn::BatchMatMulDescriptor& descriptor, + const char* name = nullptr); + void SerializeBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer, const armnn::BatchToSpaceNdDescriptor& descriptor, const char* name = nullptr); diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 3c00fc43ae..a568bf15c9 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -92,6 +92,47 @@ TEST_CASE("SerializeArgMinMaxSigned64") SerializeArgMinMaxTest(armnn::DataType::Signed64); } +TEST_CASE("SerializeBatchMatMul") +{ + const std::string layerName("batchMatMul"); + const armnn::TensorInfo inputXInfo({2, 3, 4, 5}, armnn::DataType::Float32); + const armnn::TensorInfo inputYInfo({2, 4, 3, 5}, armnn::DataType::Float32); + + const armnn::TensorInfo outputInfo({2, 3, 3, 5}, armnn::DataType::Float32); + + armnn::BatchMatMulDescriptor descriptor(false, + false, + false, + false, + armnn::DataLayout::NHWC, + armnn::DataLayout::NHWC); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputXLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputYLayer = network->AddInputLayer(1); + + armnn::IConnectableLayer* const batchMatMulLayer = + network->AddBatchMatMulLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputXLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(0)); + inputYLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(1)); + batchMatMulLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputXLayer->GetOutputSlot(0).SetTensorInfo(inputXInfo); + inputYLayer->GetOutputSlot(0).SetTensorInfo(inputYInfo); + batchMatMulLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + LayerVerifierBaseWithDescriptor<armnn::BatchMatMulDescriptor> verifier(layerName, + {inputXInfo, inputYInfo}, + {outputInfo}, + descriptor); + deserializedNetwork->ExecuteStrategy(verifier); +} + TEST_CASE("SerializeBatchNormalization") { const std::string layerName("batchNormalization"); |