diff options
Diffstat (limited to 'src/core/CL/CLHelpers.cpp')
-rw-r--r-- | src/core/CL/CLHelpers.cpp | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index 23c24c0337..df06aff647 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/Log.h" #include "arm_compute/core/Types.h" +#include <utility> #include <vector> namespace arm_compute @@ -164,4 +165,41 @@ bool device_supports_extension(const cl::Device &device, const char *extension_n return (pos != std::string::npos); } +bool cl_winograd_convolution_layer_supported(const Size2D &output_tile, const Size2D &kernel_size, DataLayout data_layout) +{ + ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN); + + using WinogradConfiguration = std::pair<std::pair<int, int>, std::pair<int, int>>; + + std::vector<WinogradConfiguration> winograd_filter_transform_nchw = + { + WinogradConfiguration(std::pair<int, int>(1, 2), std::pair<int, int>(1, 3)), + WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 3)), + WinogradConfiguration(std::pair<int, int>(2, 1), std::pair<int, int>(3, 1)), + WinogradConfiguration(std::pair<int, int>(4, 1), std::pair<int, int>(3, 1)), + WinogradConfiguration(std::pair<int, int>(2, 2), std::pair<int, int>(3, 3)), + WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(3, 3)), + WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(5, 5)) + }; + + std::vector<WinogradConfiguration> winograd_filter_transform_nhwc = + { + WinogradConfiguration(std::pair<int, int>(2, 2), std::pair<int, int>(3, 3)), + WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(3, 3)), + WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(5, 5)) + }; + + auto p = std::make_pair(std::pair<int, int>(output_tile.width, output_tile.height), + std::pair<int, int>(kernel_size.width, kernel_size.height)); + + // Return true if supported + if(data_layout == DataLayout::NCHW) + { + return (std::find(winograd_filter_transform_nchw.begin(), winograd_filter_transform_nchw.end(), p) != winograd_filter_transform_nchw.end()); + } + else + { + return (std::find(winograd_filter_transform_nhwc.begin(), winograd_filter_transform_nhwc.end(), p) != winograd_filter_transform_nhwc.end()); + } +} } // namespace arm_compute |