aboutsummaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-20 13:23:44 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite2220551b7a64b929650ba9a60529c31e70c13c5 (patch)
tree5d609887f15b4392cdade7bb388710ceafc62260 /src/graph
parenteff8d95991205e874091576e2d225f63246dd0bb (diff)
downloadComputeLibrary-e2220551b7a64b929650ba9a60529c31e70c13c5.tar.gz
COMPMID-1367: Enable NHWC in graph examples
Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/GraphBuilder.cpp18
-rw-r--r--src/graph/backends/CL/CLFunctionsFactory.cpp4
-rw-r--r--src/graph/backends/GLES/GCFunctionsFactory.cpp50
-rw-r--r--src/graph/backends/GLES/GCNodeValidator.cpp6
-rw-r--r--src/graph/backends/NEON/NEFunctionFactory.cpp8
-rw-r--r--src/graph/mutators/DepthConcatSubTensorMutator.cpp14
-rw-r--r--src/graph/nodes/ConcatenateLayerNode.cpp (renamed from src/graph/nodes/DepthConcatenateLayerNode.cpp)62
-rw-r--r--src/graph/printers/DotGraphPrinter.cpp10
8 files changed, 119 insertions, 53 deletions
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index d26039ec35..b3721719d9 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -88,10 +88,14 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPai
bool has_bias = (bias != EmptyNodeID);
// Split input
- NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2);
+ const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+ const unsigned int input_idx = get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL);
+ NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx);
// Split weights
- NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3);
+ const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]);
+ const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc, DataLayoutDimension::BATCHES);
+ NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx);
// Split bias
NodeID bias_split = EmptyNodeID;
@@ -122,7 +126,7 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPai
}
// Depth concatenate output
- return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs);
+ return GraphBuilder::add_concatenate_node(g, params, convolution_outputs, DataLayoutDimension::CHANNEL);
}
} // namespace
@@ -329,11 +333,11 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx
return deconv_nid;
}
-NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs)
+NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs, DataLayoutDimension axis)
{
ARM_COMPUTE_ERROR_ON(inputs.size() == 0);
- NodeID nid = g.add_node<DepthConcatenateLayerNode>(inputs.size());
+ NodeID nid = g.add_node<ConcatenateLayerNode>(inputs.size(), axis);
unsigned int i = 0;
for(const auto &input : inputs)
@@ -508,9 +512,9 @@ NodeID GraphBuilder::add_scale_layer(Graph &g, const NodeParams &params, NodeIdx
NodeIdxPair add_const_nidxp = { add_const_nid, 0 };
// Create node and connect
- NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::MUL);
+ NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::Mul);
NodeIdxPair mulnode_nidxp = { mul_node, 0 };
- NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::ADD);
+ NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::Add);
return add_node;
}
diff --git a/src/graph/backends/CL/CLFunctionsFactory.cpp b/src/graph/backends/CL/CLFunctionsFactory.cpp
index 4d6734846a..57871487ef 100644
--- a/src/graph/backends/CL/CLFunctionsFactory.cpp
+++ b/src/graph/backends/CL/CLFunctionsFactory.cpp
@@ -89,8 +89,8 @@ std::unique_ptr<IFunction> CLFunctionFactory::create(INode *node, GraphContext &
return detail::create_convolution_layer<CLConvolutionLayerFunctions, CLTargetInfo>(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
case NodeType::DeconvolutionLayer:
return detail::create_deconvolution_layer<CLDeconvolutionLayer, CLTargetInfo>(*polymorphic_downcast<DeconvolutionLayerNode *>(node), ctx);
- case NodeType::DepthConcatenateLayer:
- return detail::create_depth_concatenate_layer<CLDepthConcatenateLayer, CLTargetInfo>(*polymorphic_downcast<DepthConcatenateLayerNode *>(node));
+ case NodeType::ConcatenateLayer:
+ return detail::create_concatenate_layer<CLConcatenateLayer, CLTargetInfo>(*polymorphic_downcast<ConcatenateLayerNode *>(node));
case NodeType::DepthwiseConvolutionLayer:
return detail::create_depthwise_convolution_layer<CLDepthwiseConvolutionLayerFunctions, CLTargetInfo>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::EltwiseLayer:
diff --git a/src/graph/backends/GLES/GCFunctionsFactory.cpp b/src/graph/backends/GLES/GCFunctionsFactory.cpp
index e6bd5a5f02..f72513c87c 100644
--- a/src/graph/backends/GLES/GCFunctionsFactory.cpp
+++ b/src/graph/backends/GLES/GCFunctionsFactory.cpp
@@ -68,6 +68,42 @@ struct GCEltwiseFunctions
namespace detail
{
+// Specialize functions
+template <>
+std::unique_ptr<IFunction> create_concatenate_layer<GCDepthConcatenateLayer, GCTargetInfo>(ConcatenateLayerNode &node)
+{
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating Concatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
+ ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
+
+ // Return nullptr if depth concatenate is switched off
+ if(!node.is_enabled())
+ {
+ return nullptr;
+ }
+
+ // Extract IO and info
+ std::vector<GCTargetInfo::TensorType *> inputs;
+ for(unsigned int i = 0; i < node.num_inputs(); ++i)
+ {
+ inputs.push_back(get_backing_tensor<GCTargetInfo>(node.input(i)));
+ }
+ typename GCTargetInfo::TensorType *output = get_backing_tensor<GCTargetInfo>(node.output(0));
+
+ // Create and configure function
+ auto func = support::cpp14::make_unique<GCDepthConcatenateLayer>();
+ func->configure(inputs, output);
+
+ // Log info
+ ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type()
+ << " Target " << GCTargetInfo::TargetType
+ << " Data Type: " << output->info()->data_type()
+ << " Shape: " << output->info()->tensor_shape()
+ << " Num Inputs: " << inputs.size()
+ << std::endl);
+
+ return std::move(func);
+}
+
template <>
std::unique_ptr<IFunction> create_convolution_layer<GCConvolutionLayerFunctions, GCTargetInfo>(ConvolutionLayerNode &node, GraphContext &ctx)
{
@@ -92,7 +128,7 @@ std::unique_ptr<IFunction> create_convolution_layer<GCConvolutionLayerFunctions,
std::unique_ptr<IFunction> func;
std::string func_name;
- if(conv_algorithm == ConvolutionMethod::DIRECT)
+ if(conv_algorithm == ConvolutionMethod::Direct)
{
std::tie(func, func_name) = create_named_function<GCConvolutionLayerFunctions::DirectConvolutionLayer>(
std::string("DirectConvolutionLayer"),
@@ -139,7 +175,7 @@ std::unique_ptr<IFunction> create_depthwise_convolution_layer<GCDepthwiseConvolu
// Create and configure function (we assume that functions have been validated before creation)
std::unique_ptr<IFunction> func;
std::string func_name;
- if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3)
+ if(dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3)
{
std::tie(func, func_name) = create_named_function<GCDepthwiseConvolutionLayerFunctions::DepthwiseConvolutionLayer3x3>(
std::string("DepthwiseConvolutionLayer3x3"),
@@ -183,17 +219,17 @@ std::unique_ptr<IFunction> create_eltwise_layer<GCEltwiseFunctions, GCTargetInfo
std::unique_ptr<IFunction> func = nullptr;
std::string func_name;
- if(eltwise_op == EltwiseOperation::ADD)
+ if(eltwise_op == EltwiseOperation::Add)
{
std::tie(func, func_name) = create_named_function<GCEltwiseFunctions::Addition>(
std::string("GCArithmeticAddition"),
input1, input2, output, convert_policy);
}
- else if(eltwise_op == EltwiseOperation::SUB)
+ else if(eltwise_op == EltwiseOperation::Sub)
{
ARM_COMPUTE_ERROR("Arithmetic subtraction is not supported in GLES backend");
}
- else if(eltwise_op == EltwiseOperation::MUL)
+ else if(eltwise_op == EltwiseOperation::Mul)
{
std::tie(func, func_name) = create_named_function<GCEltwiseFunctions::Multiplication>(
std::string("PixelWiseMultiplication"),
@@ -232,8 +268,8 @@ std::unique_ptr<IFunction> GCFunctionFactory::create(INode *node, GraphContext &
return detail::create_batch_normalization_layer<GCBatchNormalizationLayer, GCTargetInfo>(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
case NodeType::ConvolutionLayer:
return detail::create_convolution_layer<GCConvolutionLayerFunctions, GCTargetInfo>(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
- case NodeType::DepthConcatenateLayer:
- return detail::create_depth_concatenate_layer<GCDepthConcatenateLayer, GCTargetInfo>(*polymorphic_downcast<DepthConcatenateLayerNode *>(node));
+ case NodeType::ConcatenateLayer:
+ return detail::create_concatenate_layer<GCDepthConcatenateLayer, GCTargetInfo>(*polymorphic_downcast<ConcatenateLayerNode *>(node));
case NodeType::DepthwiseConvolutionLayer:
return detail::create_depthwise_convolution_layer<GCDepthwiseConvolutionLayerFunctions, GCTargetInfo>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::EltwiseLayer:
diff --git a/src/graph/backends/GLES/GCNodeValidator.cpp b/src/graph/backends/GLES/GCNodeValidator.cpp
index 4bef89329a..8118a7c476 100644
--- a/src/graph/backends/GLES/GCNodeValidator.cpp
+++ b/src/graph/backends/GLES/GCNodeValidator.cpp
@@ -58,7 +58,7 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
// TODO (geopin01) : Switch when validation is implemented
// Validate function
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->tensor_shape().x() != 3 && weights->tensor_shape().y() != 3, "Unsupported depthwise convolution");
- node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::OPTIMIZED_3x3);
+ node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::Optimized3x3);
return Status{};
}
@@ -80,14 +80,14 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
const ConvolutionMethod conv_algorithm = node.convolution_method();
// Validate function
- if(conv_algorithm == ConvolutionMethod::DIRECT)
+ if(conv_algorithm == ConvolutionMethod::Direct)
{
bool is_square = weights->tensor_shape().x() == weights->tensor_shape().y();
bool is_direct = (weights->tensor_shape().x() == 1) || (weights->tensor_shape().x() == 3) || (weights->tensor_shape().x() == 5);
bool is_correct_stride = (conv_info.stride().first) <= 2 && (conv_info.stride().second <= 2);
if(!(is_square && is_direct && is_correct_stride))
{
- node.set_convolution_method(ConvolutionMethod::DEFAULT);
+ node.set_convolution_method(ConvolutionMethod::Default);
}
}
diff --git a/src/graph/backends/NEON/NEFunctionFactory.cpp b/src/graph/backends/NEON/NEFunctionFactory.cpp
index 3b7417da3f..6c912a02f1 100644
--- a/src/graph/backends/NEON/NEFunctionFactory.cpp
+++ b/src/graph/backends/NEON/NEFunctionFactory.cpp
@@ -102,7 +102,7 @@ std::unique_ptr<IFunction> create_convolution_layer<NEConvolutionLayerFunctions,
std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, Target::NEON);
std::unique_ptr<IFunction> func;
std::string func_name;
- if(conv_algorithm == ConvolutionMethod::DIRECT)
+ if(conv_algorithm == ConvolutionMethod::Direct)
{
std::tie(func, func_name) = create_named_memory_managed_function<NEDirectConvolutionLayer>(
std::string("DirectConvolutionLayer"), mm, input, weights, biases, output, conv_info);
@@ -112,7 +112,7 @@ std::unique_ptr<IFunction> create_convolution_layer<NEConvolutionLayerFunctions,
std::tie(func, func_name) = create_named_memory_managed_function<NEGEMMConvolutionLayer>(
std::string("GEMMConvolutionLayer"), mm, input, weights, biases, output, conv_info);
}
- else if(conv_algorithm == ConvolutionMethod::WINOGRAD)
+ else if(conv_algorithm == ConvolutionMethod::Winograd)
{
std::tie(func, func_name) = create_named_memory_managed_function<NEWinogradConvolutionLayer>(
std::string("WinogradConvolutionLayer"), mm, input, weights, biases, output, conv_info);
@@ -183,8 +183,8 @@ std::unique_ptr<IFunction> NEFunctionFactory::create(INode *node, GraphContext &
return detail::create_convolution_layer<NEConvolutionLayerFunctions, NETargetInfo>(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
case NodeType::DeconvolutionLayer:
return detail::create_deconvolution_layer<NEDeconvolutionLayer, NETargetInfo>(*polymorphic_downcast<DeconvolutionLayerNode *>(node), ctx);
- case NodeType::DepthConcatenateLayer:
- return detail::create_depth_concatenate_layer<NEDepthConcatenateLayer, NETargetInfo>(*polymorphic_downcast<DepthConcatenateLayerNode *>(node));
+ case NodeType::ConcatenateLayer:
+ return detail::create_concatenate_layer<NEConcatenateLayer, NETargetInfo>(*polymorphic_downcast<ConcatenateLayerNode *>(node));
case NodeType::DepthwiseConvolutionLayer:
return detail::create_depthwise_convolution_layer<NEDepthwiseConvolutionLayerFunctions, NETargetInfo>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::EltwiseLayer:
diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
index c56f4c5106..241c07b367 100644
--- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
@@ -25,8 +25,9 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
-#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h"
+#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
#include "arm_compute/core/utils/misc/Cast.h"
#include "arm_compute/core/utils/misc/Iterable.h"
@@ -45,11 +46,18 @@ void DepthConcatSubTensorMutator::mutate(Graph &g)
// Should be in reverse order of execution
for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
{
- if(node && node->type() == NodeType::DepthConcatenateLayer && node->output(0) != nullptr)
+ if(node && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr)
{
// Get output tensor
auto output_tensor = node->output(0);
+ // Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2)
+ auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get());
+ if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2)
+ {
+ continue;
+ }
+
// Check that all tensor have the same target and valid inputs
bool is_valid = std::all_of(node->input_edges().cbegin(), node->input_edges().cend(),
[&](const EdgeID & eid)
@@ -76,7 +84,7 @@ void DepthConcatSubTensorMutator::mutate(Graph &g)
depth += input_shape.z();
}
- auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<DepthConcatenateLayerNode *>(node.get());
+ auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get());
dc_node->set_enabled(false);
}
}
diff --git a/src/graph/nodes/DepthConcatenateLayerNode.cpp b/src/graph/nodes/ConcatenateLayerNode.cpp
index 08cccc1ff1..ade3f6e1a9 100644
--- a/src/graph/nodes/DepthConcatenateLayerNode.cpp
+++ b/src/graph/nodes/ConcatenateLayerNode.cpp
@@ -21,58 +21,74 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h"
+#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Utils.h"
+
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
namespace arm_compute
{
namespace graph
{
-DepthConcatenateLayerNode::DepthConcatenateLayerNode(unsigned int total_nodes)
- : _total_nodes(total_nodes), _is_enabled(true)
+ConcatenateLayerNode::ConcatenateLayerNode(unsigned int total_nodes, DataLayoutDimension axis)
+ : _total_nodes(total_nodes), _axis(axis), _is_enabled(true)
{
_input_edges.resize(_total_nodes, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
}
-void DepthConcatenateLayerNode::set_enabled(bool is_enabled)
+void ConcatenateLayerNode::set_enabled(bool is_enabled)
{
_is_enabled = is_enabled;
}
-bool DepthConcatenateLayerNode::is_enabled() const
+bool ConcatenateLayerNode::is_enabled() const
{
return _is_enabled;
}
-TensorDescriptor DepthConcatenateLayerNode::compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors)
+DataLayoutDimension ConcatenateLayerNode::concatenation_axis() const
+{
+ return _axis;
+}
+
+TensorDescriptor ConcatenateLayerNode::compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors,
+ DataLayoutDimension axis)
{
ARM_COMPUTE_ERROR_ON(input_descriptors.size() == 0);
TensorDescriptor output_descriptor = input_descriptors[0];
+ const int axis_idx = get_dimension_idx(output_descriptor, axis);
- size_t max_x = 0;
- size_t max_y = 0;
- size_t depth = 0;
-
- for(const auto &input_descriptor : input_descriptors)
+ // Extract shapes
+ std::vector<const TensorShape *> shapes;
+ for(auto &input_descriptor : input_descriptors)
{
- max_x = std::max(input_descriptor.shape.x(), max_x);
- max_y = std::max(input_descriptor.shape.y(), max_y);
- depth += input_descriptor.shape.z();
+ shapes.emplace_back(&input_descriptor.shape);
}
- output_descriptor.shape.set(0, max_x);
- output_descriptor.shape.set(1, max_y);
- output_descriptor.shape.set(2, depth);
+ // Calculate output shape
+ if(axis_idx == 0)
+ {
+ output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(shapes);
+ }
+ else if(axis_idx == 2)
+ {
+ output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(shapes);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Unsupported concatenation axis!");
+ }
return output_descriptor;
}
-bool DepthConcatenateLayerNode::forward_descriptors()
+bool ConcatenateLayerNode::forward_descriptors()
{
if(_outputs[0] != NullTensorID)
{
@@ -84,7 +100,7 @@ bool DepthConcatenateLayerNode::forward_descriptors()
return false;
}
-TensorDescriptor DepthConcatenateLayerNode::configure_output(size_t idx) const
+TensorDescriptor ConcatenateLayerNode::configure_output(size_t idx) const
{
ARM_COMPUTE_UNUSED(idx);
ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
@@ -106,18 +122,18 @@ TensorDescriptor DepthConcatenateLayerNode::configure_output(size_t idx) const
ARM_COMPUTE_ERROR_ON(t == nullptr);
inputs_descriptors.push_back(t->desc());
}
- output_info = compute_output_descriptor(inputs_descriptors);
+ output_info = compute_output_descriptor(inputs_descriptors, _axis);
}
return output_info;
}
-NodeType DepthConcatenateLayerNode::type() const
+NodeType ConcatenateLayerNode::type() const
{
- return NodeType::DepthConcatenateLayer;
+ return NodeType::ConcatenateLayer;
}
-void DepthConcatenateLayerNode::accept(INodeVisitor &v)
+void ConcatenateLayerNode::accept(INodeVisitor &v)
{
v.visit(*this);
}
diff --git a/src/graph/printers/DotGraphPrinter.cpp b/src/graph/printers/DotGraphPrinter.cpp
index 61cf42356f..ef156ea252 100644
--- a/src/graph/printers/DotGraphPrinter.cpp
+++ b/src/graph/printers/DotGraphPrinter.cpp
@@ -47,17 +47,19 @@ void DotGraphVisitor::visit(BatchNormalizationLayerNode &n)
_info = ss.str();
}
-void DotGraphVisitor::visit(ConvolutionLayerNode &n)
+void DotGraphVisitor::visit(ConcatenateLayerNode &n)
{
std::stringstream ss;
- ss << n.convolution_method();
+ ss << "Enabled: " << n.is_enabled();
+ ss << R"( \n )";
+ ss << "Axis: " << n.concatenation_axis();
_info = ss.str();
}
-void DotGraphVisitor::visit(DepthConcatenateLayerNode &n)
+void DotGraphVisitor::visit(ConvolutionLayerNode &n)
{
std::stringstream ss;
- ss << "Enabled: " << n.is_enabled();
+ ss << n.convolution_method();
_info = ss.str();
}