aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/nodes/DepthwiseConvolutionLayerNode.cpp')
-rw-r--r--src/graph/nodes/DepthwiseConvolutionLayerNode.cpp31
1 files changed, 16 insertions, 15 deletions
diff --git a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp
index 67a39029e6..1a6f8d398d 100644
--- a/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp
+++ b/src/graph/nodes/DepthwiseConvolutionLayerNode.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Utils.h"
namespace arm_compute
{
@@ -53,17 +54,25 @@ PadStrideInfo DepthwiseConvolutionLayerNode::convolution_info() const
return _info;
}
-TensorShape DepthwiseConvolutionLayerNode::compute_output_shape(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo info)
+TensorDescriptor DepthwiseConvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
+ const TensorDescriptor &weights_descriptor,
+ const PadStrideInfo &info)
{
unsigned int output_width = 0;
unsigned int output_height = 0;
- std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(), weights_shape.x(), weights_shape.y(), info);
- TensorShape output_shape{ input_shape };
- output_shape.set(0, output_width);
- output_shape.set(1, output_height);
+ const unsigned int input_width = get_dimension_size(input_descriptor, DataLayoutDimension::WIDTH);
+ const unsigned int input_height = get_dimension_size(input_descriptor, DataLayoutDimension::HEIGHT);
+ const unsigned int kernel_width = get_dimension_size(weights_descriptor, DataLayoutDimension::WIDTH);
+ const unsigned int kernel_height = get_dimension_size(weights_descriptor, DataLayoutDimension::HEIGHT);
- return output_shape;
+ std::tie(output_width, output_height) = scaled_dimensions(input_width, input_height, kernel_width, kernel_height, info);
+
+ TensorDescriptor output_descriptor = input_descriptor;
+ output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), output_width);
+ output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), output_height);
+
+ return output_descriptor;
}
bool DepthwiseConvolutionLayerNode::forward_descriptors()
@@ -86,15 +95,7 @@ TensorDescriptor DepthwiseConvolutionLayerNode::configure_output(size_t idx) con
ARM_COMPUTE_ERROR_ON(src == nullptr || weights == nullptr);
- TensorDescriptor output_info = src->desc();
- TensorShape output_shape = compute_output_shape(src->desc().shape, weights->desc().shape, _info);
- output_info.shape = output_shape;
- return output_info;
-}
-
-Status DepthwiseConvolutionLayerNode::validate()
-{
- return Status{};
+ return compute_output_descriptor(src->desc(), weights->desc(), _info);
}
NodeType DepthwiseConvolutionLayerNode::type() const