diff options
Diffstat (limited to 'tests/validation/fixtures/UNIT/DynamicTensorFixture.h')
-rw-r--r-- | tests/validation/fixtures/UNIT/DynamicTensorFixture.h | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h index df12a4aa30..66ef6c4aff 100644 --- a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h +++ b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h @@ -89,6 +89,26 @@ public: MemoryGroup mg; size_t num_pools; }; + +template <typename MemoryMgrType, typename FuncType, typename ITensorType> +class SimpleFunctionWrapper +{ +public: + SimpleFunctionWrapper(std::shared_ptr<MemoryMgrType> mm) + : _func(mm) + { + } + void configure(ITensorType *src, ITensorType *dst) + { + } + void run() + { + _func.run(); + } + +private: + FuncType _func; +}; } // namespace /** Simple test case to run a single function with different shapes twice. @@ -102,7 +122,7 @@ template <typename TensorType, typename LifetimeMgrType, typename PoolMgrType, typename MemoryManagerType, - typename NormalizationFunctionType> + typename SimpleFunctionWrapperType> class DynamicTensorType3SingleFunction : public framework::Fixture { using T = float; @@ -131,15 +151,15 @@ protected: // Level 0 // Create tensors - TensorType src = create_tensor<Tensor>(level_0, DataType::F32, 1); - TensorType dst = create_tensor<Tensor>(level_0, DataType::F32, 1); + TensorType src = create_tensor<TensorType>(level_0, DataType::F32, 1); + TensorType dst = create_tensor<TensorType>(level_0, DataType::F32, 1); serv_cross.mg.manage(&src); serv_cross.mg.manage(&dst); // Create and configure function - NormalizationFunctionType norm_layer(serv_internal.mm); - norm_layer.configure(&src, &dst, NormalizationLayerInfo(NormType::CROSS_MAP, 3)); + SimpleFunctionWrapperType layer(serv_internal.mm); + layer.configure(&src, &dst); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -163,8 +183,8 @@ protected: // Acquire memory manager, fill tensors and compute functions serv_cross.mg.acquire(); - arm_compute::test::library->fill_tensor_value(Accessor(src), 12.f); - norm_layer.run(); + arm_compute::test::library->fill_tensor_value(AccessorType(src), 12.f); + layer.run(); serv_cross.mg.release(); // Clear manager @@ -184,7 +204,7 @@ protected: serv_cross.mg.manage(&dst); // Re-configure the function - norm_layer.configure(&src, &dst, NormalizationLayerInfo(NormType::CROSS_MAP, 3)); + layer.configure(&src, &dst); // Allocate tensors src.allocator()->allocate(); @@ -203,7 +223,7 @@ protected: // Compute functions serv_cross.mg.acquire(); arm_compute::test::library->fill_tensor_value(AccessorType(src), 12.f); - norm_layer.run(); + layer.run(); serv_cross.mg.release(); // Clear manager |