diff options
author | Keith Davis <keith.davis@arm.com> | 2020-06-04 16:34:23 +0100 |
---|---|---|
committer | KeithARM <keith.davis@arm.com> | 2020-06-17 14:54:08 +0000 |
commit | 300ad5695e2a577d2a9292b3cd6d182aae3298a3 (patch) | |
tree | 3fb34c3dc50b62630538592e6964fc263d078921 /src/armnnDeserializer | |
parent | 6398a98ac273931cc0b3ab33222d255d1edf48b0 (diff) | |
download | armnn-300ad5695e2a577d2a9292b3cd6d182aae3298a3.tar.gz |
IVGCVSW-4908 Add Serializer/Deserializer Support for FILL operator
Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: Icae26505d0e378ee5ffb3e92b35d78d48b369d2e
Diffstat (limited to 'src/armnnDeserializer')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 24 | ||||
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 1 | ||||
-rw-r--r-- | src/armnnDeserializer/DeserializerSupport.md | 1 | ||||
-rw-r--r-- | src/armnnDeserializer/test/DeserializeFill.cpp | 134 |
4 files changed, 160 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index bea34e16e7..3b69ed17b8 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -204,6 +204,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_ElementwiseUnaryLayer] = &Deserializer::ParseElementwiseUnary; m_ParserFunctions[Layer_EqualLayer] = &Deserializer::ParseEqual; m_ParserFunctions[Layer_FullyConnectedLayer] = &Deserializer::ParseFullyConnected; + m_ParserFunctions[Layer_FillLayer] = &Deserializer::ParseFill; m_ParserFunctions[Layer_FloorLayer] = &Deserializer::ParseFloor; m_ParserFunctions[Layer_GatherLayer] = &Deserializer::ParseGather; m_ParserFunctions[Layer_GreaterLayer] = &Deserializer::ParseGreater; @@ -283,6 +284,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_EqualLayer()->base(); case Layer::Layer_FullyConnectedLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_FullyConnectedLayer()->base(); + case Layer::Layer_FillLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_FillLayer()->base(); case Layer::Layer_FloorLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_FloorLayer()->base(); case Layer::Layer_GatherLayer: @@ -1431,6 +1434,27 @@ void Deserializer::ParseEqual(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseFill(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + auto inputs = GetInputs(graph, layerIndex); + CHECK_LOCATION(); + CHECK_VALID_SIZE(inputs.size(), 1); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = GetLayerName(graph, layerIndex); + armnn::FillDescriptor descriptor(1.0f); + IConnectableLayer* layer = m_Network->AddFillLayer(descriptor, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void Deserializer::ParseGreater(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index d6ceced7c6..69868c210a 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -96,6 +96,7 @@ private: void ParseDivision(GraphPtr graph, unsigned int layerIndex); void ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex); void ParseEqual(GraphPtr graph, unsigned int layerIndex); + void ParseFill(GraphPtr graph, unsigned int layerIndex); void ParseFloor(GraphPtr graph, unsigned int layerIndex); void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex); void ParseGather(GraphPtr graph, unsigned int layerIndex); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 4e83cc6733..b4982ec78a 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -22,6 +22,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * DetectionPostProcess * Division * ElementwiseUnary +* Fill * Floor * FullyConnected * Gather diff --git a/src/armnnDeserializer/test/DeserializeFill.cpp b/src/armnnDeserializer/test/DeserializeFill.cpp new file mode 100644 index 0000000000..632734fa9e --- /dev/null +++ b/src/armnnDeserializer/test/DeserializeFill.cpp @@ -0,0 +1,134 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <boost/test/unit_test.hpp> +#include "ParserFlatbuffersSerializeFixture.hpp" +#include "../Deserializer.hpp" + +#include <string> + +BOOST_AUTO_TEST_SUITE(Deserializer) + +struct FillFixture : public ParserFlatbuffersSerializeFixture +{ + explicit FillFixture() + { + m_JsonString = R"( + { + layers: [ + { + layer_type: "InputLayer", + layer: { + base: { + base: { + layerName: "InputLayer", + layerType: "Input", + inputSlots: [ + + ], + outputSlots: [ + { + tensorInfo: { + dimensions: [ + 4 + ], + dataType: "Float32", + quantizationScale: 0.0 + } + } + ] + } + } + } + }, + { + layer_type: "FillLayer", + layer: { + base: { + index: 1, + layerName: "FillLayer", + layerType: "Fill", + inputSlots: [ + { + connection: { + sourceLayerIndex: 0, + outputSlotIndex: 0 + } + } + ], + outputSlots: [ + { + tensorInfo: { + dimensions: [ + 1, + 3, + 3, + 1 + ], + dataType: "Float32", + quantizationScale: 0.0 + } + } + ] + }, + descriptor: { + value: 1.0 + } + } + }, + { + layer_type: "OutputLayer", + layer: { + base: { + base: { + index: 2, + layerName: "OutputLayer", + layerType: "Output", + inputSlots: [ + { + connection: { + sourceLayerIndex: 1, + outputSlotIndex: 0 + } + } + ], + outputSlots: [ + + ] + } + } + } + } + ], + inputIds: [ + 0 + ], + outputIds: [ + 0 + ], + featureVersions: { + bindingIdsScheme: 1 + } + } + )"; + Setup(); + } +}; + + +struct SimpleFillFixture : FillFixture +{ + SimpleFillFixture() : FillFixture() {} +}; + +BOOST_FIXTURE_TEST_CASE(Fill, SimpleFillFixture) +{ + RunTest<4, armnn::DataType::Float32>( + 0, + {{"InputLayer", { 1, 3, 3, 1 }}}, + {{"OutputLayer",{ 1, 1, 1, 1, 1, 1, 1, 1, 1}}}); +} + +BOOST_AUTO_TEST_SUITE_END() |