From 9e4824c909b14dbaf7106e9527b0ffa22ef09bdc Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 12 Apr 2019 13:15:58 +0100 Subject: COMPMID-2111: ConcatenateLayer API should accept an index instead of an enum Alters the concatenate layer to be layout agnostic and accept an index as thec concatenation axis instead of an typed layout dependent enumeration. Change-Id: I0eaaf919f66a1ba1b09bbfb47c171fc1d4045530 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/994 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- src/graph/mutators/DepthConcatSubTensorMutator.cpp | 4 ++-- src/graph/mutators/GroupedConvolutionMutator.cpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'src/graph/mutators') diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp index 0e0a26b886..7994541b78 100644 --- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp +++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp @@ -62,9 +62,9 @@ void DepthConcatSubTensorMutator::mutate(Graph &g) // Get output tensor auto output_tensor = node->output(0); - // Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2) + // Check concatenation axis (Sub-tensor optimization is supported for concatenation axis >=2) auto *concat_node = arm_compute::utils::cast::polymorphic_downcast(node); - if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2) + if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc().layout, concat_node->concatenation_axis()) < 2) { continue; } diff --git a/src/graph/mutators/GroupedConvolutionMutator.cpp b/src/graph/mutators/GroupedConvolutionMutator.cpp index d69d2cd7d0..3d53f49218 100644 --- a/src/graph/mutators/GroupedConvolutionMutator.cpp +++ b/src/graph/mutators/GroupedConvolutionMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -47,12 +47,12 @@ NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPai // Split input const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]); - const unsigned int input_idx = get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL); + const unsigned int input_idx = get_dimension_idx(input_tensor_desc.layout, DataLayoutDimension::CHANNEL); NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx); // Split weights const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]); - const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc, DataLayoutDimension::BATCHES); + const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc.layout, DataLayoutDimension::BATCHES); NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx); // Split bias -- cgit v1.2.1