diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.hpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.hpp | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index fac2599322..fb01fe8ba2 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -10,6 +10,7 @@ #include <schema_generated.h> #include <functional> +#include <unordered_map> #include <vector> namespace armnnTfLiteParser @@ -58,7 +59,7 @@ public: /// Return the output tensor names for a given subgraph virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const override; - TfLiteParser(); + TfLiteParser(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional()); virtual ~TfLiteParser() {} public: @@ -89,7 +90,9 @@ private: // signature for the parser functions using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex); + void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex); void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex); + void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType); void ParseAdd(size_t subgraphIndex, size_t operatorIndex); void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex); @@ -180,11 +183,16 @@ private: armnn::TensorInfo& tensorInfo, armnn::Optional<armnn::PermutationVector&> permutationVector); + // Settings for configuring the TfLiteParser + armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options; + /// The network we're building. Gets cleared after it is passed to the user armnn::INetworkPtr m_Network; - std::vector<OperatorParsingFunction> m_ParserFunctions; ModelPtr m_Model; + std::vector<OperatorParsingFunction> m_ParserFunctions; + std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions; + /// A mapping of an output slot to each of the input slots it should be connected to /// The outputSlot is from the layer that creates this tensor as one of its ouputs /// The inputSlots are from the layers that use this tensor as one of their inputs |