aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-07-18 13:42:54 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commitb6c8d24042616341c1fbca6e255a69561c73fedf (patch)
treef959d732981d4230d6491d8636afe3f31d64e798
parentfce87954ac2373e910ccb0d83a00f5958ba41e71 (diff)
downloadComputeLibrary-b6c8d24042616341c1fbca6e255a69561c73fedf.tar.gz
COMPMID-415: Use templates for data arguments
Change-Id: I815d705e7cf42022f7a203935dcaaa333a2801fe Reviewed-on: http://mpd-gerrit.cambridge.arm.com/80311 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
-rw-r--r--arm_compute/core/Dimensions.h2
-rw-r--r--framework/Macros.h141
-rw-r--r--framework/TestCase.h2
-rw-r--r--tests/SConscript1
-rw-r--r--tests/fixtures_new/ActivationLayerFixture.h1
-rw-r--r--tests/fixtures_new/AlexNetFixture.h1
-rw-r--r--tests/fixtures_new/ConvolutionLayerFixture.h1
-rw-r--r--tests/fixtures_new/FullyConnectedLayerFixture.h1
-rw-r--r--tests/fixtures_new/GEMMFixture.h1
-rw-r--r--tests/fixtures_new/LeNet5Fixture.h1
-rw-r--r--tests/fixtures_new/NormalizationLayerFixture.h1
-rw-r--r--tests/fixtures_new/PoolingLayerFixture.h1
12 files changed, 110 insertions, 44 deletions
diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h
index d2131018fa..96dd3711cb 100644
--- a/arm_compute/core/Dimensions.h
+++ b/arm_compute/core/Dimensions.h
@@ -49,7 +49,7 @@ public:
* @param[in] dims Values to initialize the dimensions.
*/
template <typename... Ts>
- Dimensions(Ts... dims)
+ explicit Dimensions(Ts... dims)
: _id{ { dims... } }, _num_dimensions{ sizeof...(dims) }
{
}
diff --git a/framework/Macros.h b/framework/Macros.h
index 38eb29c6a1..8c5afe9304 100644
--- a/framework/Macros.h
+++ b/framework/Macros.h
@@ -47,6 +47,52 @@
//
//
+// HELPER MACROS
+//
+
+#define CONCAT(ARG0, ARG1) ARG0##ARG1
+
+#define VARIADIC_SIZE_IMPL(e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, size, ...) size
+#define VARIADIC_SIZE(...) VARIADIC_SIZE_IMPL(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
+
+#define JOIN_PARAM1(OP, param) OP(0, param)
+#define JOIN_PARAM2(OP, param, ...) \
+ OP(1, param) \
+ , JOIN_PARAM1(OP, __VA_ARGS__)
+#define JOIN_PARAM3(OP, param, ...) \
+ OP(2, param) \
+ , JOIN_PARAM2(OP, __VA_ARGS__)
+#define JOIN_PARAM4(OP, param, ...) \
+ OP(3, param) \
+ , JOIN_PARAM3(OP, __VA_ARGS__)
+#define JOIN_PARAM5(OP, param, ...) \
+ OP(4, param) \
+ , JOIN_PARAM4(OP, __VA_ARGS__)
+#define JOIN_PARAM6(OP, param, ...) \
+ OP(5, param) \
+ , JOIN_PARAM5(OP, __VA_ARGS__)
+#define JOIN_PARAM7(OP, param, ...) \
+ OP(6, param) \
+ , JOIN_PARAM6(OP, __VA_ARGS__)
+#define JOIN_PARAM8(OP, param, ...) \
+ OP(7, param) \
+ , JOIN_PARAM7(OP, __VA_ARGS__)
+#define JOIN_PARAM9(OP, param, ...) \
+ OP(8, param) \
+ , JOIN_PARAM8(OP, __VA_ARGS__)
+#define JOIN_PARAM10(OP, param, ...) \
+ OP(9, param) \
+ , JOIN_PARAM9(OP, __VA_ARGS__)
+#define JOIN_PARAM(OP, NUM, ...) \
+ CONCAT(JOIN_PARAM, NUM) \
+ (OP, __VA_ARGS__)
+
+#define MAKE_TYPE_PARAM(i, name) typename T##i
+#define MAKE_ARG_PARAM(i, name) const T##i &name
+#define MAKE_TYPE_PARAMS(...) JOIN_PARAM(MAKE_TYPE_PARAM, VARIADIC_SIZE(__VA_ARGS__), __VA_ARGS__)
+#define MAKE_ARG_PARAMS(...) JOIN_PARAM(MAKE_ARG_PARAM, VARIADIC_SIZE(__VA_ARGS__), __VA_ARGS__)
+
+//
// TEST CASE MACROS
//
#define TEST_CASE_CONSTRUCTOR(TEST_NAME) \
@@ -61,10 +107,10 @@
{ \
FIXTURE::setup(); \
}
-#define FIXTURE_DATA_SETUP(FIXTURE) \
- void do_setup() override \
- { \
- apply(this, &FIXTURE::setup, _data); \
+#define FIXTURE_DATA_SETUP(FIXTURE) \
+ void do_setup() override \
+ { \
+ apply(this, &FIXTURE::setup<As...>, _data); \
}
#define FIXTURE_RUN(FIXTURE) \
void do_run() override \
@@ -81,10 +127,10 @@
{ \
#TEST_NAME, MODE \
}
-#define DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET) \
- static arm_compute::test::framework::detail::TestCaseRegistrar<TEST_NAME> TEST_NAME##_reg \
- { \
- #TEST_NAME, MODE, DATASET \
+#define DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET) \
+ static arm_compute::test::framework::detail::TestCaseRegistrar<TEST_NAME<decltype(DATASET)::type>> TEST_NAME##_reg \
+ { \
+ #TEST_NAME, MODE, DATASET \
}
#define TEST_CASE(TEST_NAME, MODE) \
@@ -97,19 +143,25 @@
TEST_REGISTRAR(TEST_NAME, MODE); \
void TEST_NAME::do_run()
-#define DATA_TEST_CASE(TEST_NAME, MODE, DATASET, ...) \
- class TEST_NAME : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type> \
- { \
- public: \
- DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \
- void do_run() override \
- { \
- arm_compute::test::framework::apply(this, &TEST_NAME::run, _data); \
- } \
- void run(__VA_ARGS__); \
- }; \
- DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \
- void TEST_NAME::run(__VA_ARGS__)
+#define DATA_TEST_CASE(TEST_NAME, MODE, DATASET, ...) \
+ template <typename T> \
+ class TEST_NAME; \
+ template <typename... As> \
+ class TEST_NAME<std::tuple<As...>> : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type> \
+ { \
+ public: \
+ DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \
+ void do_run() override \
+ { \
+ arm_compute::test::framework::apply(this, &TEST_NAME::run<As...>, _data); \
+ } \
+ template <MAKE_TYPE_PARAMS(__VA_ARGS__)> \
+ void run(MAKE_ARG_PARAMS(__VA_ARGS__)); \
+ }; \
+ DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \
+ template <typename... As> \
+ template <MAKE_TYPE_PARAMS(__VA_ARGS__)> \
+ void TEST_NAME<std::tuple<As...>>::run(MAKE_ARG_PARAMS(__VA_ARGS__))
#define FIXTURE_TEST_CASE(TEST_NAME, FIXTURE, MODE) \
class TEST_NAME : public arm_compute::test::framework::TestCase, public FIXTURE \
@@ -123,17 +175,21 @@
TEST_REGISTRAR(TEST_NAME, MODE); \
void TEST_NAME::do_run()
-#define FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \
- class TEST_NAME : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \
- { \
- public: \
- DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \
- FIXTURE_DATA_SETUP(FIXTURE) \
- void do_run() override; \
- FIXTURE_TEARDOWN(FIXTURE) \
- }; \
- DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \
- void TEST_NAME::do_run()
+#define FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \
+ template <typename T> \
+ class TEST_NAME; \
+ template <typename... As> \
+ class TEST_NAME<std::tuple<As...>> : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \
+ { \
+ public: \
+ DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \
+ FIXTURE_DATA_SETUP(FIXTURE) \
+ void do_run() override; \
+ FIXTURE_TEARDOWN(FIXTURE) \
+ }; \
+ DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \
+ template <typename... As> \
+ void TEST_NAME<std::tuple<As...>>::do_run()
#define REGISTER_FIXTURE_TEST_CASE(TEST_NAME, FIXTURE, MODE) \
class TEST_NAME : public arm_compute::test::framework::TestCase, public FIXTURE \
@@ -146,15 +202,18 @@
}; \
TEST_REGISTRAR(TEST_NAME, MODE)
-#define REGISTER_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \
- class TEST_NAME : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \
- { \
- public: \
- DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \
- FIXTURE_DATA_SETUP(FIXTURE) \
- FIXTURE_RUN(FIXTURE) \
- FIXTURE_TEARDOWN(FIXTURE) \
- }; \
+#define REGISTER_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \
+ template <typename T> \
+ class TEST_NAME; \
+ template <typename... As> \
+ class TEST_NAME<std::tuple<As...>> : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \
+ { \
+ public: \
+ DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \
+ FIXTURE_DATA_SETUP(FIXTURE) \
+ FIXTURE_RUN(FIXTURE) \
+ FIXTURE_TEARDOWN(FIXTURE) \
+ }; \
DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET)
//
// TEST CASE MACROS END
diff --git a/framework/TestCase.h b/framework/TestCase.h
index 43750b1d1b..dbb9312dee 100644
--- a/framework/TestCase.h
+++ b/framework/TestCase.h
@@ -58,7 +58,7 @@ class DataTestCase : public TestCase
{
protected:
explicit DataTestCase(T data)
- : _data{ data }
+ : _data{ std::move(data) }
{
}
diff --git a/tests/SConscript b/tests/SConscript
index 77d370f91a..ec596a7b89 100644
--- a/tests/SConscript
+++ b/tests/SConscript
@@ -146,7 +146,6 @@ files_benchmark += Glob('new/TensorLibrary.cpp')
files_benchmark += Glob('new/RawTensor.cpp')
files_benchmark += Glob('new/benchmark_new/*.cpp')
-# Add unit tests
if env['opencl']:
Import('opencl')
diff --git a/tests/fixtures_new/ActivationLayerFixture.h b/tests/fixtures_new/ActivationLayerFixture.h
index bb03fa2ed0..5066810c79 100644
--- a/tests/fixtures_new/ActivationLayerFixture.h
+++ b/tests/fixtures_new/ActivationLayerFixture.h
@@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor>
class ActivationLayerFixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(TensorShape shape, ActivationLayerInfo info, DataType data_type, int batches)
{
// Set batched in source and destination shapes
diff --git a/tests/fixtures_new/AlexNetFixture.h b/tests/fixtures_new/AlexNetFixture.h
index fcac1b2236..0ebdae0091 100644
--- a/tests/fixtures_new/AlexNetFixture.h
+++ b/tests/fixtures_new/AlexNetFixture.h
@@ -47,6 +47,7 @@ template <typename ITensorType,
class AlexNetFixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(DataType data_type, int batches)
{
constexpr bool weights_transposed = true;
diff --git a/tests/fixtures_new/ConvolutionLayerFixture.h b/tests/fixtures_new/ConvolutionLayerFixture.h
index 65426103e2..f41cd1d25e 100644
--- a/tests/fixtures_new/ConvolutionLayerFixture.h
+++ b/tests/fixtures_new/ConvolutionLayerFixture.h
@@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor>
class ConvolutionLayerFixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape dst_shape, PadStrideInfo info, DataType data_type, int batches)
{
// Set batched in source and destination shapes
diff --git a/tests/fixtures_new/FullyConnectedLayerFixture.h b/tests/fixtures_new/FullyConnectedLayerFixture.h
index 9bf18a9689..82ecb39b9c 100644
--- a/tests/fixtures_new/FullyConnectedLayerFixture.h
+++ b/tests/fixtures_new/FullyConnectedLayerFixture.h
@@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor>
class FullyConnectedLayerFixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape dst_shape, DataType data_type, int batches)
{
// Set batched in source and destination shapes
diff --git a/tests/fixtures_new/GEMMFixture.h b/tests/fixtures_new/GEMMFixture.h
index cd357789e5..b23661f3e3 100644
--- a/tests/fixtures_new/GEMMFixture.h
+++ b/tests/fixtures_new/GEMMFixture.h
@@ -39,6 +39,7 @@ template <typename TensorType, typename Function>
class GEMMFixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape shape_dst, float alpha, float beta, DataType data_type)
{
constexpr int fixed_point_position = 4;
diff --git a/tests/fixtures_new/LeNet5Fixture.h b/tests/fixtures_new/LeNet5Fixture.h
index 3f36628c60..d9173af048 100644
--- a/tests/fixtures_new/LeNet5Fixture.h
+++ b/tests/fixtures_new/LeNet5Fixture.h
@@ -43,6 +43,7 @@ template <typename TensorType,
class LeNet5Fixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(int batches)
{
network.init(batches);
diff --git a/tests/fixtures_new/NormalizationLayerFixture.h b/tests/fixtures_new/NormalizationLayerFixture.h
index 63d2d42c88..999eed6cff 100644
--- a/tests/fixtures_new/NormalizationLayerFixture.h
+++ b/tests/fixtures_new/NormalizationLayerFixture.h
@@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor>
class NormalizationLayerFixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(TensorShape shape, NormalizationLayerInfo info, DataType data_type, int batches)
{
// Set batched in source and destination shapes
diff --git a/tests/fixtures_new/PoolingLayerFixture.h b/tests/fixtures_new/PoolingLayerFixture.h
index a09b421ad0..fc9c90ae3c 100644
--- a/tests/fixtures_new/PoolingLayerFixture.h
+++ b/tests/fixtures_new/PoolingLayerFixture.h
@@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor>
class PoolingLayerFixture : public framework::Fixture
{
public:
+ template <typename...>
void setup(TensorShape src_shape, TensorShape dst_shape, PoolingLayerInfo info, DataType data_type, int batches)
{
// Set batched in source and destination shapes