aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-03 12:06:23 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commit12be7ab4876f77fecfab903df70791623219b3da (patch)
tree1cfa6852e60948bee9db0831a9f3abc97a2031c8 /arm_compute/graph
parente39334c15c7fd141bb8173d5017ea5ca157fca2c (diff)
downloadComputeLibrary-12be7ab4876f77fecfab903df70791623219b3da.tar.gz
COMPMID-1310: Create graph validation executables.
Change-Id: I9e0b57b1b83fe5a95777cdaeddba6ecef650bafc Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/138697 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph')
-rw-r--r--arm_compute/graph/Graph.h13
-rw-r--r--arm_compute/graph/TypeLoader.h105
-rw-r--r--arm_compute/graph/TypePrinter.h239
-rw-r--r--arm_compute/graph/algorithms/BFS.h2
-rw-r--r--arm_compute/graph/detail/ExecutionHelpers.h8
5 files changed, 120 insertions, 247 deletions
diff --git a/arm_compute/graph/Graph.h b/arm_compute/graph/Graph.h
index 16f5f97986..2a776826e5 100644
--- a/arm_compute/graph/Graph.h
+++ b/arm_compute/graph/Graph.h
@@ -72,7 +72,7 @@ public:
* @tparam NT Node operation
* @tparam Ts Arguments to operation
*
- * @param args Node arguments
+ * @param[in] args Node arguments
*
* @return ID of the node
*/
@@ -114,9 +114,11 @@ public:
GraphID id() const;
/** Returns graph input nodes
*
- * @return vector containing the graph inputs
+ * @param[in] type Type of nodes to return
+ *
+ * @return vector containing the graph node of given type
*/
- const std::vector<NodeID> &inputs();
+ const std::vector<NodeID> &nodes(NodeType type);
/** Returns nodes of graph
*
* @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
@@ -238,10 +240,7 @@ inline NodeID Graph::add_node(Ts &&... args)
node->set_id(nid);
// Keep track of input nodes
- if(node->type() == NodeType::Input)
- {
- _tagged_nodes[NodeType::Input].push_back(nid);
- }
+ _tagged_nodes[node->type()].push_back(nid);
// Associate a new tensor with each output
for(auto &output : node->_outputs)
diff --git a/arm_compute/graph/TypeLoader.h b/arm_compute/graph/TypeLoader.h
new file mode 100644
index 0000000000..77f096133d
--- /dev/null
+++ b/arm_compute/graph/TypeLoader.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_GRAPH_TYPE_LOADER_H__
+#define __ARM_COMPUTE_GRAPH_TYPE_LOADER_H__
+
+#include "arm_compute/graph/Types.h"
+
+#include <istream>
+
+namespace arm_compute
+{
+/** Converts a string to a strong types enumeration @ref DataType
+ *
+ * @param[in] name String to convert
+ *
+ * @return Converted DataType enumeration
+ */
+arm_compute::DataType data_type_from_name(const std::string &name);
+
+/** Input Stream operator for @ref DataType
+ *
+ * @param[in] stream Stream to parse
+ * @param[out] data_type Output data type
+ *
+ * @return Updated stream
+ */
+inline ::std::istream &operator>>(::std::istream &stream, arm_compute::DataType &data_type)
+{
+ std::string value;
+ stream >> value;
+ data_type = data_type_from_name(value);
+ return stream;
+}
+
+/** Converts a string to a strong types enumeration @ref DataLayout
+ *
+ * @param[in] name String to convert
+ *
+ * @return Converted DataLayout enumeration
+ */
+arm_compute::DataLayout data_layout_from_name(const std::string &name);
+
+/** Input Stream operator for @ref DataLayout
+ *
+ * @param[in] stream Stream to parse
+ * @param[out] data_layout Output data layout
+ *
+ * @return Updated stream
+ */
+inline ::std::istream &operator>>(::std::istream &stream, arm_compute::DataLayout &data_layout)
+{
+ std::string value;
+ stream >> value;
+ data_layout = data_layout_from_name(value);
+ return stream;
+}
+
+namespace graph
+{
+/** Converts a string to a strong types enumeration @ref Target
+ *
+ * @param[in] name String to convert
+ *
+ * @return Converted Target enumeration
+ */
+Target target_from_name(const std::string &name);
+
+/** Input Stream operator for @ref Target
+ *
+ * @param[in] stream Stream to parse
+ * @param[out] target Output target
+ *
+ * @return Updated stream
+ */
+inline ::std::istream &operator>>(::std::istream &stream, Target &target)
+{
+ std::string value;
+ stream >> value;
+ target = target_from_name(value);
+ return stream;
+}
+} // namespace graph
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_GRAPH_TYPE_LOADER_H__ */
diff --git a/arm_compute/graph/TypePrinter.h b/arm_compute/graph/TypePrinter.h
index 177a5e2f38..c3601f2373 100644
--- a/arm_compute/graph/TypePrinter.h
+++ b/arm_compute/graph/TypePrinter.h
@@ -28,89 +28,12 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/graph/Types.h"
+#include "utils/TypePrinter.h"
+
namespace arm_compute
{
namespace graph
{
-/** Formatted output of the Dimensions type. */
-template <typename T>
-inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::Dimensions<T> &dimensions)
-{
- if(dimensions.num_dimensions() > 0)
- {
- os << dimensions[0];
-
- for(unsigned int d = 1; d < dimensions.num_dimensions(); ++d)
- {
- os << "x" << dimensions[d];
- }
- }
-
- return os;
-}
-
-/** Formatted output of the Size2D type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const Size2D &size)
-{
- os << size.width << "x" << size.height;
-
- return os;
-}
-
-/** Formatted output of the DataType type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const DataType &data_type)
-{
- switch(data_type)
- {
- case DataType::UNKNOWN:
- os << "UNKNOWN";
- break;
- case DataType::U8:
- os << "U8";
- break;
- case DataType::QASYMM8:
- os << "QASYMM8";
- break;
- case DataType::S8:
- os << "S8";
- break;
- case DataType::U16:
- os << "U16";
- break;
- case DataType::S16:
- os << "S16";
- break;
- case DataType::U32:
- os << "U32";
- break;
- case DataType::S32:
- os << "S32";
- break;
- case DataType::U64:
- os << "U64";
- break;
- case DataType::S64:
- os << "S64";
- break;
- case DataType::F16:
- os << "F16";
- break;
- case DataType::F32:
- os << "F32";
- break;
- case DataType::F64:
- os << "F64";
- break;
- case DataType::SIZET:
- os << "SIZET";
- break;
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
-
- return os;
-}
-
/** Formatted output of the Target. */
inline ::std::ostream &operator<<(::std::ostream &os, const Target &target)
{
@@ -135,24 +58,6 @@ inline ::std::ostream &operator<<(::std::ostream &os, const Target &target)
return os;
}
-/** Formatted output of the DataLayout */
-inline ::std::ostream &operator<<(::std::ostream &os, const DataLayout &data_layout)
-{
- switch(data_layout)
- {
- case DataLayout::NCHW:
- os << "NCHW";
- break;
- case DataLayout::NHWC:
- os << "NHWC";
- break;
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
-
- return os;
-}
-
inline ::std::ostream &operator<<(::std::ostream &os, const NodeType &node_type)
{
switch(node_type)
@@ -224,100 +129,6 @@ inline ::std::ostream &operator<<(::std::ostream &os, const NodeType &node_type)
return os;
}
-/** Formatted output of the activation function type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const ActivationLayerInfo::ActivationFunction &act_function)
-{
- switch(act_function)
- {
- case ActivationLayerInfo::ActivationFunction::ABS:
- os << "ABS";
- break;
- case ActivationLayerInfo::ActivationFunction::LINEAR:
- os << "LINEAR";
- break;
- case ActivationLayerInfo::ActivationFunction::LOGISTIC:
- os << "LOGISTIC";
- break;
- case ActivationLayerInfo::ActivationFunction::RELU:
- os << "RELU";
- break;
- case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
- os << "BOUNDED_RELU";
- break;
- case ActivationLayerInfo::ActivationFunction::LEAKY_RELU:
- os << "LEAKY_RELU";
- break;
- case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
- os << "SOFT_RELU";
- break;
- case ActivationLayerInfo::ActivationFunction::SQRT:
- os << "SQRT";
- break;
- case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
- os << "LU_BOUNDED_RELU";
- break;
- case ActivationLayerInfo::ActivationFunction::SQUARE:
- os << "SQUARE";
- break;
- case ActivationLayerInfo::ActivationFunction::TANH:
- os << "TANH";
- break;
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
-
- return os;
-}
-
-inline std::string to_string(const ActivationLayerInfo::ActivationFunction &act_function)
-{
- std::stringstream str;
- str << act_function;
- return str.str();
-}
-
-/** Formatted output of the PoolingType type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const PoolingType &pool_type)
-{
- switch(pool_type)
- {
- case PoolingType::AVG:
- os << "AVG";
- break;
- case PoolingType::MAX:
- os << "MAX";
- break;
- case PoolingType::L2:
- os << "L2";
- break;
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
-
- return os;
-}
-
-/** Formatted output of the NormType type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const NormType &norm_type)
-{
- switch(norm_type)
- {
- case NormType::CROSS_MAP:
- os << "CROSS_MAP";
- break;
- case NormType::IN_MAP_1D:
- os << "IN_MAP_1D";
- break;
- case NormType::IN_MAP_2D:
- os << "IN_MAP_2D";
- break;
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
-
- return os;
-}
-
/** Formatted output of the EltwiseOperation type. */
inline ::std::ostream &operator<<(::std::ostream &os, const EltwiseOperation &eltwise_op)
{
@@ -401,52 +212,6 @@ inline ::std::ostream &operator<<(::std::ostream &os, const DepthwiseConvolution
return os;
}
-
-/** Formatted output of the PadStrideInfo type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const PadStrideInfo &pad_stride_info)
-{
- os << pad_stride_info.stride().first << "," << pad_stride_info.stride().second;
- os << ";";
- os << pad_stride_info.pad_left() << "," << pad_stride_info.pad_right() << ","
- << pad_stride_info.pad_top() << "," << pad_stride_info.pad_bottom();
-
- return os;
-}
-
-/** Formatted output of the QuantizationInfo type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const QuantizationInfo &quantization_info)
-{
- os << "Scale:" << quantization_info.scale << "~"
- << "Offset:" << quantization_info.offset;
- return os;
-}
-
-/** Formatted output of the Interpolation policy type.
- *
- * @param[out] os Output stream.
- * @param[in] policy Interpolation policy to output.
- *
- * @return Modified output stream.
- */
-inline ::std::ostream &operator<<(::std::ostream &os, const InterpolationPolicy &policy)
-{
- switch(policy)
- {
- case InterpolationPolicy::NEAREST_NEIGHBOR:
- os << "NEAREST NEIGHBOR";
- break;
- case InterpolationPolicy::BILINEAR:
- os << "BILINEAR";
- break;
- case InterpolationPolicy::AREA:
- os << "AREA";
- break;
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
-
- return os;
-}
} // namespace graph
} // namespace arm_compute
#endif /* __ARM_COMPUTE_GRAPH_TYPE_PRINTER_H__ */
diff --git a/arm_compute/graph/algorithms/BFS.h b/arm_compute/graph/algorithms/BFS.h
index 36ca872f15..97292d733b 100644
--- a/arm_compute/graph/algorithms/BFS.h
+++ b/arm_compute/graph/algorithms/BFS.h
@@ -85,7 +85,7 @@ inline std::vector<NodeID> bfs(Graph &g)
std::list<NodeID> queue;
// Push inputs and mark as visited
- for(auto &input : g.inputs())
+ for(auto &input : g.nodes(NodeType::Input))
{
if(input != EmptyNodeID)
{
diff --git a/arm_compute/graph/detail/ExecutionHelpers.h b/arm_compute/graph/detail/ExecutionHelpers.h
index 23dd207695..3a357776e4 100644
--- a/arm_compute/graph/detail/ExecutionHelpers.h
+++ b/arm_compute/graph/detail/ExecutionHelpers.h
@@ -95,13 +95,17 @@ void call_all_const_node_accessors(Graph &g);
/** Call all input node accessors
*
* @param[in] workload Workload to execute
+ *
+ * @return True if all the accesses were valid
*/
-void call_all_input_node_accessors(ExecutionWorkload &workload);
+bool call_all_input_node_accessors(ExecutionWorkload &workload);
/** Call all output node accessors
*
* @param[in] workload Workload to execute
+ *
+ * @return True if all the accessors expect more data
*/
-void call_all_output_node_accessors(ExecutionWorkload &workload);
+bool call_all_output_node_accessors(ExecutionWorkload &workload);
/** Prepares all tasks for execution
*
* @param[in] workload Workload to prepare