diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-04-12 13:15:58 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-04-15 16:52:22 +0000 |
commit | 9e4824c909b14dbaf7106e9527b0ffa22ef09bdc (patch) | |
tree | b1cc8f6a8b275a7e227e305f1b02870d5e0f30ec /src/graph/nodes/ResizeLayerNode.cpp | |
parent | d66094e37ecd747e85f30130e1a678bdbaf30788 (diff) | |
download | ComputeLibrary-9e4824c909b14dbaf7106e9527b0ffa22ef09bdc.tar.gz |
COMPMID-2111: ConcatenateLayer API should accept an index instead of an enum
Alters the concatenate layer to be layout agnostic and accept an index
as thec concatenation axis instead of an typed layout dependent
enumeration.
Change-Id: I0eaaf919f66a1ba1b09bbfb47c171fc1d4045530
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/994
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/nodes/ResizeLayerNode.cpp')
-rw-r--r-- | src/graph/nodes/ResizeLayerNode.cpp | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/graph/nodes/ResizeLayerNode.cpp b/src/graph/nodes/ResizeLayerNode.cpp index a6aa7bfe5c..a399229013 100644 --- a/src/graph/nodes/ResizeLayerNode.cpp +++ b/src/graph/nodes/ResizeLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -68,9 +68,10 @@ TensorDescriptor ResizeLayerNode::configure_output(size_t idx) const const Tensor *src = input(0); ARM_COMPUTE_ERROR_ON(src == nullptr); + const DataLayout data_layout = src->desc().layout; TensorDescriptor output_desc = src->desc(); - size_t width_idx = get_dimension_idx(output_desc, DataLayoutDimension::WIDTH); - size_t height_idx = get_dimension_idx(output_desc, DataLayoutDimension::HEIGHT); + size_t width_idx = get_dimension_idx(data_layout, DataLayoutDimension::WIDTH); + size_t height_idx = get_dimension_idx(data_layout, DataLayoutDimension::HEIGHT); output_desc.shape.set(width_idx, static_cast<int>(output_desc.shape[width_idx] * _scale_width)); output_desc.shape.set(height_idx, static_cast<int>(output_desc.shape[height_idx] * _scale_height)); |