aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/graph_alexnet.cpp2
-rw-r--r--examples/graph_googlenet.cpp2
-rw-r--r--examples/graph_inception_resnet_v1.cpp2
-rw-r--r--examples/graph_inception_resnet_v2.cpp2
-rw-r--r--examples/graph_inception_v3.cpp2
-rw-r--r--examples/graph_inception_v4.cpp2
-rw-r--r--examples/graph_mobilenet.cpp2
-rw-r--r--examples/graph_mobilenet_v2.cpp2
-rw-r--r--examples/graph_resnet12.cpp2
-rw-r--r--examples/graph_resnet50.cpp2
-rw-r--r--examples/graph_resnet_v2_50.cpp2
-rw-r--r--examples/graph_resnext50.cpp2
-rw-r--r--examples/graph_shufflenet.cpp2
-rw-r--r--examples/graph_squeezenet.cpp2
-rw-r--r--examples/graph_squeezenet_v1_1.cpp2
-rw-r--r--examples/graph_srcnn955.cpp2
-rw-r--r--examples/graph_vgg16.cpp2
-rw-r--r--examples/graph_vgg19.cpp2
-rw-r--r--examples/graph_vgg_vdsr.cpp2
-rw-r--r--utils/CommonGraphOptions.cpp3
-rw-r--r--utils/CommonGraphOptions.h2
21 files changed, 24 insertions, 19 deletions
diff --git a/examples/graph_alexnet.cpp b/examples/graph_alexnet.cpp
index 7f4e75aaf8..53a4547e04 100644
--- a/examples/graph_alexnet.cpp
+++ b/examples/graph_alexnet.cpp
@@ -74,7 +74,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(227U, 227U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(227U, 227U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_googlenet.cpp b/examples/graph_googlenet.cpp
index 7555d805c1..683205b3b5 100644
--- a/examples/graph_googlenet.cpp
+++ b/examples/graph_googlenet.cpp
@@ -70,7 +70,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_inception_resnet_v1.cpp b/examples/graph_inception_resnet_v1.cpp
index 6ae5b5dc77..d789d7f6e7 100644
--- a/examples/graph_inception_resnet_v1.cpp
+++ b/examples/graph_inception_resnet_v1.cpp
@@ -96,7 +96,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_inception_resnet_v2.cpp b/examples/graph_inception_resnet_v2.cpp
index ae37ee507d..1d0c51e9ad 100644
--- a/examples/graph_inception_resnet_v2.cpp
+++ b/examples/graph_inception_resnet_v2.cpp
@@ -80,7 +80,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(299U, 299U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(299U, 299U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_inception_v3.cpp b/examples/graph_inception_v3.cpp
index 928efb9124..160e7f04f4 100644
--- a/examples/graph_inception_v3.cpp
+++ b/examples/graph_inception_v3.cpp
@@ -68,7 +68,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(299U, 299U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(299U, 299U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_inception_v4.cpp b/examples/graph_inception_v4.cpp
index 0c67215136..6d8fab4141 100644
--- a/examples/graph_inception_v4.cpp
+++ b/examples/graph_inception_v4.cpp
@@ -70,7 +70,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(299U, 299U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(299U, 299U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_mobilenet.cpp b/examples/graph_mobilenet.cpp
index 4d4e17715d..4630dc958a 100644
--- a/examples/graph_mobilenet.cpp
+++ b/examples/graph_mobilenet.cpp
@@ -72,7 +72,7 @@ public:
unsigned int spatial_size = (model_id == 0 || common_params.data_type == DataType::QASYMM8) ? 224 : 160;
// Create input descriptor
- const TensorShape tensor_shape = permute_shape(TensorShape(spatial_size, spatial_size, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(spatial_size, spatial_size, 3U, common_params.batches), DataLayout::NCHW, common_params.data_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
// Set graph hints
diff --git a/examples/graph_mobilenet_v2.cpp b/examples/graph_mobilenet_v2.cpp
index b1b33be2f5..c027e6f13e 100644
--- a/examples/graph_mobilenet_v2.cpp
+++ b/examples/graph_mobilenet_v2.cpp
@@ -64,7 +64,7 @@ public:
std::cout << common_params << std::endl;
// Create input descriptor
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, common_params.data_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
// Set graph hints
diff --git a/examples/graph_resnet12.cpp b/examples/graph_resnet12.cpp
index 8818cf742a..48708ce29a 100644
--- a/examples/graph_resnet12.cpp
+++ b/examples/graph_resnet12.cpp
@@ -84,7 +84,7 @@ public:
std::unique_ptr<IPreprocessor> preprocessor = std::make_unique<TFPreproccessor>();
// Create input descriptor
- const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 3U, common_params.batches), DataLayout::NCHW, common_params.data_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
// Set weights trained layout
diff --git a/examples/graph_resnet50.cpp b/examples/graph_resnet50.cpp
index 5834d9be77..0d3322c886 100644
--- a/examples/graph_resnet50.cpp
+++ b/examples/graph_resnet50.cpp
@@ -68,7 +68,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_resnet_v2_50.cpp b/examples/graph_resnet_v2_50.cpp
index cd4e6fd6df..6d5abb4f4b 100644
--- a/examples/graph_resnet_v2_50.cpp
+++ b/examples/graph_resnet_v2_50.cpp
@@ -71,7 +71,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_resnext50.cpp b/examples/graph_resnext50.cpp
index ec87e0b882..97eb85ccd1 100644
--- a/examples/graph_resnext50.cpp
+++ b/examples/graph_resnext50.cpp
@@ -66,7 +66,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_shufflenet.cpp b/examples/graph_shufflenet.cpp
index f90f36149d..6e13c5eeb4 100644
--- a/examples/graph_shufflenet.cpp
+++ b/examples/graph_shufflenet.cpp
@@ -82,7 +82,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_squeezenet.cpp b/examples/graph_squeezenet.cpp
index 82d95143be..3ea2fea38f 100644
--- a/examples/graph_squeezenet.cpp
+++ b/examples/graph_squeezenet.cpp
@@ -67,7 +67,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_squeezenet_v1_1.cpp b/examples/graph_squeezenet_v1_1.cpp
index 1a7d752d50..9cc183fbbd 100644
--- a/examples/graph_squeezenet_v1_1.cpp
+++ b/examples/graph_squeezenet_v1_1.cpp
@@ -67,7 +67,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(227U, 227U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(227U, 227U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_srcnn955.cpp b/examples/graph_srcnn955.cpp
index ccad0b65e4..855bbd848e 100644
--- a/examples/graph_srcnn955.cpp
+++ b/examples/graph_srcnn955.cpp
@@ -81,7 +81,7 @@ public:
std::unique_ptr<IPreprocessor> preprocessor = std::make_unique<TFPreproccessor>();
// Create input descriptor
- const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 3U, common_params.batches), DataLayout::NCHW, common_params.data_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
// Set weights trained layout
diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp
index 3e453b3626..fcfe6ef50d 100644
--- a/examples/graph_vgg16.cpp
+++ b/examples/graph_vgg16.cpp
@@ -67,7 +67,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_vgg19.cpp b/examples/graph_vgg19.cpp
index d79aa01326..efc0bcce19 100644
--- a/examples/graph_vgg19.cpp
+++ b/examples/graph_vgg19.cpp
@@ -66,7 +66,7 @@ public:
// Create input descriptor
const auto operation_layout = common_params.data_layout;
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, common_params.batches), DataLayout::NCHW, operation_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
diff --git a/examples/graph_vgg_vdsr.cpp b/examples/graph_vgg_vdsr.cpp
index 226edcd15b..3fe28e0fed 100644
--- a/examples/graph_vgg_vdsr.cpp
+++ b/examples/graph_vgg_vdsr.cpp
@@ -82,7 +82,7 @@ public:
std::unique_ptr<IPreprocessor> preprocessor = std::make_unique<TFPreproccessor>();
// Create input descriptor
- const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 1U, 1U), DataLayout::NCHW, common_params.data_layout);
+ const TensorShape tensor_shape = permute_shape(TensorShape(image_width, image_height, 1U, common_params.batches), DataLayout::NCHW, common_params.data_layout);
TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
// Set weights trained layout
diff --git a/utils/CommonGraphOptions.cpp b/utils/CommonGraphOptions.cpp
index 6ee7ed1878..c0270726da 100644
--- a/utils/CommonGraphOptions.cpp
+++ b/utils/CommonGraphOptions.cpp
@@ -117,6 +117,7 @@ namespace utils
CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser)
: help(parser.add_option<ToggleOption>("help")),
threads(parser.add_option<SimpleOption<int>>("threads", 1)),
+ batches(parser.add_option<SimpleOption<int>>("batches", 1)),
target(),
data_type(),
data_layout(),
@@ -168,6 +169,7 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser)
help->set_help("Show this help message");
threads->set_help("Number of threads to use");
+ batches->set_help("Number of batches to use for the inputs");
target->set_help("Target to execute on");
data_type->set_help("Data type to use");
data_layout->set_help("Data layout to use");
@@ -197,6 +199,7 @@ CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options)
CommonGraphParams common_params;
common_params.help = options.help->is_set() ? options.help->value() : false;
common_params.threads = options.threads->value();
+ common_params.batches = options.batches->value();
common_params.target = options.target->value();
common_params.data_type = options.data_type->value();
if(options.data_layout->is_set())
diff --git a/utils/CommonGraphOptions.h b/utils/CommonGraphOptions.h
index f2cbd48b72..afdb78b1be 100644
--- a/utils/CommonGraphOptions.h
+++ b/utils/CommonGraphOptions.h
@@ -94,6 +94,7 @@ struct CommonGraphParams
{
bool help{ false };
int threads{ 0 };
+ int batches{ 1 };
arm_compute::graph::Target target{ arm_compute::graph::Target::NEON };
arm_compute::DataType data_type{ DataType::F32 };
arm_compute::DataLayout data_layout{ DataLayout::NHWC };
@@ -151,6 +152,7 @@ public:
ToggleOption *help; /**< Show help option */
SimpleOption<int> *threads; /**< Number of threads option */
+ SimpleOption<int> *batches; /**< Number of batches */
EnumOption<arm_compute::graph::Target> *target; /**< Graph execution target */
EnumOption<arm_compute::DataType> *data_type; /**< Graph data type */
EnumOption<arm_compute::DataLayout> *data_layout; /**< Graph data layout */