diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 15 | ||||
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.hpp | 1 | ||||
-rw-r--r-- | src/armnnOnnxParser/test/Clip.cpp | 112 |
3 files changed, 126 insertions, 2 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 0c1af03af4..e4259980ca 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -5,7 +5,6 @@ #include "OnnxParser.hpp" #include <armnn/Descriptors.hpp> -#include <armnn/Utils.hpp> #include <VerificationHelpers.hpp> #include <boost/format.hpp> @@ -352,6 +351,7 @@ const std::map<std::string, OnnxParser::OperationParsingFunction> OnnxParser::m_ { "BatchNormalization", &OnnxParser::ParseBatchNormalization}, { "GlobalAveragePool", &OnnxParser::ParseGlobalAveragePool}, { "AveragePool", &OnnxParser::ParseAveragePool }, + { "Clip", &OnnxParser::ParseClip }, { "Constant", &OnnxParser::ParseConstant }, { "MaxPool", &OnnxParser::ParseMaxPool }, { "Reshape", &OnnxParser::ParseReshape }, @@ -1106,7 +1106,7 @@ void OnnxParser::ParseReshape(const onnx::NodeProto& node) void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func) { - CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1); + CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3); CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT)); @@ -1114,6 +1114,12 @@ void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::Activ ActivationDescriptor desc; desc.m_Function = func; + if (func == ActivationFunction::BoundedReLu) + { + desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2)); + desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1)); + } + IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str()); BOOST_ASSERT(layer != nullptr); @@ -1128,6 +1134,11 @@ void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::Activ RegisterOutputSlots(layer, {node.output(0)}); } +void OnnxParser::ParseClip(const onnx::NodeProto& node) +{ + ParseActivation(node, ActivationFunction::BoundedReLu); +} + void OnnxParser::ParseSigmoid(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::Sigmoid); diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index f9fa6d969f..cc012ff34c 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -106,6 +106,7 @@ private: void ParseReshape(const onnx::NodeProto& nodeProto); void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func); + void ParseClip(const onnx::NodeProto& nodeProto); void ParseSigmoid(const onnx::NodeProto& nodeProto); void ParseTanh(const onnx::NodeProto& nodeProto); void ParseRelu(const onnx::NodeProto& nodeProto); diff --git a/src/armnnOnnxParser/test/Clip.cpp b/src/armnnOnnxParser/test/Clip.cpp new file mode 100644 index 0000000000..6420304291 --- /dev/null +++ b/src/armnnOnnxParser/test/Clip.cpp @@ -0,0 +1,112 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <boost/test/unit_test.hpp> +#include "armnnOnnxParser/IOnnxParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(OnnxParser) + +struct ClipMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> +{ + ClipMainFixture(std::string min, std::string max) + { + m_Prototext = R"( + ir_version: 3 + producer_name: "CNTK" + producer_version: "2.5.1" + domain: "ai.cntk" + model_version: 1 + graph { + name: "CNTKGraph" + input { + name: "Input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + } + } + } + } + node { + input: "Input" + input:")" + min + R"(" + input:")" + max + R"(" + output: "Output" + name: "ActivationLayer" + op_type: "Clip" + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + } + } + } + } + } + opset_import { + version: 7 + })"; + Setup(); + } +}; + +struct ClipFixture : ClipMainFixture +{ + ClipFixture() : ClipMainFixture("2", "3.5") {} +}; + +BOOST_FIXTURE_TEST_CASE(ValidClipTest, ClipFixture) +{ + RunTest<1>({{"Input", { -1.5f, 1.25f, 3.5f, 8.0, 2.5}}}, + {{ "Output", { 2.0f, 2.0f, 3.5f, 3.5, 2.5}}}); +} + +struct ClipNoMaxInputFixture : ClipMainFixture +{ + ClipNoMaxInputFixture() : ClipMainFixture("0", std::string()) {} +}; + +BOOST_FIXTURE_TEST_CASE(ValidNoMaxInputClipTest, ClipNoMaxInputFixture) +{ + RunTest<1>({{"Input", { -1.5f, -5.25f, -0.5f, 8.0f, std::numeric_limits<float>::max() }}}, + {{ "Output", { 0.0f, 0.0f, 0.0f, 8.0f, std::numeric_limits<float>::max() }}}); +} + +struct ClipNoMinInputFixture : ClipMainFixture +{ + ClipNoMinInputFixture() : ClipMainFixture(std::string(), "6") {} +}; + +BOOST_FIXTURE_TEST_CASE(ValidNoMinInputClipTest, ClipNoMinInputFixture) +{ + RunTest<1>({{"Input", { std::numeric_limits<float>::lowest(), -5.25f, -0.5f, 8.0f, 200.0f }}}, + {{ "Output", { std::numeric_limits<float>::lowest(), -5.25f, -0.5f, 6.0f, 6.0f }}}); +} + +struct ClipNoInputFixture : ClipMainFixture +{ + ClipNoInputFixture() : ClipMainFixture(std::string(), std::string()) {} +}; + +BOOST_FIXTURE_TEST_CASE(ValidNoInputClipTest, ClipNoInputFixture) +{ + RunTest<1>({{"Input", { std::numeric_limits<float>::lowest(), -1.25f, 3.5f, 8.0f, + std::numeric_limits<float>::max()}}}, + {{ "Output", { std::numeric_limits<float>::lowest(), -1.25f, 3.5f, 8.0f, + std::numeric_limits<float>::max()}}}); +} + +BOOST_AUTO_TEST_SUITE_END() |