diff options
author | Moritz Pflanzer <moritz.pflanzer@arm.com> | 2017-07-18 13:42:54 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-09-17 14:16:42 +0100 |
commit | b6c8d24042616341c1fbca6e255a69561c73fedf (patch) | |
tree | f959d732981d4230d6491d8636afe3f31d64e798 | |
parent | fce87954ac2373e910ccb0d83a00f5958ba41e71 (diff) | |
download | ComputeLibrary-b6c8d24042616341c1fbca6e255a69561c73fedf.tar.gz |
COMPMID-415: Use templates for data arguments
Change-Id: I815d705e7cf42022f7a203935dcaaa333a2801fe
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/80311
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
-rw-r--r-- | arm_compute/core/Dimensions.h | 2 | ||||
-rw-r--r-- | framework/Macros.h | 141 | ||||
-rw-r--r-- | framework/TestCase.h | 2 | ||||
-rw-r--r-- | tests/SConscript | 1 | ||||
-rw-r--r-- | tests/fixtures_new/ActivationLayerFixture.h | 1 | ||||
-rw-r--r-- | tests/fixtures_new/AlexNetFixture.h | 1 | ||||
-rw-r--r-- | tests/fixtures_new/ConvolutionLayerFixture.h | 1 | ||||
-rw-r--r-- | tests/fixtures_new/FullyConnectedLayerFixture.h | 1 | ||||
-rw-r--r-- | tests/fixtures_new/GEMMFixture.h | 1 | ||||
-rw-r--r-- | tests/fixtures_new/LeNet5Fixture.h | 1 | ||||
-rw-r--r-- | tests/fixtures_new/NormalizationLayerFixture.h | 1 | ||||
-rw-r--r-- | tests/fixtures_new/PoolingLayerFixture.h | 1 |
12 files changed, 110 insertions, 44 deletions
diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h index d2131018fa..96dd3711cb 100644 --- a/arm_compute/core/Dimensions.h +++ b/arm_compute/core/Dimensions.h @@ -49,7 +49,7 @@ public: * @param[in] dims Values to initialize the dimensions. */ template <typename... Ts> - Dimensions(Ts... dims) + explicit Dimensions(Ts... dims) : _id{ { dims... } }, _num_dimensions{ sizeof...(dims) } { } diff --git a/framework/Macros.h b/framework/Macros.h index 38eb29c6a1..8c5afe9304 100644 --- a/framework/Macros.h +++ b/framework/Macros.h @@ -47,6 +47,52 @@ // // +// HELPER MACROS +// + +#define CONCAT(ARG0, ARG1) ARG0##ARG1 + +#define VARIADIC_SIZE_IMPL(e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, size, ...) size +#define VARIADIC_SIZE(...) VARIADIC_SIZE_IMPL(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define JOIN_PARAM1(OP, param) OP(0, param) +#define JOIN_PARAM2(OP, param, ...) \ + OP(1, param) \ + , JOIN_PARAM1(OP, __VA_ARGS__) +#define JOIN_PARAM3(OP, param, ...) \ + OP(2, param) \ + , JOIN_PARAM2(OP, __VA_ARGS__) +#define JOIN_PARAM4(OP, param, ...) \ + OP(3, param) \ + , JOIN_PARAM3(OP, __VA_ARGS__) +#define JOIN_PARAM5(OP, param, ...) \ + OP(4, param) \ + , JOIN_PARAM4(OP, __VA_ARGS__) +#define JOIN_PARAM6(OP, param, ...) \ + OP(5, param) \ + , JOIN_PARAM5(OP, __VA_ARGS__) +#define JOIN_PARAM7(OP, param, ...) \ + OP(6, param) \ + , JOIN_PARAM6(OP, __VA_ARGS__) +#define JOIN_PARAM8(OP, param, ...) \ + OP(7, param) \ + , JOIN_PARAM7(OP, __VA_ARGS__) +#define JOIN_PARAM9(OP, param, ...) \ + OP(8, param) \ + , JOIN_PARAM8(OP, __VA_ARGS__) +#define JOIN_PARAM10(OP, param, ...) \ + OP(9, param) \ + , JOIN_PARAM9(OP, __VA_ARGS__) +#define JOIN_PARAM(OP, NUM, ...) \ + CONCAT(JOIN_PARAM, NUM) \ + (OP, __VA_ARGS__) + +#define MAKE_TYPE_PARAM(i, name) typename T##i +#define MAKE_ARG_PARAM(i, name) const T##i &name +#define MAKE_TYPE_PARAMS(...) JOIN_PARAM(MAKE_TYPE_PARAM, VARIADIC_SIZE(__VA_ARGS__), __VA_ARGS__) +#define MAKE_ARG_PARAMS(...) JOIN_PARAM(MAKE_ARG_PARAM, VARIADIC_SIZE(__VA_ARGS__), __VA_ARGS__) + +// // TEST CASE MACROS // #define TEST_CASE_CONSTRUCTOR(TEST_NAME) \ @@ -61,10 +107,10 @@ { \ FIXTURE::setup(); \ } -#define FIXTURE_DATA_SETUP(FIXTURE) \ - void do_setup() override \ - { \ - apply(this, &FIXTURE::setup, _data); \ +#define FIXTURE_DATA_SETUP(FIXTURE) \ + void do_setup() override \ + { \ + apply(this, &FIXTURE::setup<As...>, _data); \ } #define FIXTURE_RUN(FIXTURE) \ void do_run() override \ @@ -81,10 +127,10 @@ { \ #TEST_NAME, MODE \ } -#define DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET) \ - static arm_compute::test::framework::detail::TestCaseRegistrar<TEST_NAME> TEST_NAME##_reg \ - { \ - #TEST_NAME, MODE, DATASET \ +#define DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET) \ + static arm_compute::test::framework::detail::TestCaseRegistrar<TEST_NAME<decltype(DATASET)::type>> TEST_NAME##_reg \ + { \ + #TEST_NAME, MODE, DATASET \ } #define TEST_CASE(TEST_NAME, MODE) \ @@ -97,19 +143,25 @@ TEST_REGISTRAR(TEST_NAME, MODE); \ void TEST_NAME::do_run() -#define DATA_TEST_CASE(TEST_NAME, MODE, DATASET, ...) \ - class TEST_NAME : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type> \ - { \ - public: \ - DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \ - void do_run() override \ - { \ - arm_compute::test::framework::apply(this, &TEST_NAME::run, _data); \ - } \ - void run(__VA_ARGS__); \ - }; \ - DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \ - void TEST_NAME::run(__VA_ARGS__) +#define DATA_TEST_CASE(TEST_NAME, MODE, DATASET, ...) \ + template <typename T> \ + class TEST_NAME; \ + template <typename... As> \ + class TEST_NAME<std::tuple<As...>> : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type> \ + { \ + public: \ + DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \ + void do_run() override \ + { \ + arm_compute::test::framework::apply(this, &TEST_NAME::run<As...>, _data); \ + } \ + template <MAKE_TYPE_PARAMS(__VA_ARGS__)> \ + void run(MAKE_ARG_PARAMS(__VA_ARGS__)); \ + }; \ + DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \ + template <typename... As> \ + template <MAKE_TYPE_PARAMS(__VA_ARGS__)> \ + void TEST_NAME<std::tuple<As...>>::run(MAKE_ARG_PARAMS(__VA_ARGS__)) #define FIXTURE_TEST_CASE(TEST_NAME, FIXTURE, MODE) \ class TEST_NAME : public arm_compute::test::framework::TestCase, public FIXTURE \ @@ -123,17 +175,21 @@ TEST_REGISTRAR(TEST_NAME, MODE); \ void TEST_NAME::do_run() -#define FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \ - class TEST_NAME : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \ - { \ - public: \ - DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \ - FIXTURE_DATA_SETUP(FIXTURE) \ - void do_run() override; \ - FIXTURE_TEARDOWN(FIXTURE) \ - }; \ - DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \ - void TEST_NAME::do_run() +#define FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \ + template <typename T> \ + class TEST_NAME; \ + template <typename... As> \ + class TEST_NAME<std::tuple<As...>> : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \ + { \ + public: \ + DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \ + FIXTURE_DATA_SETUP(FIXTURE) \ + void do_run() override; \ + FIXTURE_TEARDOWN(FIXTURE) \ + }; \ + DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET); \ + template <typename... As> \ + void TEST_NAME<std::tuple<As...>>::do_run() #define REGISTER_FIXTURE_TEST_CASE(TEST_NAME, FIXTURE, MODE) \ class TEST_NAME : public arm_compute::test::framework::TestCase, public FIXTURE \ @@ -146,15 +202,18 @@ }; \ TEST_REGISTRAR(TEST_NAME, MODE) -#define REGISTER_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \ - class TEST_NAME : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \ - { \ - public: \ - DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \ - FIXTURE_DATA_SETUP(FIXTURE) \ - FIXTURE_RUN(FIXTURE) \ - FIXTURE_TEARDOWN(FIXTURE) \ - }; \ +#define REGISTER_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, DATASET) \ + template <typename T> \ + class TEST_NAME; \ + template <typename... As> \ + class TEST_NAME<std::tuple<As...>> : public arm_compute::test::framework::DataTestCase<decltype(DATASET)::type>, public FIXTURE \ + { \ + public: \ + DATA_TEST_CASE_CONSTRUCTOR(TEST_NAME, DATASET) \ + FIXTURE_DATA_SETUP(FIXTURE) \ + FIXTURE_RUN(FIXTURE) \ + FIXTURE_TEARDOWN(FIXTURE) \ + }; \ DATA_TEST_REGISTRAR(TEST_NAME, MODE, DATASET) // // TEST CASE MACROS END diff --git a/framework/TestCase.h b/framework/TestCase.h index 43750b1d1b..dbb9312dee 100644 --- a/framework/TestCase.h +++ b/framework/TestCase.h @@ -58,7 +58,7 @@ class DataTestCase : public TestCase { protected: explicit DataTestCase(T data) - : _data{ data } + : _data{ std::move(data) } { } diff --git a/tests/SConscript b/tests/SConscript index 77d370f91a..ec596a7b89 100644 --- a/tests/SConscript +++ b/tests/SConscript @@ -146,7 +146,6 @@ files_benchmark += Glob('new/TensorLibrary.cpp') files_benchmark += Glob('new/RawTensor.cpp') files_benchmark += Glob('new/benchmark_new/*.cpp') -# Add unit tests if env['opencl']: Import('opencl') diff --git a/tests/fixtures_new/ActivationLayerFixture.h b/tests/fixtures_new/ActivationLayerFixture.h index bb03fa2ed0..5066810c79 100644 --- a/tests/fixtures_new/ActivationLayerFixture.h +++ b/tests/fixtures_new/ActivationLayerFixture.h @@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor> class ActivationLayerFixture : public framework::Fixture { public: + template <typename...> void setup(TensorShape shape, ActivationLayerInfo info, DataType data_type, int batches) { // Set batched in source and destination shapes diff --git a/tests/fixtures_new/AlexNetFixture.h b/tests/fixtures_new/AlexNetFixture.h index fcac1b2236..0ebdae0091 100644 --- a/tests/fixtures_new/AlexNetFixture.h +++ b/tests/fixtures_new/AlexNetFixture.h @@ -47,6 +47,7 @@ template <typename ITensorType, class AlexNetFixture : public framework::Fixture { public: + template <typename...> void setup(DataType data_type, int batches) { constexpr bool weights_transposed = true; diff --git a/tests/fixtures_new/ConvolutionLayerFixture.h b/tests/fixtures_new/ConvolutionLayerFixture.h index 65426103e2..f41cd1d25e 100644 --- a/tests/fixtures_new/ConvolutionLayerFixture.h +++ b/tests/fixtures_new/ConvolutionLayerFixture.h @@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor> class ConvolutionLayerFixture : public framework::Fixture { public: + template <typename...> void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape dst_shape, PadStrideInfo info, DataType data_type, int batches) { // Set batched in source and destination shapes diff --git a/tests/fixtures_new/FullyConnectedLayerFixture.h b/tests/fixtures_new/FullyConnectedLayerFixture.h index 9bf18a9689..82ecb39b9c 100644 --- a/tests/fixtures_new/FullyConnectedLayerFixture.h +++ b/tests/fixtures_new/FullyConnectedLayerFixture.h @@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor> class FullyConnectedLayerFixture : public framework::Fixture { public: + template <typename...> void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape dst_shape, DataType data_type, int batches) { // Set batched in source and destination shapes diff --git a/tests/fixtures_new/GEMMFixture.h b/tests/fixtures_new/GEMMFixture.h index cd357789e5..b23661f3e3 100644 --- a/tests/fixtures_new/GEMMFixture.h +++ b/tests/fixtures_new/GEMMFixture.h @@ -39,6 +39,7 @@ template <typename TensorType, typename Function> class GEMMFixture : public framework::Fixture { public: + template <typename...> void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape shape_dst, float alpha, float beta, DataType data_type) { constexpr int fixed_point_position = 4; diff --git a/tests/fixtures_new/LeNet5Fixture.h b/tests/fixtures_new/LeNet5Fixture.h index 3f36628c60..d9173af048 100644 --- a/tests/fixtures_new/LeNet5Fixture.h +++ b/tests/fixtures_new/LeNet5Fixture.h @@ -43,6 +43,7 @@ template <typename TensorType, class LeNet5Fixture : public framework::Fixture { public: + template <typename...> void setup(int batches) { network.init(batches); diff --git a/tests/fixtures_new/NormalizationLayerFixture.h b/tests/fixtures_new/NormalizationLayerFixture.h index 63d2d42c88..999eed6cff 100644 --- a/tests/fixtures_new/NormalizationLayerFixture.h +++ b/tests/fixtures_new/NormalizationLayerFixture.h @@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor> class NormalizationLayerFixture : public framework::Fixture { public: + template <typename...> void setup(TensorShape shape, NormalizationLayerInfo info, DataType data_type, int batches) { // Set batched in source and destination shapes diff --git a/tests/fixtures_new/PoolingLayerFixture.h b/tests/fixtures_new/PoolingLayerFixture.h index a09b421ad0..fc9c90ae3c 100644 --- a/tests/fixtures_new/PoolingLayerFixture.h +++ b/tests/fixtures_new/PoolingLayerFixture.h @@ -39,6 +39,7 @@ template <typename TensorType, typename Function, typename Accessor> class PoolingLayerFixture : public framework::Fixture { public: + template <typename...> void setup(TensorShape src_shape, TensorShape dst_shape, PoolingLayerInfo info, DataType data_type, int batches) { // Set batched in source and destination shapes |