diff options
Diffstat (limited to 'src/graph/nodes/DepthwiseConvolutionLayerNode.cpp')
-rw-r--r-- | src/graph/nodes/DepthwiseConvolutionLayerNode.cpp | 31 |
1 files changed, 16 insertions, 15 deletions
diff --git a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp index 67a39029e6..1a6f8d398d 100644 --- a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp +++ b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp @@ -26,6 +26,7 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/INodeVisitor.h" +#include "arm_compute/graph/Utils.h" namespace arm_compute { @@ -53,17 +54,25 @@ PadStrideInfo DepthwiseConvolutionLayerNode::convolution_info() const return _info; } -TensorShape DepthwiseConvolutionLayerNode::compute_output_shape(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo info) +TensorDescriptor DepthwiseConvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, + const TensorDescriptor &weights_descriptor, + const PadStrideInfo &info) { unsigned int output_width = 0; unsigned int output_height = 0; - std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(), weights_shape.x(), weights_shape.y(), info); - TensorShape output_shape{ input_shape }; - output_shape.set(0, output_width); - output_shape.set(1, output_height); + const unsigned int input_width = get_dimension_size(input_descriptor, DataLayoutDimension::WIDTH); + const unsigned int input_height = get_dimension_size(input_descriptor, DataLayoutDimension::HEIGHT); + const unsigned int kernel_width = get_dimension_size(weights_descriptor, DataLayoutDimension::WIDTH); + const unsigned int kernel_height = get_dimension_size(weights_descriptor, DataLayoutDimension::HEIGHT); - return output_shape; + std::tie(output_width, output_height) = scaled_dimensions(input_width, input_height, kernel_width, kernel_height, info); + + TensorDescriptor output_descriptor = input_descriptor; + output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), output_width); + output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), output_height); + + return output_descriptor; } bool DepthwiseConvolutionLayerNode::forward_descriptors() @@ -86,15 +95,7 @@ TensorDescriptor DepthwiseConvolutionLayerNode::configure_output(size_t idx) con ARM_COMPUTE_ERROR_ON(src == nullptr || weights == nullptr); - TensorDescriptor output_info = src->desc(); - TensorShape output_shape = compute_output_shape(src->desc().shape, weights->desc().shape, _info); - output_info.shape = output_shape; - return output_info; -} - -Status DepthwiseConvolutionLayerNode::validate() -{ - return Status{}; + return compute_output_descriptor(src->desc(), weights->desc(), _info); } NodeType DepthwiseConvolutionLayerNode::type() const |