diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 354f60d016..9cb8023463 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -184,7 +184,7 @@ inline TensorShape compute_fully_connected_reshaped_weights_shape(const ITensorI output_shape = compute_transposed_shape(*input); } - // If the we run multiple batches we need 1xW transpose, too. + // If we run multiple batches we need 1xW transpose, too. if(is_batched_fc_layer) { output_shape = compute_transposed_shape(input->clone()->set_tensor_shape(output_shape)); @@ -193,6 +193,29 @@ inline TensorShape compute_fully_connected_reshaped_weights_shape(const ITensorI return output_shape; } + +inline TensorShape compute_winograd_filter_transform_shape(const ITensorInfo &input) +{ + // COMPMID-984 (giaiod01) + TensorShape tensor_shape{ input.tensor_shape() }; + + if(input.data_layout() == DataLayout::NCHW) + { + tensor_shape.remove_dimension(0); + tensor_shape.set(Window::DimX, input.dimension(3)); + tensor_shape.set(Window::DimY, input.dimension(2)); + tensor_shape.set(Window::DimZ, 16); + } + else + { + tensor_shape.remove_dimension(1); + tensor_shape.set(Window::DimY, input.dimension(2)); + tensor_shape.set(Window::DimZ, 16); + } + + return tensor_shape; +} + inline TensorShape compute_winograd_input_transform_shape(const ITensorInfo &input, const PadStrideInfo &conv_info, const Size2D &kernel_size) { // Compute height |