diff options
Diffstat (limited to 'examples/graph_lenet.cpp')
-rw-r--r-- | examples/graph_lenet.cpp | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/examples/graph_lenet.cpp b/examples/graph_lenet.cpp index 61bc7bd3bf..e4b8effe5d 100644 --- a/examples/graph_lenet.cpp +++ b/examples/graph_lenet.cpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "arm_compute/graph/Graph.h" -#include "arm_compute/graph/Nodes.h" +#include "arm_compute/graph2.h" + #include "support/ToolchainSupport.h" #include "utils/GraphUtils.h" #include "utils/Utils.h" @@ -30,7 +30,7 @@ #include <cstdlib> using namespace arm_compute::utils; -using namespace arm_compute::graph; +using namespace arm_compute::graph2::frontend; using namespace arm_compute::graph_utils; /** Example demonstrating how to implement LeNet's network using the Compute Library's graph API @@ -47,8 +47,10 @@ public: unsigned int batches = 4; /** Number of batches */ // Set target. 0 (NEON), 1 (OpenCL), 2 (OpenCL with Tuner). By default it is NEON - const int int_target_hint = argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0; - TargetHint target_hint = set_target_hint(int_target_hint); + const int target = argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0; + Target target_hint = set_target_hint2(target); + bool enable_tuning = (target == 2); + bool enable_memory_management = true; // Parse arguments if(argc < 2) @@ -78,7 +80,7 @@ public: //conv1 << pool1 << conv2 << pool2 << fc1 << act1 << fc2 << smx graph << target_hint - << Tensor(TensorInfo(TensorShape(28U, 28U, 1U, batches), 1, DataType::F32), DummyAccessor()) + << InputLayer(TensorDescriptor(TensorShape(28U, 28U, 1U, batches), DataType::F32), get_input_accessor("")) << ConvolutionLayer( 5U, 5U, 20U, get_weights_accessor(data_path, "/cnn_data/lenet_model/conv1_w.npy"), @@ -101,10 +103,10 @@ public: get_weights_accessor(data_path, "/cnn_data/lenet_model/ip2_w.npy"), get_weights_accessor(data_path, "/cnn_data/lenet_model/ip2_b.npy")) << SoftmaxLayer() - << Tensor(DummyAccessor(0)); + << OutputLayer(get_output_accessor("")); - // In order to enable the OpenCL tuner, graph_init() has to be called only when all nodes have been instantiated - graph.graph_init(int_target_hint == 2); + // Finalize graph + graph.finalize(target_hint, enable_tuning, enable_memory_management); } void do_run() override { @@ -113,7 +115,7 @@ public: } private: - Graph graph{}; + Stream graph{ 0, "LeNet" }; }; /** Main program for LeNet |