aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp15
1 files changed, 13 insertions, 2 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 0c1af03af4..e4259980ca 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -5,7 +5,6 @@
#include "OnnxParser.hpp"
#include <armnn/Descriptors.hpp>
-#include <armnn/Utils.hpp>
#include <VerificationHelpers.hpp>
#include <boost/format.hpp>
@@ -352,6 +351,7 @@ const std::map<std::string, OnnxParser::OperationParsingFunction> OnnxParser::m_
{ "BatchNormalization", &OnnxParser::ParseBatchNormalization},
{ "GlobalAveragePool", &OnnxParser::ParseGlobalAveragePool},
{ "AveragePool", &OnnxParser::ParseAveragePool },
+ { "Clip", &OnnxParser::ParseClip },
{ "Constant", &OnnxParser::ParseConstant },
{ "MaxPool", &OnnxParser::ParseMaxPool },
{ "Reshape", &OnnxParser::ParseReshape },
@@ -1106,7 +1106,7 @@ void OnnxParser::ParseReshape(const onnx::NodeProto& node)
void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
{
- CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
+ CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
@@ -1114,6 +1114,12 @@ void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::Activ
ActivationDescriptor desc;
desc.m_Function = func;
+ if (func == ActivationFunction::BoundedReLu)
+ {
+ desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
+ desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
+ }
+
IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
BOOST_ASSERT(layer != nullptr);
@@ -1128,6 +1134,11 @@ void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::Activ
RegisterOutputSlots(layer, {node.output(0)});
}
+void OnnxParser::ParseClip(const onnx::NodeProto& node)
+{
+ ParseActivation(node, ActivationFunction::BoundedReLu);
+}
+
void OnnxParser::ParseSigmoid(const onnx::NodeProto& node)
{
ParseActivation(node, ActivationFunction::Sigmoid);