diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/validation/CL/DepthConcatenateLayer.cpp | 2 | ||||
-rw-r--r-- | tests/validation/NEON/DepthConcatenateLayer.cpp | 2 | ||||
-rw-r--r-- | tests/validation/fixtures/DepthConcatenateLayerFixture.h | 9 |
3 files changed, 3 insertions, 10 deletions
diff --git a/tests/validation/CL/DepthConcatenateLayer.cpp b/tests/validation/CL/DepthConcatenateLayer.cpp index a8ef1c37c7..5ad423fe8f 100644 --- a/tests/validation/CL/DepthConcatenateLayer.cpp +++ b/tests/validation/CL/DepthConcatenateLayer.cpp @@ -46,7 +46,7 @@ TEST_SUITE(DepthConcatenateLayer) //TODO(COMPMID-415): Add configuration test? template <typename T> -using CLDepthConcatenateLayerFixture = DepthConcatenateValidationFixture<CLTensor, CLAccessor, CLDepthConcatenate, T>; +using CLDepthConcatenateLayerFixture = DepthConcatenateValidationFixture<CLTensor, ICLTensor, CLAccessor, CLDepthConcatenate, T>; TEST_SUITE(Float) TEST_SUITE(FP16) diff --git a/tests/validation/NEON/DepthConcatenateLayer.cpp b/tests/validation/NEON/DepthConcatenateLayer.cpp index 19a41ee9d6..d282cb5501 100644 --- a/tests/validation/NEON/DepthConcatenateLayer.cpp +++ b/tests/validation/NEON/DepthConcatenateLayer.cpp @@ -46,7 +46,7 @@ TEST_SUITE(DepthConcatenateLayer) //TODO(COMPMID-415): Add configuration test? template <typename T> -using NEDepthConcatenateLayerFixture = DepthConcatenateValidationFixture<Tensor, Accessor, NEDepthConcatenate, T>; +using NEDepthConcatenateLayerFixture = DepthConcatenateValidationFixture<Tensor, ITensor, Accessor, NEDepthConcatenate, T>; TEST_SUITE(Float) #ifdef ARM_COMPUTE_ENABLE_FP16 diff --git a/tests/validation/fixtures/DepthConcatenateLayerFixture.h b/tests/validation/fixtures/DepthConcatenateLayerFixture.h index 2a2e96e821..633dba23e0 100644 --- a/tests/validation/fixtures/DepthConcatenateLayerFixture.h +++ b/tests/validation/fixtures/DepthConcatenateLayerFixture.h @@ -38,16 +38,11 @@ namespace arm_compute { -class ITensor; -class Tensor; -class ICLTensor; -class CLTensor; - namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T> class DepthConcatenateValidationFixture : public framework::Fixture { public: @@ -99,8 +94,6 @@ protected: TensorType compute_target(std::vector<TensorShape> shapes, DataType data_type) { - using ITensorType = typename std::conditional<std::is_same<TensorType, Tensor>::value, ITensor, ICLTensor>::type; - std::vector<TensorType> srcs; std::vector<ITensorType *> src_ptrs; |