aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.hpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp12
1 files changed, 10 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index fac2599322..fb01fe8ba2 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -10,6 +10,7 @@
#include <schema_generated.h>
#include <functional>
+#include <unordered_map>
#include <vector>
namespace armnnTfLiteParser
@@ -58,7 +59,7 @@ public:
/// Return the output tensor names for a given subgraph
virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const override;
- TfLiteParser();
+ TfLiteParser(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional());
virtual ~TfLiteParser() {}
public:
@@ -89,7 +90,9 @@ private:
// signature for the parser functions
using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex);
+ void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
+
void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
@@ -180,11 +183,16 @@ private:
armnn::TensorInfo& tensorInfo,
armnn::Optional<armnn::PermutationVector&> permutationVector);
+ // Settings for configuring the TfLiteParser
+ armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
+
/// The network we're building. Gets cleared after it is passed to the user
armnn::INetworkPtr m_Network;
- std::vector<OperatorParsingFunction> m_ParserFunctions;
ModelPtr m_Model;
+ std::vector<OperatorParsingFunction> m_ParserFunctions;
+ std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
+
/// A mapping of an output slot to each of the input slots it should be connected to
/// The outputSlot is from the layer that creates this tensor as one of its ouputs
/// The inputSlots are from the layers that use this tensor as one of their inputs