diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/INetwork.hpp | 2 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 3 | ||||
-rw-r--r-- | include/armnn/TypesUtils.hpp | 50 | ||||
-rw-r--r-- | include/armnn/Version.hpp | 2 | ||||
-rw-r--r-- | include/armnnTfParser/ITfParser.hpp | 60 |
5 files changed, 116 insertions, 1 deletions
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 8545629c96..5cff810db5 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -63,6 +63,7 @@ public: virtual const IOutputSlot& GetOutputSlot(unsigned int index) const = 0; virtual IOutputSlot& GetOutputSlot(unsigned int index) = 0; + virtual LayerGuid GetGuid() const = 0; protected: ~IConnectableLayer() {} // Objects are not deletable via the handle }; @@ -265,6 +266,7 @@ public: static void Destroy(IOptimizedNetwork* network); virtual Status PrintGraph() = 0; + virtual Status SerializeToDot(std::ostream& stream) const = 0; protected: ~IOptimizedNetwork() {} diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index e1aa393ecc..c9a4bf13e5 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -152,4 +152,7 @@ private: SizeType m_NumDimMappings; }; +// Define LayerGuid type. +using LayerGuid = unsigned int; + } diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index a851b66b28..ba18e0045b 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -34,6 +34,56 @@ constexpr char const* GetComputeDeviceAsCString(Compute compute) } } +constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation) +{ + switch (activation) + { + case ActivationFunction::Sigmoid: return "Sigmoid"; + case ActivationFunction::TanH: return "TanH"; + case ActivationFunction::Linear: return "Linear"; + case ActivationFunction::ReLu: return "ReLu"; + case ActivationFunction::BoundedReLu: return "BoundedReLu"; + case ActivationFunction::SoftReLu: return "SoftReLu"; + case ActivationFunction::LeakyReLu: return "LeakyReLu"; + case ActivationFunction::Abs: return "Abs"; + case ActivationFunction::Sqrt: return "Sqrt"; + case ActivationFunction::Square: return "Square"; + default: return "Unknown"; + } +} + +constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling) +{ + switch (pooling) + { + case PoolingAlgorithm::Average: return "Average"; + case PoolingAlgorithm::Max: return "Max"; + case PoolingAlgorithm::L2: return "L2"; + default: return "Unknown"; + } +} + +constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding) +{ + switch (rounding) + { + case OutputShapeRounding::Ceiling: return "Ceiling"; + case OutputShapeRounding::Floor: return "Floor"; + default: return "Unknown"; + } +} + + +constexpr char const* GetPaddingMethodAsCString(PaddingMethod method) +{ + switch (method) + { + case PaddingMethod::Exclude: return "Exclude"; + case PaddingMethod::IgnoreValue: return "IgnoreValue"; + default: return "Unknown"; + } +} + constexpr unsigned int GetDataTypeSize(DataType dataType) { switch (dataType) diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp index 6ce8256faa..5fdcf8dbc6 100644 --- a/include/armnn/Version.hpp +++ b/include/armnn/Version.hpp @@ -9,4 +9,4 @@ // YYYY = 4-digit year number // MM = 2-digit month number // PP = 2-digit patch number -#define ARMNN_VERSION "20180200" +#define ARMNN_VERSION "20180300" diff --git a/include/armnnTfParser/ITfParser.hpp b/include/armnnTfParser/ITfParser.hpp new file mode 100644 index 0000000000..a6f56c8a19 --- /dev/null +++ b/include/armnnTfParser/ITfParser.hpp @@ -0,0 +1,60 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "armnn/Types.hpp" +#include "armnn/Tensor.hpp" +#include "armnn/INetwork.hpp" + +#include <map> +#include <memory> +#include <unordered_map> +#include <vector> + +namespace armnnTfParser +{ + +using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>; + +class ITfParser; +using ITfParserPtr = std::unique_ptr<ITfParser, void(*)(ITfParser* parser)>; + +/// parses a directed acyclic graph from a tensorflow protobuf file +class ITfParser +{ +public: + static ITfParser* CreateRaw(); + static ITfParserPtr Create(); + static void Destroy(ITfParser* parser); + + /// Create the network from a protobuf text file on disk + virtual armnn::INetworkPtr CreateNetworkFromTextFile( + const char* graphFile, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) = 0; + + /// Create the network from a protobuf binary file on disk + virtual armnn::INetworkPtr CreateNetworkFromBinaryFile( + const char* graphFile, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) = 0; + + /// Create the network directly from protobuf text in a string. Useful for debugging/testing + virtual armnn::INetworkPtr CreateNetworkFromString( + const char* protoText, + const std::map<std::string, armnn::TensorShape>& inputShapes, + const std::vector<std::string>& requestedOutputs) = 0; + + /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name + virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const = 0; + + /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name + virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const = 0; + +protected: + virtual ~ITfParser() {}; +}; + +} |