diff options
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ConvolutionLayerFixture.h | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 5b8963ebfe..e4c2e4bfea 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -27,6 +27,9 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" #include "arm_compute/graph/Utils.h" +#ifdef ARM_COMPUTE_OPENCL_ENABLED +#include "arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h" +#endif // ARM_COMPUTE_OPENCL_ENABLED #include "arm_compute/runtime/NEON/NEScheduler.h" #include "src/core/NEON/kernels/arm_gemm/utils.hpp" #include "src/graph/mutators/MutatorUtils.h" @@ -43,6 +46,7 @@ #include "tests/validation/reference/Utils.h" #include <random> +#include <type_traits> namespace arm_compute { @@ -53,13 +57,30 @@ namespace validation namespace detail { template <typename ConvolutionFunction, typename TensorType> -void configure_conv_function(ConvolutionFunction &func, +#ifdef ARM_COMPUTE_OPENCL_ENABLED +std::enable_if_t<!std::is_same<ConvolutionFunction, CLGEMMConvolutionLayer>::value, void> +#else // ARM_COMPUTE_OPENCL_ENABLED +void +#endif // ARM_COMPUTE_OPENCL_ENABLED +configure_conv_function(ConvolutionFunction &func, + TensorType *src, const TensorType *weights, const TensorType *bias, TensorType *dst, + const PadStrideInfo &info, const WeightsInfo &weights_info, + const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups) +{ + func.configure(src, weights, bias, dst, info, weights_info, dilation, act_info, false /* enable_fast_math */, num_groups); +} + +#ifdef ARM_COMPUTE_OPENCL_ENABLED +template <typename ConvolutionFunction, typename TensorType> +std::enable_if_t<std::is_same<ConvolutionFunction, CLGEMMConvolutionLayer>::value, void> +configure_conv_function(ConvolutionFunction &func, TensorType *src, const TensorType *weights, const TensorType *bias, TensorType *dst, const PadStrideInfo &info, const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups) { func.configure(src, weights, bias, dst, info, weights_info, dilation, act_info, num_groups); } +#endif // ARM_COMPUTE_OPENCL_ENABLED } // namespace detail template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW> |