diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.hpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.hpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 836c4e8f51..49ccd2705c 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -64,20 +64,20 @@ public: public: // testable helpers - static ModelPtr LoadModelFromFile(const char * fileName); - static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len); - static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex); - static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex); - static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex); - static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex); + static ModelPtr LoadModelFromFile(const char* fileName); + static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len); + static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); + static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); + static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex); + static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex); static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex); - static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims, - const armnn::TensorInfo & inputTensorInfo); - static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo, - const std::vector<int32_t> & targetDimsIn); + static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims, + const armnn::TensorInfo& inputTensorInfo); + static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo, + const std::vector<int32_t>& targetDimsIn); /// Retrieve version in X.Y.Z form static const std::string GetVersion(); @@ -116,6 +116,7 @@ private: void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation); void ParseElu(size_t subgraphIndex, size_t operatorIndex); void ParseExp(size_t subgraphIndex, size_t operatorIndex); + void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex); void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex); void ParseGather(size_t subgraphIndex, size_t operatorIndex); void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex); |