aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/UNIT/DynamicTensor.cpp1
-rw-r--r--tests/validation/CL/WeightsReshape.cpp20
-rw-r--r--tests/validation/fixtures/WeightsReshapeFixture.h16
3 files changed, 23 insertions, 14 deletions
diff --git a/tests/validation/CL/UNIT/DynamicTensor.cpp b/tests/validation/CL/UNIT/DynamicTensor.cpp
index f83a92ec2f..ac433721d8 100644
--- a/tests/validation/CL/UNIT/DynamicTensor.cpp
+++ b/tests/validation/CL/UNIT/DynamicTensor.cpp
@@ -31,7 +31,6 @@
#include "src/core/CL/kernels/CLFillBorderKernel.h"
#include "src/core/CL/kernels/CLL2NormalizeLayerKernel.h"
#include "src/core/CL/kernels/CLReductionOperationKernel.h"
-#include "src/core/CL/kernels/CLWeightsReshapeKernel.h"
#include "tests/AssetsLibrary.h"
#include "tests/CL/CLAccessor.h"
#include "tests/Globals.h"
diff --git a/tests/validation/CL/WeightsReshape.cpp b/tests/validation/CL/WeightsReshape.cpp
index d04c10cee2..93be75df98 100644
--- a/tests/validation/CL/WeightsReshape.cpp
+++ b/tests/validation/CL/WeightsReshape.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,7 +22,7 @@
* SOFTWARE.
*/
#include "arm_compute/core/Types.h"
-#include "src/core/CL/kernels/CLWeightsReshapeKernel.h"
+#include "src/core/gpu/cl/kernels/ClWeightsReshapeKernel.h"
#include "tests/CL/CLAccessor.h"
#include "tests/CL/Helper.h"
#include "tests/datasets/ShapeDatasets.h"
@@ -41,7 +41,7 @@ namespace validation
TEST_SUITE(CL)
TEST_SUITE(WeightsReshape)
-using CLWeightsReshape = CLSynthetizeFunction<CLWeightsReshapeKernel>;
+using ClWeightsReshape = ClSynthetizeOperatorWithBorder<opencl::kernels::ClWeightsReshapeKernel>;
/** Validate tests
*
@@ -87,15 +87,15 @@ framework::dataset::make("NumGroups", { 1, 1, 1, 2, 1, 2 })),
framework::dataset::make("Expected", { false, false, false, false, false, false })),
input_info, biases_info, output_info, num_groups, expected)
{
- bool status = bool(CLWeightsReshape::validate(&input_info, &biases_info, &output_info, num_groups));
+ bool status = bool(opencl::kernels::ClWeightsReshapeKernel::validate(&input_info, &biases_info, &output_info, num_groups));
ARM_COMPUTE_EXPECT(status == expected, framework::LogLevel::ERRORS);
}
template <typename T>
-using CLWeightsReshapeFixture = WeightsReshapeValidationFixture<CLTensor, CLAccessor, CLWeightsReshape, T>;
+using ClWeightsReshapeFixture = WeightsReshapeOpValidationFixture<CLTensor, CLAccessor, ClWeightsReshape, T>;
TEST_SUITE(Float)
-FIXTURE_DATA_TEST_CASE(FP32, CLWeightsReshapeFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(3U, 3U, 48U, 120U) }),
+FIXTURE_DATA_TEST_CASE(FP32, ClWeightsReshapeFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(3U, 3U, 48U, 120U) }),
framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("HasBias", { true, false })),
framework::dataset::make("NumGroups", { 1, 2 })))
@@ -104,7 +104,7 @@ FIXTURE_DATA_TEST_CASE(FP32, CLWeightsReshapeFixture<float>, framework::DatasetM
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(FP16, CLWeightsReshapeFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(13U, 13U, 96U, 240U) }),
+FIXTURE_DATA_TEST_CASE(FP16, ClWeightsReshapeFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(13U, 13U, 96U, 240U) }),
framework::dataset::make("DataType", DataType::F16)),
framework::dataset::make("HasBias", { true, false })),
framework::dataset::make("NumGroups", { 3, 4 })))
@@ -113,7 +113,7 @@ FIXTURE_DATA_TEST_CASE(FP16, CLWeightsReshapeFixture<half>, framework::DatasetMo
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(BFloat16, CLWeightsReshapeFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(9U, 9U, 96U, 240U) }),
+FIXTURE_DATA_TEST_CASE(BFloat16, ClWeightsReshapeFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(9U, 9U, 96U, 240U) }),
framework::dataset::make("DataType", DataType::BFLOAT16)),
framework::dataset::make("HasBias", { false })),
framework::dataset::make("NumGroups", { 3, 4 })))
@@ -125,7 +125,7 @@ FIXTURE_DATA_TEST_CASE(BFloat16, CLWeightsReshapeFixture<half>, framework::Datas
TEST_SUITE_END()
TEST_SUITE(Quantized)
-FIXTURE_DATA_TEST_CASE(QASYMM8, CLWeightsReshapeFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(5U, 5U, 48U, 120U) }),
+FIXTURE_DATA_TEST_CASE(QASYMM8, ClWeightsReshapeFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(5U, 5U, 48U, 120U) }),
framework::dataset::make("DataType", DataType::QASYMM8)),
framework::dataset::make("HasBias", { false })),
framework::dataset::make("NumGroups", { 1, 2 })))
@@ -134,7 +134,7 @@ FIXTURE_DATA_TEST_CASE(QASYMM8, CLWeightsReshapeFixture<uint8_t>, framework::Dat
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(QASYMM8_SIGNED, CLWeightsReshapeFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(5U, 5U, 48U, 120U) }),
+FIXTURE_DATA_TEST_CASE(QASYMM8_SIGNED, ClWeightsReshapeFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("InputShape", { TensorShape(5U, 5U, 48U, 120U) }),
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("HasBias", { false })),
framework::dataset::make("NumGroups", { 1, 2 })))
diff --git a/tests/validation/fixtures/WeightsReshapeFixture.h b/tests/validation/fixtures/WeightsReshapeFixture.h
index 0b3e76d677..7c7214acac 100644
--- a/tests/validation/fixtures/WeightsReshapeFixture.h
+++ b/tests/validation/fixtures/WeightsReshapeFixture.h
@@ -45,7 +45,7 @@ namespace validation
using namespace arm_compute::misc::shape_calculator;
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class WeightsReshapeValidationFixture : public framework::Fixture
+class WeightsReshapeOpValidationFixture : public framework::Fixture
{
public:
template <typename...>
@@ -73,7 +73,7 @@ protected:
// Create and configure function
FunctionType weights_reshape_func;
- weights_reshape_func.configure(&src, (has_bias ? &bias : nullptr), &dst, num_groups);
+ weights_reshape_func.configure(src.info(), (has_bias ? bias.info() : nullptr), dst.info(), num_groups);
ARM_COMPUTE_ASSERT(src.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
@@ -99,8 +99,18 @@ protected:
fill(AccessorType(bias), 1);
}
+ arm_compute::ITensorPack pack =
+ {
+ { arm_compute::TensorType::ACL_SRC, &src },
+ { arm_compute::TensorType::ACL_DST, &dst }
+ };
+
+ if(has_bias)
+ {
+ pack.add_const_tensor(arm_compute::TensorType::ACL_BIAS, &bias);
+ }
// Compute function
- weights_reshape_func.run();
+ weights_reshape_func.run(pack);
return dst;
}