aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/UNIT/DynamicTensorFixture.h')
-rw-r--r--tests/validation/fixtures/UNIT/DynamicTensorFixture.h38
1 files changed, 29 insertions, 9 deletions
diff --git a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
index df12a4aa30..66ef6c4aff 100644
--- a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
+++ b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
@@ -89,6 +89,26 @@ public:
MemoryGroup mg;
size_t num_pools;
};
+
+template <typename MemoryMgrType, typename FuncType, typename ITensorType>
+class SimpleFunctionWrapper
+{
+public:
+ SimpleFunctionWrapper(std::shared_ptr<MemoryMgrType> mm)
+ : _func(mm)
+ {
+ }
+ void configure(ITensorType *src, ITensorType *dst)
+ {
+ }
+ void run()
+ {
+ _func.run();
+ }
+
+private:
+ FuncType _func;
+};
} // namespace
/** Simple test case to run a single function with different shapes twice.
@@ -102,7 +122,7 @@ template <typename TensorType,
typename LifetimeMgrType,
typename PoolMgrType,
typename MemoryManagerType,
- typename NormalizationFunctionType>
+ typename SimpleFunctionWrapperType>
class DynamicTensorType3SingleFunction : public framework::Fixture
{
using T = float;
@@ -131,15 +151,15 @@ protected:
// Level 0
// Create tensors
- TensorType src = create_tensor<Tensor>(level_0, DataType::F32, 1);
- TensorType dst = create_tensor<Tensor>(level_0, DataType::F32, 1);
+ TensorType src = create_tensor<TensorType>(level_0, DataType::F32, 1);
+ TensorType dst = create_tensor<TensorType>(level_0, DataType::F32, 1);
serv_cross.mg.manage(&src);
serv_cross.mg.manage(&dst);
// Create and configure function
- NormalizationFunctionType norm_layer(serv_internal.mm);
- norm_layer.configure(&src, &dst, NormalizationLayerInfo(NormType::CROSS_MAP, 3));
+ SimpleFunctionWrapperType layer(serv_internal.mm);
+ layer.configure(&src, &dst);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -163,8 +183,8 @@ protected:
// Acquire memory manager, fill tensors and compute functions
serv_cross.mg.acquire();
- arm_compute::test::library->fill_tensor_value(Accessor(src), 12.f);
- norm_layer.run();
+ arm_compute::test::library->fill_tensor_value(AccessorType(src), 12.f);
+ layer.run();
serv_cross.mg.release();
// Clear manager
@@ -184,7 +204,7 @@ protected:
serv_cross.mg.manage(&dst);
// Re-configure the function
- norm_layer.configure(&src, &dst, NormalizationLayerInfo(NormType::CROSS_MAP, 3));
+ layer.configure(&src, &dst);
// Allocate tensors
src.allocator()->allocate();
@@ -203,7 +223,7 @@ protected:
// Compute functions
serv_cross.mg.acquire();
arm_compute::test::library->fill_tensor_value(AccessorType(src), 12.f);
- norm_layer.run();
+ layer.run();
serv_cross.mg.release();
// Clear manager