aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-15 17:30:37 +0100
committerJim Flynn <jim.flynn@arm.com>2021-09-16 09:08:39 +0000
commitf106ab745a12a5c773a9c315dcddef0c8bf11225 (patch)
treed0d3f2ca4e084c044e50fc1950f2fe933e309992 /src/armnnOnnxParser
parent7ba84d6881685d6ebfedc597a9af98b16fa42d51 (diff)
downloadarmnn-f106ab745a12a5c773a9c315dcddef0c8bf11225.tar.gz
Add support for Clip with attribute on ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I1bae42dade7eabf3da09252066e912e803a8ea32
Diffstat (limited to 'src/armnnOnnxParser')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp12
-rw-r--r--src/armnnOnnxParser/test/Clip.cpp73
2 files changed, 83 insertions, 2 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index a7e6902fdd..49f0271aeb 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -1189,8 +1189,16 @@ void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::A
if (func == ActivationFunction::BoundedReLu)
{
- desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
- desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
+ if (node.input_size() == 1 && node.attribute_size() > 0)
+ {
+ desc.m_A = ReadOptionalNodeFloatAttribute(node, "max", std::numeric_limits<float>::max());
+ desc.m_B = ReadOptionalNodeFloatAttribute(node, "min", std::numeric_limits<float>::lowest());
+ }
+ else
+ {
+ desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
+ desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
+ }
}
IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
diff --git a/src/armnnOnnxParser/test/Clip.cpp b/src/armnnOnnxParser/test/Clip.cpp
index b0447bcad5..2b43574d6c 100644
--- a/src/armnnOnnxParser/test/Clip.cpp
+++ b/src/armnnOnnxParser/test/Clip.cpp
@@ -62,6 +62,68 @@ struct ClipMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParse
}
};
+struct ClipAttributeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ ClipAttributeFixture(std::string min, std::string max)
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "CNTK"
+ producer_version: "2.5.1"
+ domain: "ai.cntk"
+ model_version: 1
+ graph {
+ name: "CNTKGraph"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 5
+ }
+ }
+ }
+ }
+ }
+ node {
+ input: "Input"
+ output: "Output"
+ name: "ActivationLayer"
+ op_type: "Clip"
+ attribute {
+ name: "min"
+ f: )" + min + R"(
+ type: FLOAT
+ }
+ attribute {
+ name: "max"
+ f: )" + max + R"(
+ type: FLOAT
+ }
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 5
+ }
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ Setup();
+ }
+};
+
struct ClipFixture : ClipMainFixture
{
ClipFixture() : ClipMainFixture("2", "3.5") {}
@@ -108,4 +170,15 @@ TEST_CASE_FIXTURE(ClipNoInputFixture, "ValidNoInputClipTest")
std::numeric_limits<float>::max()}}});
}
+struct ClipMinMaxAttributeFixture : ClipAttributeFixture
+{
+ ClipMinMaxAttributeFixture() : ClipAttributeFixture("2", "3.5") {}
+};
+
+TEST_CASE_FIXTURE(ClipMinMaxAttributeFixture, "ValidClipAttributeTest")
+{
+ RunTest<1>({{ "Input", { -1.5f, 1.25f, 3.5f, 8.0, 2.5}}},
+ {{ "Output", { 2.0f, 2.0f, 3.5f, 3.5, 2.5}}});
+}
+
}