From 2b5f0f2574551f59970bb9d710bafad2bc4bbd4a Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Wed, 10 Jan 2018 14:08:50 +0000 Subject: COMPMID-782 Port examples to the new format Change-Id: Ib178a97c080ff650094d02ee49e2a0aa22376dd0 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/115717 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- examples/neon_cnn.cpp | 415 +++++++++++++++++++++++++------------------------- 1 file changed, 211 insertions(+), 204 deletions(-) (limited to 'examples/neon_cnn.cpp') diff --git a/examples/neon_cnn.cpp b/examples/neon_cnn.cpp index 2be5acfbaf..05b6c832bc 100644 --- a/examples/neon_cnn.cpp +++ b/examples/neon_cnn.cpp @@ -33,254 +33,261 @@ using namespace arm_compute; using namespace utils; -void main_cnn(int argc, char **argv) +class NEONCNNExample : public Example { - ARM_COMPUTE_UNUSED(argc); - ARM_COMPUTE_UNUSED(argv); +public: + void do_setup(int argc, char **argv) override + { + ARM_COMPUTE_UNUSED(argc); + ARM_COMPUTE_UNUSED(argv); - // Create NEON allocator - Allocator allocator; + // Create memory manager components + // We need 2 memory managers: 1 for handling the tensors within the functions (mm_layers) and 1 for handling the input and output tensors of the functions (mm_transitions)) + auto lifetime_mgr0 = std::make_shared(); // Create lifetime manager + auto lifetime_mgr1 = std::make_shared(); // Create lifetime manager + auto pool_mgr0 = std::make_shared(); // Create pool manager + auto pool_mgr1 = std::make_shared(); // Create pool manager + auto mm_layers = std::make_shared(lifetime_mgr0, pool_mgr0); // Create the memory manager + auto mm_transitions = std::make_shared(lifetime_mgr1, pool_mgr1); // Create the memory manager - // Create memory manager components - // We need 2 memory managers: 1 for handling the tensors within the functions (mm_layers) and 1 for handling the input and output tensors of the functions (mm_transitions)) - auto lifetime_mgr0 = std::make_shared(); // Create lifetime manager - auto lifetime_mgr1 = std::make_shared(); // Create lifetime manager - auto pool_mgr0 = std::make_shared(); // Create pool manager - auto pool_mgr1 = std::make_shared(); // Create pool manager - auto mm_layers = std::make_shared(lifetime_mgr0, pool_mgr0); // Create the memory manager - auto mm_transitions = std::make_shared(lifetime_mgr1, pool_mgr1); // Create the memory manager + // The weights and biases tensors should be initialized with the values inferred with the training - // The src tensor should contain the input image - Tensor src; - - // The weights and biases tensors should be initialized with the values inferred with the training - Tensor weights0; - Tensor weights1; - Tensor weights2; - Tensor biases0; - Tensor biases1; - Tensor biases2; - - Tensor out_conv0; - Tensor out_conv1; - Tensor out_act0; - Tensor out_act1; - Tensor out_act2; - Tensor out_pool0; - Tensor out_pool1; - Tensor out_fc0; - Tensor out_softmax; - - // Create layers and set memory manager where allowed to manage internal memory requirements - NEConvolutionLayer conv0(mm_layers); - NEConvolutionLayer conv1(mm_layers); - NEPoolingLayer pool0; - NEPoolingLayer pool1; - NEFullyConnectedLayer fc0(mm_layers); - NEActivationLayer act0; - NEActivationLayer act1; - NEActivationLayer act2; - NESoftmaxLayer softmax(mm_layers); + // Set memory manager where allowed to manage internal memory requirements + conv0 = arm_compute::support::cpp14::make_unique(mm_layers); + conv1 = arm_compute::support::cpp14::make_unique(mm_layers); + fc0 = arm_compute::support::cpp14::make_unique(mm_layers); + softmax = arm_compute::support::cpp14::make_unique(mm_layers); - /* [Initialize tensors] */ + /* [Initialize tensors] */ - // Initialize src tensor - constexpr unsigned int width_src_image = 32; - constexpr unsigned int height_src_image = 32; - constexpr unsigned int ifm_src_img = 1; + // Initialize src tensor + constexpr unsigned int width_src_image = 32; + constexpr unsigned int height_src_image = 32; + constexpr unsigned int ifm_src_img = 1; - const TensorShape src_shape(width_src_image, height_src_image, ifm_src_img); - src.allocator()->init(TensorInfo(src_shape, 1, DataType::F32)); + const TensorShape src_shape(width_src_image, height_src_image, ifm_src_img); + src.allocator()->init(TensorInfo(src_shape, 1, DataType::F32)); - // Initialize tensors of conv0 - constexpr unsigned int kernel_x_conv0 = 5; - constexpr unsigned int kernel_y_conv0 = 5; - constexpr unsigned int ofm_conv0 = 8; + // Initialize tensors of conv0 + constexpr unsigned int kernel_x_conv0 = 5; + constexpr unsigned int kernel_y_conv0 = 5; + constexpr unsigned int ofm_conv0 = 8; - const TensorShape weights_shape_conv0(kernel_x_conv0, kernel_y_conv0, src_shape.z(), ofm_conv0); - const TensorShape biases_shape_conv0(weights_shape_conv0[3]); - const TensorShape out_shape_conv0(src_shape.x(), src_shape.y(), weights_shape_conv0[3]); + const TensorShape weights_shape_conv0(kernel_x_conv0, kernel_y_conv0, src_shape.z(), ofm_conv0); + const TensorShape biases_shape_conv0(weights_shape_conv0[3]); + const TensorShape out_shape_conv0(src_shape.x(), src_shape.y(), weights_shape_conv0[3]); - weights0.allocator()->init(TensorInfo(weights_shape_conv0, 1, DataType::F32)); - biases0.allocator()->init(TensorInfo(biases_shape_conv0, 1, DataType::F32)); - out_conv0.allocator()->init(TensorInfo(out_shape_conv0, 1, DataType::F32)); - - // Initialize tensor of act0 - out_act0.allocator()->init(TensorInfo(out_shape_conv0, 1, DataType::F32)); - - // Initialize tensor of pool0 - TensorShape out_shape_pool0 = out_shape_conv0; - out_shape_pool0.set(0, out_shape_pool0.x() / 2); - out_shape_pool0.set(1, out_shape_pool0.y() / 2); - out_pool0.allocator()->init(TensorInfo(out_shape_pool0, 1, DataType::F32)); - - // Initialize tensors of conv1 - constexpr unsigned int kernel_x_conv1 = 3; - constexpr unsigned int kernel_y_conv1 = 3; - constexpr unsigned int ofm_conv1 = 16; - - const TensorShape weights_shape_conv1(kernel_x_conv1, kernel_y_conv1, out_shape_pool0.z(), ofm_conv1); - - const TensorShape biases_shape_conv1(weights_shape_conv1[3]); - const TensorShape out_shape_conv1(out_shape_pool0.x(), out_shape_pool0.y(), weights_shape_conv1[3]); - - weights1.allocator()->init(TensorInfo(weights_shape_conv1, 1, DataType::F32)); - biases1.allocator()->init(TensorInfo(biases_shape_conv1, 1, DataType::F32)); - out_conv1.allocator()->init(TensorInfo(out_shape_conv1, 1, DataType::F32)); - - // Initialize tensor of act1 - out_act1.allocator()->init(TensorInfo(out_shape_conv1, 1, DataType::F32)); - - // Initialize tensor of pool1 - TensorShape out_shape_pool1 = out_shape_conv1; - out_shape_pool1.set(0, out_shape_pool1.x() / 2); - out_shape_pool1.set(1, out_shape_pool1.y() / 2); - out_pool1.allocator()->init(TensorInfo(out_shape_pool1, 1, DataType::F32)); - - // Initialize tensor of fc0 - constexpr unsigned int num_labels = 128; - - const TensorShape weights_shape_fc0(out_shape_pool1.x() * out_shape_pool1.y() * out_shape_pool1.z(), num_labels); - const TensorShape biases_shape_fc0(num_labels); - const TensorShape out_shape_fc0(num_labels); - - weights2.allocator()->init(TensorInfo(weights_shape_fc0, 1, DataType::F32)); - biases2.allocator()->init(TensorInfo(biases_shape_fc0, 1, DataType::F32)); - out_fc0.allocator()->init(TensorInfo(out_shape_fc0, 1, DataType::F32)); - - // Initialize tensor of act2 - out_act2.allocator()->init(TensorInfo(out_shape_fc0, 1, DataType::F32)); - - // Initialize tensor of softmax - const TensorShape out_shape_softmax(out_shape_fc0.x()); - out_softmax.allocator()->init(TensorInfo(out_shape_softmax, 1, DataType::F32)); + weights0.allocator()->init(TensorInfo(weights_shape_conv0, 1, DataType::F32)); + biases0.allocator()->init(TensorInfo(biases_shape_conv0, 1, DataType::F32)); + out_conv0.allocator()->init(TensorInfo(out_shape_conv0, 1, DataType::F32)); - /* -----------------------End: [Initialize tensors] */ + // Initialize tensor of act0 + out_act0.allocator()->init(TensorInfo(out_shape_conv0, 1, DataType::F32)); - /* [Configure functions] */ + // Initialize tensor of pool0 + TensorShape out_shape_pool0 = out_shape_conv0; + out_shape_pool0.set(0, out_shape_pool0.x() / 2); + out_shape_pool0.set(1, out_shape_pool0.y() / 2); + out_pool0.allocator()->init(TensorInfo(out_shape_pool0, 1, DataType::F32)); - // in:32x32x1: 5x5 convolution, 8 output features maps (OFM) - conv0.configure(&src, &weights0, &biases0, &out_conv0, PadStrideInfo(1 /* stride_x */, 1 /* stride_y */, 2 /* pad_x */, 2 /* pad_y */)); + // Initialize tensors of conv1 + constexpr unsigned int kernel_x_conv1 = 3; + constexpr unsigned int kernel_y_conv1 = 3; + constexpr unsigned int ofm_conv1 = 16; - // in:32x32x8, out:32x32x8, Activation function: relu - act0.configure(&out_conv0, &out_act0, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)); + const TensorShape weights_shape_conv1(kernel_x_conv1, kernel_y_conv1, out_shape_pool0.z(), ofm_conv1); - // in:32x32x8, out:16x16x8 (2x2 pooling), Pool type function: Max - pool0.configure(&out_act0, &out_pool0, PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2 /* stride_x */, 2 /* stride_y */))); + const TensorShape biases_shape_conv1(weights_shape_conv1[3]); + const TensorShape out_shape_conv1(out_shape_pool0.x(), out_shape_pool0.y(), weights_shape_conv1[3]); + + weights1.allocator()->init(TensorInfo(weights_shape_conv1, 1, DataType::F32)); + biases1.allocator()->init(TensorInfo(biases_shape_conv1, 1, DataType::F32)); + out_conv1.allocator()->init(TensorInfo(out_shape_conv1, 1, DataType::F32)); + + // Initialize tensor of act1 + out_act1.allocator()->init(TensorInfo(out_shape_conv1, 1, DataType::F32)); - // in:16x16x8: 3x3 convolution, 16 output features maps (OFM) - conv1.configure(&out_pool0, &weights1, &biases1, &out_conv1, PadStrideInfo(1 /* stride_x */, 1 /* stride_y */, 1 /* pad_x */, 1 /* pad_y */)); + // Initialize tensor of pool1 + TensorShape out_shape_pool1 = out_shape_conv1; + out_shape_pool1.set(0, out_shape_pool1.x() / 2); + out_shape_pool1.set(1, out_shape_pool1.y() / 2); + out_pool1.allocator()->init(TensorInfo(out_shape_pool1, 1, DataType::F32)); - // in:16x16x16, out:16x16x16, Activation function: relu - act1.configure(&out_conv1, &out_act1, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)); + // Initialize tensor of fc0 + constexpr unsigned int num_labels = 128; - // in:16x16x16, out:8x8x16 (2x2 pooling), Pool type function: Average - pool1.configure(&out_act1, &out_pool1, PoolingLayerInfo(PoolingType::AVG, 2, PadStrideInfo(2 /* stride_x */, 2 /* stride_y */))); + const TensorShape weights_shape_fc0(out_shape_pool1.x() * out_shape_pool1.y() * out_shape_pool1.z(), num_labels); + const TensorShape biases_shape_fc0(num_labels); + const TensorShape out_shape_fc0(num_labels); - // in:8x8x16, out:128 - fc0.configure(&out_pool1, &weights2, &biases2, &out_fc0); + weights2.allocator()->init(TensorInfo(weights_shape_fc0, 1, DataType::F32)); + biases2.allocator()->init(TensorInfo(biases_shape_fc0, 1, DataType::F32)); + out_fc0.allocator()->init(TensorInfo(out_shape_fc0, 1, DataType::F32)); - // in:128, out:128, Activation function: relu - act2.configure(&out_fc0, &out_act2, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)); + // Initialize tensor of act2 + out_act2.allocator()->init(TensorInfo(out_shape_fc0, 1, DataType::F32)); - // in:128, out:128 - softmax.configure(&out_act2, &out_softmax); + // Initialize tensor of softmax + const TensorShape out_shape_softmax(out_shape_fc0.x()); + out_softmax.allocator()->init(TensorInfo(out_shape_softmax, 1, DataType::F32)); - /* -----------------------End: [Configure functions] */ + /* -----------------------End: [Initialize tensors] */ - /*[ Add tensors to memory manager ]*/ + /* [Configure functions] */ - // We need 2 memory groups for handling the input and output - // We call explicitly allocate after manage() in order to avoid overlapping lifetimes - MemoryGroup memory_group0(mm_transitions); - MemoryGroup memory_group1(mm_transitions); + // in:32x32x1: 5x5 convolution, 8 output features maps (OFM) + conv0->configure(&src, &weights0, &biases0, &out_conv0, PadStrideInfo(1 /* stride_x */, 1 /* stride_y */, 2 /* pad_x */, 2 /* pad_y */)); - memory_group0.manage(&out_conv0); - out_conv0.allocator()->allocate(); - memory_group1.manage(&out_act0); - out_act0.allocator()->allocate(); - memory_group0.manage(&out_pool0); - out_pool0.allocator()->allocate(); - memory_group1.manage(&out_conv1); - out_conv1.allocator()->allocate(); - memory_group0.manage(&out_act1); - out_act1.allocator()->allocate(); - memory_group1.manage(&out_pool1); - out_pool1.allocator()->allocate(); - memory_group0.manage(&out_fc0); - out_fc0.allocator()->allocate(); - memory_group1.manage(&out_act2); - out_act2.allocator()->allocate(); - memory_group0.manage(&out_softmax); - out_softmax.allocator()->allocate(); + // in:32x32x8, out:32x32x8, Activation function: relu + act0.configure(&out_conv0, &out_act0, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)); - /* -----------------------End: [ Add tensors to memory manager ] */ + // in:32x32x8, out:16x16x8 (2x2 pooling), Pool type function: Max + pool0.configure(&out_act0, &out_pool0, PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2 /* stride_x */, 2 /* stride_y */))); - /* [Allocate tensors] */ + // in:16x16x8: 3x3 convolution, 16 output features maps (OFM) + conv1->configure(&out_pool0, &weights1, &biases1, &out_conv1, PadStrideInfo(1 /* stride_x */, 1 /* stride_y */, 1 /* pad_x */, 1 /* pad_y */)); - // Now that the padding requirements are known we can allocate all tensors - src.allocator()->allocate(); - weights0.allocator()->allocate(); - weights1.allocator()->allocate(); - weights2.allocator()->allocate(); - biases0.allocator()->allocate(); - biases1.allocator()->allocate(); - biases2.allocator()->allocate(); + // in:16x16x16, out:16x16x16, Activation function: relu + act1.configure(&out_conv1, &out_act1, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)); - /* -----------------------End: [Allocate tensors] */ + // in:16x16x16, out:8x8x16 (2x2 pooling), Pool type function: Average + pool1.configure(&out_act1, &out_pool1, PoolingLayerInfo(PoolingType::AVG, 2, PadStrideInfo(2 /* stride_x */, 2 /* stride_y */))); - // Finalize layers memory manager + // in:8x8x16, out:128 + fc0->configure(&out_pool1, &weights2, &biases2, &out_fc0); - // Set allocator that the memory manager will use - mm_layers->set_allocator(&allocator); + // in:128, out:128, Activation function: relu + act2.configure(&out_fc0, &out_act2, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)); - // Number of pools that the manager will create. This specifies how many layers you want to run in parallel - mm_layers->set_num_pools(1); + // in:128, out:128 + softmax->configure(&out_act2, &out_softmax); - // Finalize the manager. (Validity checks, memory allocations etc) - mm_layers->finalize(); + /* -----------------------End: [Configure functions] */ - // Finalize transitions memory manager + /*[ Add tensors to memory manager ]*/ - // Set allocator that the memory manager will use - mm_transitions->set_allocator(&allocator); + // We need 2 memory groups for handling the input and output + // We call explicitly allocate after manage() in order to avoid overlapping lifetimes + memory_group0 = arm_compute::support::cpp14::make_unique(mm_transitions); + memory_group1 = arm_compute::support::cpp14::make_unique(mm_transitions); - // Number of pools that the manager will create. This specifies how many models we can run in parallel. - // Setting to 2 as we need one for the input and one for the output at any given time - mm_transitions->set_num_pools(2); + memory_group0->manage(&out_conv0); + out_conv0.allocator()->allocate(); + memory_group1->manage(&out_act0); + out_act0.allocator()->allocate(); + memory_group0->manage(&out_pool0); + out_pool0.allocator()->allocate(); + memory_group1->manage(&out_conv1); + out_conv1.allocator()->allocate(); + memory_group0->manage(&out_act1); + out_act1.allocator()->allocate(); + memory_group1->manage(&out_pool1); + out_pool1.allocator()->allocate(); + memory_group0->manage(&out_fc0); + out_fc0.allocator()->allocate(); + memory_group1->manage(&out_act2); + out_act2.allocator()->allocate(); + memory_group0->manage(&out_softmax); + out_softmax.allocator()->allocate(); - // Finalize the manager. (Validity checks, memory allocations etc) - mm_transitions->finalize(); + /* -----------------------End: [ Add tensors to memory manager ] */ - /* [Initialize weights and biases tensors] */ + /* [Allocate tensors] */ - // Once the tensors have been allocated, the src, weights and biases tensors can be initialized - // ... + // Now that the padding requirements are known we can allocate all tensors + src.allocator()->allocate(); + weights0.allocator()->allocate(); + weights1.allocator()->allocate(); + weights2.allocator()->allocate(); + biases0.allocator()->allocate(); + biases1.allocator()->allocate(); + biases2.allocator()->allocate(); - /* -----------------------[Initialize weights and biases tensors] */ + /* -----------------------End: [Allocate tensors] */ - /* [Execute the functions] */ + // Finalize layers memory manager - // Acquire memory for the memory groups - memory_group0.acquire(); - memory_group1.acquire(); + // Set allocator that the memory manager will use + mm_layers->set_allocator(&allocator); - conv0.run(); - act0.run(); - pool0.run(); - conv1.run(); - act1.run(); - pool1.run(); - fc0.run(); - act2.run(); - softmax.run(); + // Number of pools that the manager will create. This specifies how many layers you want to run in parallel + mm_layers->set_num_pools(1); - // Release memory - memory_group0.release(); - memory_group1.release(); + // Finalize the manager. (Validity checks, memory allocations etc) + mm_layers->finalize(); - /* -----------------------End: [Execute the functions] */ -} + // Finalize transitions memory manager + + // Set allocator that the memory manager will use + mm_transitions->set_allocator(&allocator); + + // Number of pools that the manager will create. This specifies how many models we can run in parallel. + // Setting to 2 as we need one for the input and one for the output at any given time + mm_transitions->set_num_pools(2); + + // Finalize the manager. (Validity checks, memory allocations etc) + mm_transitions->finalize(); + } + void do_run() override + { + // Acquire memory for the memory groups + memory_group0->acquire(); + memory_group1->acquire(); + + conv0->run(); + act0.run(); + pool0.run(); + conv1->run(); + act1.run(); + pool1.run(); + fc0->run(); + act2.run(); + softmax->run(); + + // Release memory + memory_group0->release(); + memory_group1->release(); + } + +private: + // The src tensor should contain the input image + Tensor src{}; + + // Intermediate tensors used + Tensor weights0{}; + Tensor weights1{}; + Tensor weights2{}; + Tensor biases0{}; + Tensor biases1{}; + Tensor biases2{}; + Tensor out_conv0{}; + Tensor out_conv1{}; + Tensor out_act0{}; + Tensor out_act1{}; + Tensor out_act2{}; + Tensor out_pool0{}; + Tensor out_pool1{}; + Tensor out_fc0{}; + Tensor out_softmax{}; + + // NEON allocator + Allocator allocator{}; + + // Memory groups + std::unique_ptr memory_group0{}; + std::unique_ptr memory_group1{}; + + // Layers + std::unique_ptr conv0{}; + std::unique_ptr conv1{}; + std::unique_ptr fc0{}; + std::unique_ptr softmax{}; + NEPoolingLayer pool0{}; + NEPoolingLayer pool1{}; + NEActivationLayer act0{}; + NEActivationLayer act1{}; + NEActivationLayer act2{}; +}; /** Main program for cnn test * @@ -293,5 +300,5 @@ void main_cnn(int argc, char **argv) */ int main(int argc, char **argv) { - return utils::run_example(argc, argv, main_cnn); + return utils::run_example(argc, argv); } -- cgit v1.2.1