aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinnWilliamsArm <Finn.Williams@arm.com>2019-06-28 15:07:10 +0100
committerFinnWilliamsArm <Finn.Williams@arm.com>2019-07-01 10:53:47 +0100
commit6fb339a7d202a9c64d8c7843d630fe8ab7be9f33 (patch)
tree08f7119b14c042bc009cae4ca90c0c49c5c15387
parenta9075df5b704e4f4432bf26027e3ba671d4596f0 (diff)
downloadarmnn-6fb339a7d202a9c64d8c7843d630fe8ab7be9f33.tar.gz
IVGCVSW-3364 Add serialization support for Resize layer
Signed-off-by: FinnWilliamsArm <Finn.Williams@arm.com> Change-Id: I3b1af816cefc1760f63324f365de93f899c9c9ee
-rw-r--r--src/armnnDeserializer/Deserializer.cpp44
-rw-r--r--src/armnnDeserializer/Deserializer.hpp1
-rw-r--r--src/armnnDeserializer/DeserializerSupport.md1
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs23
-rw-r--r--src/armnnSerializer/Serializer.cpp15
-rw-r--r--src/armnnSerializer/SerializerSupport.md1
-rw-r--r--src/armnnSerializer/SerializerUtils.cpp13
-rw-r--r--src/armnnSerializer/SerializerUtils.hpp2
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp59
9 files changed, 156 insertions, 3 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 5372606689..d853a08264 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -217,6 +217,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
m_ParserFunctions[Layer_QuantizeLayer] = &Deserializer::ParseQuantize;
m_ParserFunctions[Layer_ReshapeLayer] = &Deserializer::ParseReshape;
m_ParserFunctions[Layer_ResizeBilinearLayer] = &Deserializer::ParseResizeBilinear;
+ m_ParserFunctions[Layer_ResizeLayer] = &Deserializer::ParseResize;
m_ParserFunctions[Layer_RsqrtLayer] = &Deserializer::ParseRsqrt;
m_ParserFunctions[Layer_SoftmaxLayer] = &Deserializer::ParseSoftmax;
m_ParserFunctions[Layer_SpaceToBatchNdLayer] = &Deserializer::ParseSpaceToBatchNd;
@@ -302,6 +303,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base();
case Layer::Layer_ResizeBilinearLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeBilinearLayer()->base();
+ case Layer::Layer_ResizeLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeLayer()->base();
case Layer::Layer_RsqrtLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_RsqrtLayer()->base();
case Layer::Layer_SoftmaxLayer:
@@ -389,6 +392,19 @@ armnn::ActivationFunction ToActivationFunction(armnnSerializer::ActivationFuncti
}
}
+armnn::ResizeMethod ToResizeMethod(armnnSerializer::ResizeMethod method)
+{
+ switch (method)
+ {
+ case armnnSerializer::ResizeMethod_NearestNeighbor:
+ return armnn::ResizeMethod::NearestNeighbor;
+ case armnnSerializer::ResizeMethod_Bilinear:
+ return armnn::ResizeMethod::NearestNeighbor;
+ default:
+ return armnn::ResizeMethod::NearestNeighbor;
+ }
+}
+
armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
{
armnn::DataType type;
@@ -1643,6 +1659,34 @@ void Deserializer::ParseReshape(GraphPtr graph, unsigned int layerIndex)
RegisterOutputSlots(graph, layerIndex, layer);
}
+void Deserializer::ParseResize(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ Deserializer::TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 1);
+
+ Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_ResizeLayer()->descriptor();
+
+ armnn::ResizeDescriptor descriptor;
+ descriptor.m_TargetWidth = flatBufferDescriptor->targetWidth();
+ descriptor.m_TargetHeight = flatBufferDescriptor->targetHeight();
+ descriptor.m_Method = ToResizeMethod(flatBufferDescriptor->method());
+ descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
+
+ auto layerName = GetLayerName(graph, layerIndex);
+ IConnectableLayer* layer = m_Network->AddResizeLayer(descriptor, layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
void Deserializer::ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex)
{
CHECK_LAYERS(graph, 0, layerIndex);
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 1807e86f38..6a63f0fb6a 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -106,6 +106,7 @@ private:
void ParsePrelu(GraphPtr graph, unsigned int layerIndex);
void ParseQuantize(GraphPtr graph, unsigned int layerIndex);
void ParseReshape(GraphPtr graph, unsigned int layerIndex);
+ void ParseResize(GraphPtr graph, unsigned int layerIndex);
void ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex);
void ParseRsqrt(GraphPtr graph, unsigned int layerIndex);
void ParseSoftmax(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index d074be4eb4..1f51cd3fa0 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -46,5 +46,6 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Subtraction
* Switch
* TransposeConvolution2d
+* Resize
More machine learning layers will be supported in future releases.
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 7969d10598..09187927ae 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -36,6 +36,11 @@ enum DataLayout : byte {
NCHW = 1
}
+enum ResizeMethod: byte {
+ NearestNeighbor = 0,
+ Bilinear = 1,
+}
+
table TensorInfo {
dimensions:[uint];
dataType:DataType;
@@ -124,7 +129,8 @@ enum LayerType : uint {
Concat = 39,
SpaceToDepth = 40,
Prelu = 41,
- TransposeConvolution2d = 42
+ TransposeConvolution2d = 42,
+ Resize = 43
}
// Base layer table to be used as part of other layers
@@ -581,6 +587,18 @@ table TransposeConvolution2dDescriptor {
dataLayout:DataLayout = NCHW;
}
+table ResizeLayer {
+ base:LayerBase;
+ descriptor:ResizeDescriptor;
+}
+
+table ResizeDescriptor {
+ targetHeight:uint;
+ targetWidth:uint;
+ method:ResizeMethod = NearestNeighbor;
+ dataLayout:DataLayout;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -624,7 +642,8 @@ union Layer {
ConcatLayer,
SpaceToDepthLayer,
PreluLayer,
- TransposeConvolution2dLayer
+ TransposeConvolution2dLayer,
+ ResizeLayer
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 2d5877db63..57d674095b 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -660,7 +660,20 @@ void SerializerVisitor::VisitResizeLayer(const armnn::IConnectableLayer* layer,
const armnn::ResizeDescriptor& resizeDescriptor,
const char* name)
{
- throw armnn::Exception("SerializerVisitor::VisitResizeLayer is not yet implemented");
+ auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Resize);
+
+ auto flatBufferDescriptor =
+ CreateResizeDescriptor(m_flatBufferBuilder,
+ resizeDescriptor.m_TargetHeight,
+ resizeDescriptor.m_TargetWidth,
+ GetFlatBufferResizeMethod(resizeDescriptor.m_Method),
+ GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
+
+ auto flatBufferLayer = serializer::CreateResizeLayer(m_flatBufferBuilder,
+ flatBufferBaseLayer,
+ flatBufferDescriptor);
+
+ CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeLayer);
}
void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index 99bc332ac8..924bab423f 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -46,6 +46,7 @@ The Arm NN SDK Serializer currently supports the following layers:
* Subtraction
* Switch
* TransposeConvolution2d
+* Resize
More machine learning layers will be supported in future releases.
diff --git a/src/armnnSerializer/SerializerUtils.cpp b/src/armnnSerializer/SerializerUtils.cpp
index bfe795c8c4..9790d6e23b 100644
--- a/src/armnnSerializer/SerializerUtils.cpp
+++ b/src/armnnSerializer/SerializerUtils.cpp
@@ -124,4 +124,17 @@ armnnSerializer::NormalizationAlgorithmMethod GetFlatBufferNormalizationAlgorith
}
}
+armnnSerializer::ResizeMethod GetFlatBufferResizeMethod(armnn::ResizeMethod method)
+{
+ switch (method)
+ {
+ case armnn::ResizeMethod::NearestNeighbor:
+ return armnnSerializer::ResizeMethod_NearestNeighbor;
+ case armnn::ResizeMethod::Bilinear:
+ return armnnSerializer::ResizeMethod_Bilinear;
+ default:
+ return armnnSerializer::ResizeMethod_NearestNeighbor;
+ }
+}
+
} // namespace armnnSerializer \ No newline at end of file
diff --git a/src/armnnSerializer/SerializerUtils.hpp b/src/armnnSerializer/SerializerUtils.hpp
index 29cda0d629..578689b2dd 100644
--- a/src/armnnSerializer/SerializerUtils.hpp
+++ b/src/armnnSerializer/SerializerUtils.hpp
@@ -30,4 +30,6 @@ armnnSerializer::NormalizationAlgorithmChannel GetFlatBufferNormalizationAlgorit
armnnSerializer::NormalizationAlgorithmMethod GetFlatBufferNormalizationAlgorithmMethod(
armnn::NormalizationAlgorithmMethod normalizationAlgorithmMethod);
+armnnSerializer::ResizeMethod GetFlatBufferResizeMethod(armnn::ResizeMethod method);
+
} // namespace armnnSerializer
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 294adec12e..e51f76bd33 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -2025,6 +2025,65 @@ BOOST_AUTO_TEST_CASE(SerializeReshape)
deserializedNetwork->Accept(verifier);
}
+BOOST_AUTO_TEST_CASE(SerializeResize)
+{
+ class ResizeLayerVerifier : public LayerVerifierBase
+ {
+ public:
+ ResizeLayerVerifier(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const armnn::ResizeDescriptor& descriptor)
+ : LayerVerifierBase(layerName, inputInfos, outputInfos)
+ , m_Descriptor(descriptor) {}
+
+ void VisitResizeLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ResizeDescriptor& descriptor,
+ const char* name) override
+ {
+ VerifyNameAndConnections(layer, name);
+ VerifyDescriptor(descriptor);
+ }
+
+ private:
+ void VerifyDescriptor(const armnn::ResizeDescriptor& descriptor)
+ {
+ BOOST_CHECK(descriptor.m_DataLayout == m_Descriptor.m_DataLayout);
+ BOOST_CHECK(descriptor.m_TargetWidth == m_Descriptor.m_TargetWidth);
+ BOOST_CHECK(descriptor.m_TargetHeight == m_Descriptor.m_TargetHeight);
+ BOOST_CHECK(descriptor.m_Method == m_Descriptor.m_Method);
+ }
+
+ armnn::ResizeDescriptor m_Descriptor;
+ };
+
+ const std::string layerName("resize");
+ const armnn::TensorInfo inputInfo = armnn::TensorInfo({1, 3, 5, 5}, armnn::DataType::Float32);
+ const armnn::TensorInfo outputInfo = armnn::TensorInfo({1, 3, 2, 4}, armnn::DataType::Float32);
+
+ armnn::ResizeDescriptor desc;
+ desc.m_TargetWidth = 4;
+ desc.m_TargetHeight = 2;
+ desc.m_Method = armnn::ResizeMethod::NearestNeighbor;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const resizeLayer = network->AddResizeLayer(desc, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(resizeLayer->GetInputSlot(0));
+ resizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+ resizeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ ResizeLayerVerifier verifier(layerName, {inputInfo}, {outputInfo}, desc);
+ deserializedNetwork->Accept(verifier);
+}
+
BOOST_AUTO_TEST_CASE(SerializeResizeBilinear)
{
class ResizeBilinearLayerVerifier : public LayerVerifierBase