aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h16
1 files changed, 13 insertions, 3 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index ec13e1d3e0..e1452f5dfc 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -42,12 +42,22 @@
namespace arm_compute
{
-class NEConvolutionLayer;
-
namespace test
{
namespace validation
{
+namespace detail
+{
+template <typename ConvolutionFunction, typename TensorType>
+void configure_conv_function(ConvolutionFunction &func,
+ TensorType *src, const TensorType *weights, const TensorType *bias, TensorType *dst,
+ const PadStrideInfo &info, const WeightsInfo &weights_info,
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+{
+ func.configure(src, weights, bias, dst, info, weights_info, dilation, act_info, num_groups);
+}
+} // namespace detail
+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
class ConvolutionValidationGenericFixture : public framework::Fixture
{
@@ -171,7 +181,7 @@ protected:
// Create and configure function
FunctionType conv;
- conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info, num_groups);
+ detail::configure_conv_function(conv, &src, &weights, &bias, &dst, info, weights_info, dilation, act_info, num_groups);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);