aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorsurmeh01 <surabhi.mehta@arm.com>2018-03-29 16:29:27 +0100
committersurmeh01 <surabhi.mehta@arm.com>2018-03-29 16:29:27 +0100
commitbceff2fb3fc68bb0aa88b886900c34b77340c826 (patch)
treed867d3e090d58d3012dfbbac456e9ea8c7f789bc /include
parent4fcda0101ec3d110c1d6d7bee5c83416b645528a (diff)
downloadarmnn-bceff2fb3fc68bb0aa88b886900c34b77340c826.tar.gz
Release 18.03
Diffstat (limited to 'include')
-rw-r--r--include/armnn/INetwork.hpp2
-rw-r--r--include/armnn/Types.hpp3
-rw-r--r--include/armnn/TypesUtils.hpp50
-rw-r--r--include/armnn/Version.hpp2
-rw-r--r--include/armnnTfParser/ITfParser.hpp60
5 files changed, 116 insertions, 1 deletions
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 8545629c96..5cff810db5 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -63,6 +63,7 @@ public:
virtual const IOutputSlot& GetOutputSlot(unsigned int index) const = 0;
virtual IOutputSlot& GetOutputSlot(unsigned int index) = 0;
+ virtual LayerGuid GetGuid() const = 0;
protected:
~IConnectableLayer() {} // Objects are not deletable via the handle
};
@@ -265,6 +266,7 @@ public:
static void Destroy(IOptimizedNetwork* network);
virtual Status PrintGraph() = 0;
+ virtual Status SerializeToDot(std::ostream& stream) const = 0;
protected:
~IOptimizedNetwork() {}
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index e1aa393ecc..c9a4bf13e5 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -152,4 +152,7 @@ private:
SizeType m_NumDimMappings;
};
+// Define LayerGuid type.
+using LayerGuid = unsigned int;
+
}
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index a851b66b28..ba18e0045b 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -34,6 +34,56 @@ constexpr char const* GetComputeDeviceAsCString(Compute compute)
}
}
+constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
+{
+ switch (activation)
+ {
+ case ActivationFunction::Sigmoid: return "Sigmoid";
+ case ActivationFunction::TanH: return "TanH";
+ case ActivationFunction::Linear: return "Linear";
+ case ActivationFunction::ReLu: return "ReLu";
+ case ActivationFunction::BoundedReLu: return "BoundedReLu";
+ case ActivationFunction::SoftReLu: return "SoftReLu";
+ case ActivationFunction::LeakyReLu: return "LeakyReLu";
+ case ActivationFunction::Abs: return "Abs";
+ case ActivationFunction::Sqrt: return "Sqrt";
+ case ActivationFunction::Square: return "Square";
+ default: return "Unknown";
+ }
+}
+
+constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
+{
+ switch (pooling)
+ {
+ case PoolingAlgorithm::Average: return "Average";
+ case PoolingAlgorithm::Max: return "Max";
+ case PoolingAlgorithm::L2: return "L2";
+ default: return "Unknown";
+ }
+}
+
+constexpr char const* GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
+{
+ switch (rounding)
+ {
+ case OutputShapeRounding::Ceiling: return "Ceiling";
+ case OutputShapeRounding::Floor: return "Floor";
+ default: return "Unknown";
+ }
+}
+
+
+constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
+{
+ switch (method)
+ {
+ case PaddingMethod::Exclude: return "Exclude";
+ case PaddingMethod::IgnoreValue: return "IgnoreValue";
+ default: return "Unknown";
+ }
+}
+
constexpr unsigned int GetDataTypeSize(DataType dataType)
{
switch (dataType)
diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp
index 6ce8256faa..5fdcf8dbc6 100644
--- a/include/armnn/Version.hpp
+++ b/include/armnn/Version.hpp
@@ -9,4 +9,4 @@
// YYYY = 4-digit year number
// MM = 2-digit month number
// PP = 2-digit patch number
-#define ARMNN_VERSION "20180200"
+#define ARMNN_VERSION "20180300"
diff --git a/include/armnnTfParser/ITfParser.hpp b/include/armnnTfParser/ITfParser.hpp
new file mode 100644
index 0000000000..a6f56c8a19
--- /dev/null
+++ b/include/armnnTfParser/ITfParser.hpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "armnn/Types.hpp"
+#include "armnn/Tensor.hpp"
+#include "armnn/INetwork.hpp"
+
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+namespace armnnTfParser
+{
+
+using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
+
+class ITfParser;
+using ITfParserPtr = std::unique_ptr<ITfParser, void(*)(ITfParser* parser)>;
+
+/// parses a directed acyclic graph from a tensorflow protobuf file
+class ITfParser
+{
+public:
+ static ITfParser* CreateRaw();
+ static ITfParserPtr Create();
+ static void Destroy(ITfParser* parser);
+
+ /// Create the network from a protobuf text file on disk
+ virtual armnn::INetworkPtr CreateNetworkFromTextFile(
+ const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) = 0;
+
+ /// Create the network from a protobuf binary file on disk
+ virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
+ const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) = 0;
+
+ /// Create the network directly from protobuf text in a string. Useful for debugging/testing
+ virtual armnn::INetworkPtr CreateNetworkFromString(
+ const char* protoText,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) = 0;
+
+ /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
+ virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const = 0;
+
+ /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
+ virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const = 0;
+
+protected:
+ virtual ~ITfParser() {};
+};
+
+}