From bfa3b52de2cfbd330efc19e2096134a20c645406 Mon Sep 17 00:00:00 2001 From: Gian Marco Date: Tue, 12 Dec 2017 10:08:38 +0000 Subject: COMPMID-556 - Fix examples - Fixed data type issue in cl_sgemm - Added support for NEON and OpenCL targets in graph examples. Before we could run only OpenCL target - Add auto_init() in NEDepthwiseVectorToTensorKernel Change-Id: I4410ce6f4992b2375b980634fe55f1083cf3c471 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112850 Reviewed-by: Anthony Barbier Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com --- examples/graph_vgg16.cpp | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) (limited to 'examples/graph_vgg16.cpp') diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp index 44dd1f63e4..cac38d30a7 100644 --- a/examples/graph_vgg16.cpp +++ b/examples/graph_vgg16.cpp @@ -35,7 +35,7 @@ using namespace arm_compute::graph_utils; /** Example demonstrating how to implement VGG16's network using the Compute Library's graph API * * @param[in] argc Number of arguments - * @param[in] argv Arguments ( [optional] Path to the weights folder, [optional] image, [optional] labels ) + * @param[in] argv Arguments ( [optional] Target (0 = NEON, 1 = OpenCL), [optional] Path to the weights folder, [optional] image, [optional] labels ) */ void main_graph_vgg16(int argc, const char **argv) { @@ -47,43 +47,46 @@ void main_graph_vgg16(int argc, const char **argv) constexpr float mean_g = 116.779f; /* Mean value to subtract from green channel */ constexpr float mean_b = 103.939f; /* Mean value to subtract from blue channel */ + // Set target. 0 (NEON), 1 (OpenCL). By default it is NEON + TargetHint target_hint = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0); + ConvolutionMethodHint convolution_hint = ConvolutionMethodHint::DIRECT; + // Parse arguments if(argc < 2) { // Print help - std::cout << "Usage: " << argv[0] << " [path_to_data] [image] [labels]\n\n"; + std::cout << "Usage: " << argv[0] << " [target] [path_to_data] [image] [labels]\n\n"; std::cout << "No data folder provided: using random values\n\n"; } else if(argc == 2) { - data_path = argv[1]; - std::cout << "Usage: " << argv[0] << " " << argv[1] << " [image] [labels]\n\n"; - std::cout << "No image provided: using random values\n\n"; + std::cout << "Usage: " << argv[0] << " " << argv[1] << " [path_to_data] [image] [labels]\n\n"; + std::cout << "No data folder provided: using random values\n\n"; } else if(argc == 3) { - data_path = argv[1]; - image = argv[2]; - std::cout << "Usage: " << argv[0] << " " << argv[1] << " " << argv[2] << " [labels]\n\n"; - std::cout << "No text file with labels provided: skipping output accessor\n\n"; + data_path = argv[2]; + std::cout << "Usage: " << argv[0] << " " << argv[1] << " " << argv[2] << " [image] [labels]\n\n"; + std::cout << "No image provided: using random values\n\n"; } - else + else if(argc == 4) { - data_path = argv[1]; - image = argv[2]; - label = argv[3]; + data_path = argv[2]; + image = argv[3]; + std::cout << "Usage: " << argv[0] << " " << argv[1] << " " << argv[2] << " " << argv[3] << " [labels]\n\n"; + std::cout << "No text file with labels provided: skipping output accessor\n\n"; } - - // Check if OpenCL is available and initialize the scheduler - TargetHint hint = TargetHint::NEON; - if(Graph::opencl_is_available()) + else { - hint = TargetHint::OPENCL; + data_path = argv[2]; + image = argv[3]; + label = argv[4]; } Graph graph; - graph << hint + graph << target_hint + << convolution_hint << Tensor(TensorInfo(TensorShape(224U, 224U, 3U, 1U), 1, DataType::F32), get_input_accessor(image, mean_r, mean_g, mean_b)) << ConvolutionMethodHint::DIRECT @@ -211,7 +214,7 @@ void main_graph_vgg16(int argc, const char **argv) /** Main program for VGG16 * * @param[in] argc Number of arguments - * @param[in] argv Arguments ( [optional] Path to the weights folder, [optional] image, [optional] labels ) + * @param[in] argv Arguments ( [optional] Target (0 = NEON, 1 = OpenCL), [optional] Path to the weights folder, [optional] image, [optional] labels ) */ int main(int argc, const char **argv) { -- cgit v1.2.1