aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/CLHelpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/CLHelpers.cpp')
-rw-r--r--src/core/CL/CLHelpers.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp
index cd60c6e446..3965be76fd 100644
--- a/src/core/CL/CLHelpers.cpp
+++ b/src/core/CL/CLHelpers.cpp
@@ -153,7 +153,7 @@ bool cl_winograd_convolution_layer_supported(const Size2D &output_tile, const Si
using WinogradConfiguration = std::pair<std::pair<int, int>, std::pair<int, int>>;
- std::vector<WinogradConfiguration> winograd_filter_transform_nchw =
+ std::vector<WinogradConfiguration> winograd_configs_nchw =
{
WinogradConfiguration(std::pair<int, int>(1, 2), std::pair<int, int>(1, 3)),
WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 3)),
@@ -166,9 +166,11 @@ bool cl_winograd_convolution_layer_supported(const Size2D &output_tile, const Si
WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 5))
};
- std::vector<WinogradConfiguration> winograd_filter_transform_nhwc =
+ std::vector<WinogradConfiguration> winograd_configs_nhwc =
{
WinogradConfiguration(std::pair<int, int>(2, 2), std::pair<int, int>(3, 3)),
+ WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 3)),
+ WinogradConfiguration(std::pair<int, int>(4, 1), std::pair<int, int>(3, 1)),
WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(3, 3)),
WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(5, 5))
};
@@ -179,11 +181,11 @@ bool cl_winograd_convolution_layer_supported(const Size2D &output_tile, const Si
// Return true if supported
if(data_layout == DataLayout::NCHW)
{
- return (std::find(winograd_filter_transform_nchw.begin(), winograd_filter_transform_nchw.end(), p) != winograd_filter_transform_nchw.end());
+ return (std::find(winograd_configs_nchw.begin(), winograd_configs_nchw.end(), p) != winograd_configs_nchw.end());
}
else
{
- return (std::find(winograd_filter_transform_nhwc.begin(), winograd_filter_transform_nhwc.end(), p) != winograd_filter_transform_nhwc.end());
+ return (std::find(winograd_configs_nhwc.begin(), winograd_configs_nhwc.end(), p) != winograd_configs_nhwc.end());
}
}
} // namespace arm_compute