aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-11-02 01:37:17 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-11-12 15:59:25 +0000
commitc0b6f76561580414f08633a804fc548ccad65659 (patch)
tree4d46b7f479de04f799e29095392948aeb370c029 /tests/validation/fixtures/ConvolutionLayerFixture.h
parent824061d9910ebb42cbe46b677c0b843db212c9a2 (diff)
downloadComputeLibrary-c0b6f76561580414f08633a804fc548ccad65659.tar.gz
COMPMID-3776: Indirect GEMM
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h16
1 files changed, 13 insertions, 3 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index ec13e1d3e0..e1452f5dfc 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -42,12 +42,22 @@
namespace arm_compute
{
-class NEConvolutionLayer;
-
namespace test
{
namespace validation
{
+namespace detail
+{
+template <typename ConvolutionFunction, typename TensorType>
+void configure_conv_function(ConvolutionFunction &func,
+ TensorType *src, const TensorType *weights, const TensorType *bias, TensorType *dst,
+ const PadStrideInfo &info, const WeightsInfo &weights_info,
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+{
+ func.configure(src, weights, bias, dst, info, weights_info, dilation, act_info, num_groups);
+}
+} // namespace detail
+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
class ConvolutionValidationGenericFixture : public framework::Fixture
{
@@ -171,7 +181,7 @@ protected:
// Create and configure function
FunctionType conv;
- conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info, num_groups);
+ detail::configure_conv_function(conv, &src, &weights, &bias, &dst, info, weights_info, dilation, act_info, num_groups);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);