aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h23
1 files changed, 22 insertions, 1 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 5b8963ebfe..e4c2e4bfea 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -27,6 +27,9 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/graph/Utils.h"
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+#include "arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h"
+#endif // ARM_COMPUTE_OPENCL_ENABLED
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
#include "src/graph/mutators/MutatorUtils.h"
@@ -43,6 +46,7 @@
#include "tests/validation/reference/Utils.h"
#include <random>
+#include <type_traits>
namespace arm_compute
{
@@ -53,13 +57,30 @@ namespace validation
namespace detail
{
template <typename ConvolutionFunction, typename TensorType>
-void configure_conv_function(ConvolutionFunction &func,
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+std::enable_if_t<!std::is_same<ConvolutionFunction, CLGEMMConvolutionLayer>::value, void>
+#else // ARM_COMPUTE_OPENCL_ENABLED
+void
+#endif // ARM_COMPUTE_OPENCL_ENABLED
+configure_conv_function(ConvolutionFunction &func,
+ TensorType *src, const TensorType *weights, const TensorType *bias, TensorType *dst,
+ const PadStrideInfo &info, const WeightsInfo &weights_info,
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+{
+ func.configure(src, weights, bias, dst, info, weights_info, dilation, act_info, false /* enable_fast_math */, num_groups);
+}
+
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+template <typename ConvolutionFunction, typename TensorType>
+std::enable_if_t<std::is_same<ConvolutionFunction, CLGEMMConvolutionLayer>::value, void>
+configure_conv_function(ConvolutionFunction &func,
TensorType *src, const TensorType *weights, const TensorType *bias, TensorType *dst,
const PadStrideInfo &info, const WeightsInfo &weights_info,
const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
func.configure(src, weights, bias, dst, info, weights_info, dilation, act_info, num_groups);
}
+#endif // ARM_COMPUTE_OPENCL_ENABLED
} // namespace detail
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>