From a9c4472188abef421adb589e2a6fef52727d465f Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 5 Apr 2019 17:18:36 +0100 Subject: COMPMID-2051 Refactor shape_calculator::calculate_concatenate_shape Change-Id: Ibf316718d11fa975d75f226925747b21c4efd127 Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/974 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- .../GLES_COMPUTE/DepthConcatenateLayer.cpp | 33 +++++++++++++--------- 1 file changed, 19 insertions(+), 14 deletions(-) (limited to 'tests/validation/GLES_COMPUTE') diff --git a/tests/validation/GLES_COMPUTE/DepthConcatenateLayer.cpp b/tests/validation/GLES_COMPUTE/DepthConcatenateLayer.cpp index 7af3050c1d..04e91d63ae 100644 --- a/tests/validation/GLES_COMPUTE/DepthConcatenateLayer.cpp +++ b/tests/validation/GLES_COMPUTE/DepthConcatenateLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,14 +24,14 @@ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/GLES_COMPUTE/GCTensor.h" #include "arm_compute/runtime/GLES_COMPUTE/GCTensorAllocator.h" -#include "arm_compute/runtime/GLES_COMPUTE/functions/GCDepthConcatenateLayer.h" +#include "arm_compute/runtime/GLES_COMPUTE/functions/GCConcatenateLayer.h" #include "tests/GLES_COMPUTE/GCAccessor.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/framework/Asserts.h" #include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" #include "tests/validation/Validation.h" -#include "tests/validation/fixtures/DepthConcatenateLayerFixture.h" +#include "tests/validation/fixtures/ConcatenateLayerFixture.h" namespace arm_compute { @@ -42,21 +42,23 @@ namespace validation TEST_SUITE(GC) TEST_SUITE(DepthConcatenateLayer) -//TODO(COMPMID-415): Add configuration test? - template -using GCDepthConcatenateLayerFixture = DepthConcatenateLayerValidationFixture; +using GCDepthConcatenateLayerFixture = ConcatenateLayerValidationFixture; TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, GCDepthConcatenateLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small2DShapes(), framework::dataset::make("DataType", - DataType::F16))) +FIXTURE_DATA_TEST_CASE(RunSmall, GCDepthConcatenateLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::Small3DShapes(), + framework::dataset::make("DataType", + DataType::F16)), + framework::dataset::make("Axis", 2))) { // Validate output validate(GCAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, GCDepthConcatenateLayerFixture, framework::DatasetMode::NIGHTLY, combine(datasets::Large2DShapes(), framework::dataset::make("DataType", - DataType::F16))) +FIXTURE_DATA_TEST_CASE(RunLarge, GCDepthConcatenateLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::Large3DShapes(), + framework::dataset::make("DataType", + DataType::F16)), + framework::dataset::make("Axis", 2))) { // Validate output validate(GCAccessor(_target), _reference); @@ -64,14 +66,17 @@ FIXTURE_DATA_TEST_CASE(RunLarge, GCDepthConcatenateLayerFixture, framework TEST_SUITE_END() TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, GCDepthConcatenateLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small2DShapes(), framework::dataset::make("DataType", - DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunSmall, GCDepthConcatenateLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::Small3DShapes(), + framework::dataset::make("DataType", + DataType::F32)), + framework::dataset::make("Axis", 2))) { // Validate output validate(GCAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, GCDepthConcatenateLayerFixture, framework::DatasetMode::NIGHTLY, combine(datasets::DepthConcatenateLayerShapes(), framework::dataset::make("DataType", - DataType::F32))) +FIXTURE_DATA_TEST_CASE(RunLarge, GCDepthConcatenateLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::Large3DShapes(), framework::dataset::make("DataType", + DataType::F32)), + framework::dataset::make("Axis", 2))) { // Validate output validate(GCAccessor(_target), _reference); -- cgit v1.2.1