aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTee Jung <tee.ty.jung@openedges.com>2019-11-01 07:04:42 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-11-04 09:12:46 +0000
commit7ff9a6096e3c1facbd6786993a6437b9f72069d2 (patch)
treea4e590d3fbf75d065692b5c52b7bcce6133ed740
parentfcf6fd562f87595c814d8acbec04194421018c32 (diff)
downloadarmnn-7ff9a6096e3c1facbd6786993a6437b9f72069d2.tar.gz
Make onnx parser to support TanH / Sigmoid / LeakyRelu layers
Signed-off-by: Jung Tae-young tee.ty.jung@openedges.com Change-Id: I44d24b525b78b8d3fee0197abda7bd667eb04d83
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp26
-rw-r--r--src/armnnOnnxParser/OnnxParser.hpp6
2 files changed, 30 insertions, 2 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 9d374aed71..0d0cc253d2 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -337,7 +337,10 @@ const std::map<std::string, OnnxParser::OperationParsingFunction> OnnxParser::m_
{ "Constant", &OnnxParser::ParseConstant },
{ "MaxPool", &OnnxParser::ParseMaxPool },
{ "Reshape", &OnnxParser::ParseReshape },
+ { "Sigmoid", &OnnxParser::ParseSigmoid },
+ { "Tanh", &OnnxParser::ParseTanh },
{ "Relu", &OnnxParser::ParseRelu },
+ { "LeakyRelu", &OnnxParser::ParseLeakyRelu },
{ "Conv", &OnnxParser::ParseConv },
{ "Add", &OnnxParser::ParseAdd },
};
@@ -1083,7 +1086,7 @@ void OnnxParser::ParseReshape(const onnx::NodeProto& node)
}
}
-void OnnxParser::ParseRelu(const onnx::NodeProto& node)
+void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
{
CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
@@ -1091,7 +1094,7 @@ void OnnxParser::ParseRelu(const onnx::NodeProto& node)
VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
ActivationDescriptor desc;
- desc.m_Function = ActivationFunction::ReLu;
+ desc.m_Function = func;
IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
BOOST_ASSERT(layer != nullptr);
@@ -1107,6 +1110,25 @@ void OnnxParser::ParseRelu(const onnx::NodeProto& node)
RegisterOutputSlots(layer, {node.output(0)});
}
+void OnnxParser::ParseSigmoid(const onnx::NodeProto& node)
+{
+ ParseActivation(node, ActivationFunction::Sigmoid);
+}
+
+void OnnxParser::ParseTanh(const onnx::NodeProto& node)
+{
+ ParseActivation(node, ActivationFunction::TanH);
+}
+
+void OnnxParser::ParseRelu(const onnx::NodeProto& node)
+{
+ ParseActivation(node, ActivationFunction::ReLu);
+}
+
+void OnnxParser::ParseLeakyRelu(const onnx::NodeProto& node)
+{
+ ParseActivation(node, ActivationFunction::LeakyReLu);
+}
void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
{
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index 91927c24a8..a467180299 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -14,6 +14,7 @@
namespace armnn
{
class TensorInfo;
+enum class ActivationFunction;
}
namespace armnnOnnxParser
@@ -103,7 +104,12 @@ private:
void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
void ParseReshape(const onnx::NodeProto& nodeProto);
+
+ void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func);
+ void ParseSigmoid(const onnx::NodeProto& nodeProto);
+ void ParseTanh(const onnx::NodeProto& nodeProto);
void ParseRelu(const onnx::NodeProto& nodeProto);
+ void ParseLeakyRelu(const onnx::NodeProto& nodeProto);
void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
void ParseConv(const onnx::NodeProto& nodeProto);