diff options
Diffstat (limited to 'src/runtime/CL/functions/CLLocallyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLLocallyConnectedLayer.cpp | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp index d15e5dfa3d..40bf032d69 100644 --- a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp @@ -48,7 +48,10 @@ void calculate_shapes(const ITensorInfo *input, const ITensorInfo *weights, cons // Get convolved dimensions unsigned int conv_w = 0; unsigned int conv_h = 0; - std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height, + std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), + input->dimension(1), + kernel_width, + kernel_height, conv_info); const size_t mat_weights_cols = weights->dimension(3); @@ -61,9 +64,12 @@ void calculate_shapes(const ITensorInfo *input, const ITensorInfo *weights, cons const size_t mat_input_rows = conv_w * conv_h; shape_im2col = input->tensor_shape(); + if(shape_im2col.num_dimensions() >= 3) + { + shape_im2col.remove_dimension(2); + } shape_im2col.set(0, mat_input_cols); shape_im2col.set(1, mat_input_rows); - shape_im2col.set(2, 1); shape_gemm = shape_im2col; shape_gemm.set(0, mat_weights_cols); |