aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorFinn Williams <finn.williams@arm.com>2019-01-22 14:18:11 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-24 09:13:42 +0000
commitc42c38413c72cf7c31bb0353a2b836b7a2754f37 (patch)
tree5000f60a900a45b97f771ff3813c477c73309d72 /src/armnnTfLiteParser
parentdb2b160bf9e7759d0157dfa57ee940290f5170e3 (diff)
downloadarmnn-c42c38413c72cf7c31bb0353a2b836b7a2754f37.tar.gz
IVGCVSW-2430 Add logistic parser to tf-lite
* Added implementation and unit tests for sigmoid function for tf-lite parser * Refactored relu, relu6 and logisitc parser to reduce code duplication Change-Id: I00a2bd90bbc9144a2f84981f63b2cd1756b68a16
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp79
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp2
-rw-r--r--src/armnnTfLiteParser/test/Activations.cpp9
3 files changed, 55 insertions, 35 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index affd858d77..359695b94d 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -464,6 +464,7 @@ TfLiteParser::TfLiteParser()
m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D;
m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D;
m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParser::ParseFullyConnected;
+ m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParser::ParseLogistic;
m_ParserFunctions[tflite::BuiltinOperator_MAX_POOL_2D] = &TfLiteParser::ParseMaxPool2D;
m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu;
m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6;
@@ -1219,42 +1220,26 @@ void TfLiteParser::ParsePad(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+
void TfLiteParser::ParseRelu(size_t subgraphIndex, size_t operatorIndex)
{
- CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
-
- const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
- boost::ignore_unused(operatorPtr);
-
- auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(inputs.size(), 1);
-
- auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(outputs.size(), 1);
-
- auto layerName = str(boost::format("Activation:RELU:%1%:%2%") % subgraphIndex % operatorIndex);
- ActivationDescriptor activationDesc;
- activationDesc.m_Function = ActivationFunction::ReLu;
- IConnectableLayer* const layer =
- m_Network->AddActivationLayer(activationDesc, layerName.c_str());
-
- TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ ParseActivation(subgraphIndex,operatorIndex, ActivationFunction::ReLu);
+}
- // register the input connection slots for the layer, connections are made after all layers have been created
- // only the tensors for the inputs are relevant, exclude the const tensors
- auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+void TfLiteParser::ParseRelu6(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseActivation(subgraphIndex,operatorIndex, ActivationFunction::BoundedReLu);
+}
- // register the output connection slots for the layer, connections are made after all layers have been created
- auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
- RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+void TfLiteParser::ParseLogistic(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseActivation(subgraphIndex,operatorIndex,ActivationFunction::Sigmoid);
}
-void TfLiteParser::ParseRelu6(size_t subgraphIndex, size_t operatorIndex)
+
+void TfLiteParser::ParseActivation(size_t subgraphIndex, size_t operatorIndex, ActivationFunction activationType)
{
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
-
const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
boost::ignore_unused(operatorPtr);
@@ -1264,13 +1249,38 @@ void TfLiteParser::ParseRelu6(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
- auto layerName = str(boost::format("Activation:RELU6:%1%:%2%") % subgraphIndex % operatorIndex);
+ auto layerName = str(boost::format("Activation:"));
ActivationDescriptor activationDesc;
- activationDesc.m_Function = ActivationFunction::BoundedReLu;
- activationDesc.m_A = 6.0f;
- activationDesc.m_B = 0.0f;
- IConnectableLayer* const layer =
- m_Network->AddActivationLayer(activationDesc, layerName.c_str());
+ activationDesc.m_Function = activationType;
+
+ switch (activationType)
+ {
+ case ActivationFunction::ReLu:
+ {
+ layerName += str(boost::format("RELU:%1%:%2%") % subgraphIndex % operatorIndex);
+ break;
+ }
+ case ActivationFunction::BoundedReLu:
+ {
+ layerName += str(boost::format("RELU6:%1%:%2%") % subgraphIndex % operatorIndex);
+ activationDesc.m_A = 6.0f;
+ activationDesc.m_B = 0.0f;
+ break;
+ }
+ case ActivationFunction::Sigmoid:
+ {
+ layerName += str(boost::format("SIGMOID:%1%:%2%") % subgraphIndex % operatorIndex);
+ break;
+ }
+ default:
+ {
+ throw ParseException(
+ boost::str(boost::format("Unexpected ActivationFunction[%1%] when creating layerName "
+ " %2% ") %static_cast<int>(activationType)% CHECK_LOCATION().AsString()));
+ }
+ }
+
+ IConnectableLayer* const layer = m_Network->AddActivationLayer(activationDesc, layerName.c_str());
TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
@@ -1284,7 +1294,6 @@ void TfLiteParser::ParseRelu6(size_t subgraphIndex, size_t operatorIndex)
auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
-
armnn::TensorInfo TfLiteParser::OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
const std::vector<int32_t> & targetDimsIn)
{
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 34ae07f392..594f9e7f4b 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -90,11 +90,13 @@ private:
using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex);
void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
+ void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
+ void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
diff --git a/src/armnnTfLiteParser/test/Activations.cpp b/src/armnnTfLiteParser/test/Activations.cpp
index 534ae4cb73..dac16ce2c6 100644
--- a/src/armnnTfLiteParser/test/Activations.cpp
+++ b/src/armnnTfLiteParser/test/Activations.cpp
@@ -84,4 +84,13 @@ BOOST_FIXTURE_TEST_CASE(ParseReLu6, ReLu6Fixture)
{ 0.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.5f, 0.0f });
}
+struct SigmoidFixture : ActivationFixture
+{
+ SigmoidFixture() : ActivationFixture("LOGISTIC", "FLOAT32") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseLogistic, SigmoidFixture)
+{
+ RunTest<2, armnn::DataType::Float32>(0, { -1.0f, -0.5f, 4.0f, -4.0f, 0.0f, 0.5f, -0.75f },
+ {0.268941f, 0.377541f, 0.982013f, 0.0179862f, 0.5f, 0.622459f, 0.320821f });
+}
BOOST_AUTO_TEST_SUITE_END()