From b785dd4a4e1e662630f4d79e0f578513958a71fd Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 19 Sep 2019 12:09:32 +0100 Subject: COMPMID-2670: [CL/GC] Create a test case for dynamic tensor support Change-Id: I35d28786ee3843ac11c1211aea55328782a99382 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1958 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- .../fixtures/UNIT/DynamicTensorFixture.h | 38 +++++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) (limited to 'tests/validation/fixtures/UNIT/DynamicTensorFixture.h') diff --git a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h index df12a4aa30..66ef6c4aff 100644 --- a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h +++ b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h @@ -89,6 +89,26 @@ public: MemoryGroup mg; size_t num_pools; }; + +template +class SimpleFunctionWrapper +{ +public: + SimpleFunctionWrapper(std::shared_ptr mm) + : _func(mm) + { + } + void configure(ITensorType *src, ITensorType *dst) + { + } + void run() + { + _func.run(); + } + +private: + FuncType _func; +}; } // namespace /** Simple test case to run a single function with different shapes twice. @@ -102,7 +122,7 @@ template + typename SimpleFunctionWrapperType> class DynamicTensorType3SingleFunction : public framework::Fixture { using T = float; @@ -131,15 +151,15 @@ protected: // Level 0 // Create tensors - TensorType src = create_tensor(level_0, DataType::F32, 1); - TensorType dst = create_tensor(level_0, DataType::F32, 1); + TensorType src = create_tensor(level_0, DataType::F32, 1); + TensorType dst = create_tensor(level_0, DataType::F32, 1); serv_cross.mg.manage(&src); serv_cross.mg.manage(&dst); // Create and configure function - NormalizationFunctionType norm_layer(serv_internal.mm); - norm_layer.configure(&src, &dst, NormalizationLayerInfo(NormType::CROSS_MAP, 3)); + SimpleFunctionWrapperType layer(serv_internal.mm); + layer.configure(&src, &dst); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -163,8 +183,8 @@ protected: // Acquire memory manager, fill tensors and compute functions serv_cross.mg.acquire(); - arm_compute::test::library->fill_tensor_value(Accessor(src), 12.f); - norm_layer.run(); + arm_compute::test::library->fill_tensor_value(AccessorType(src), 12.f); + layer.run(); serv_cross.mg.release(); // Clear manager @@ -184,7 +204,7 @@ protected: serv_cross.mg.manage(&dst); // Re-configure the function - norm_layer.configure(&src, &dst, NormalizationLayerInfo(NormType::CROSS_MAP, 3)); + layer.configure(&src, &dst); // Allocate tensors src.allocator()->allocate(); @@ -203,7 +223,7 @@ protected: // Compute functions serv_cross.mg.acquire(); arm_compute::test::library->fill_tensor_value(AccessorType(src), 12.f); - norm_layer.run(); + layer.run(); serv_cross.mg.release(); // Clear manager -- cgit v1.2.1