aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/UNIT/DynamicTensorFixture.h')
-rw-r--r--tests/validation/fixtures/UNIT/DynamicTensorFixture.h113
1 files changed, 113 insertions, 0 deletions
diff --git a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
index 08b90c5b52..804b394649 100644
--- a/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
+++ b/tests/validation/fixtures/UNIT/DynamicTensorFixture.h
@@ -378,6 +378,119 @@ private:
TensorType _bias_target{};
std::unique_ptr<ComplexFunctionType> _f_target{};
};
+
+/** Fixture that create a pipeline of Convolutions and changes the inputs dynamically
+ *
+ * Runs a list of convolutions and then resizes the inputs and reruns.
+ * Updates the memory manager and allocated memory.
+ */
+template <typename TensorType,
+ typename AccessorType,
+ typename MemoryManagementServiceType,
+ typename ComplexFunctionType>
+class DynamicTensorType2PipelineFunction : public framework::Fixture
+{
+ using T = float;
+
+public:
+ template <typename...>
+ void setup(std::vector<TensorShape> input_shapes)
+ {
+ _data_type = DataType::F32;
+ _data_layout = DataLayout::NHWC;
+ _input_shapes = input_shapes;
+
+ run();
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i)
+ {
+ switch(tensor.data_type())
+ {
+ case DataType::F32:
+ {
+ std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ library->fill_tensor_uniform(tensor, i);
+ }
+ }
+
+ void run()
+ {
+ const unsigned int num_functions = 5;
+ const unsigned int num_tensors = num_functions + 1;
+ const unsigned int num_resizes = _input_shapes.size();
+
+ for(unsigned int i = 0; i < num_functions; ++i)
+ {
+ _functions.emplace_back(support::cpp14::make_unique<ComplexFunctionType>(_ms.mm));
+ }
+
+ for(unsigned int i = 0; i < num_resizes; ++i)
+ {
+ TensorShape input_shape = _input_shapes[i];
+ TensorShape weights_shape = TensorShape(3U, 3U, input_shape[2], input_shape[2]);
+ TensorShape output_shape = input_shape;
+ PadStrideInfo info(1U, 1U, 1U, 1U);
+
+ if(_data_layout == DataLayout::NHWC)
+ {
+ permute(input_shape, PermutationVector(2U, 0U, 1U));
+ permute(weights_shape, PermutationVector(2U, 0U, 1U));
+ permute(output_shape, PermutationVector(2U, 0U, 1U));
+ }
+
+ std::vector<TensorType> tensors(num_tensors);
+ std::vector<TensorType> ws(num_functions);
+ std::vector<TensorType> bs(num_functions);
+
+ auto tensor_info = TensorInfo(input_shape, 1, _data_type);
+ auto weights_info = TensorInfo(weights_shape, 1, _data_type);
+ tensor_info.set_data_layout(_data_layout);
+ weights_info.set_data_layout(_data_layout);
+
+ for(unsigned int f = 0; f < num_functions; ++f)
+ {
+ tensors[f].allocator()->init(tensor_info);
+ tensors[f + 1].allocator()->init(tensor_info);
+ ws[f].allocator()->init(weights_info);
+
+ _functions[f]->configure(&tensors[f], &ws[f], nullptr, &tensors[f + 1], info);
+
+ // Allocate tensors
+ tensors[f].allocator()->allocate();
+ ws[f].allocator()->allocate();
+ }
+ tensors[num_functions].allocator()->allocate();
+
+ // Populate and validate memory manager
+ _ms.clear();
+ _ms.populate(1);
+ _ms.mg.acquire();
+
+ // Run pipeline
+ for(unsigned int f = 0; f < num_functions; ++f)
+ {
+ _functions[f]->run();
+ }
+
+ // Release memory group
+ _ms.mg.release();
+ }
+ }
+
+private:
+ DataType _data_type{ DataType::UNKNOWN };
+ DataLayout _data_layout{ DataLayout::UNKNOWN };
+ std::vector<TensorShape> _input_shapes{};
+ MemoryManagementServiceType _ms{};
+ std::vector<std::unique_ptr<ComplexFunctionType>> _functions{};
+};
} // namespace validation
} // namespace test
} // namespace arm_compute