aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp14
1 files changed, 8 insertions, 6 deletions
diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
index f7517a50a3..8e3d010786 100644
--- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
@@ -82,9 +82,10 @@ Status validate_arguments_3x3(const ITensorInfo *input, const ITensorInfo *weigh
if(needs_permute)
{
- TensorShape permuted_input_shape = input->tensor_shape();
- TensorShape permuted_weights_shape = weights->tensor_shape();
- TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier, dilation);
+ TensorShape permuted_input_shape = input->tensor_shape();
+ TensorShape permuted_weights_shape = weights->tensor_shape();
+ const ConvolutionInfo info{ conv_info, depth_multiplier, ActivationLayerInfo(), dilation };
+ TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info);
permute(permuted_input_shape, PermutationVector(1U, 2U, 0U));
permute(permuted_weights_shape, PermutationVector(1U, 2U, 0U));
@@ -272,9 +273,10 @@ Status CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::validate
if(needs_permute)
{
- TensorShape permuted_input_shape = input->tensor_shape();
- TensorShape permuted_weights_shape = weights->tensor_shape();
- TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier, dilation);
+ TensorShape permuted_input_shape = input->tensor_shape();
+ TensorShape permuted_weights_shape = weights->tensor_shape();
+ const ConvolutionInfo info{ conv_info, depth_multiplier, ActivationLayerInfo(), dilation };
+ TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info);
permute(permuted_input_shape, PermutationVector(2U, 0U, 1U));
permute(permuted_weights_shape, PermutationVector(2U, 0U, 1U));