aboutsummaryrefslogtreecommitdiff
path: root/tests/networks
diff options
context:
space:
mode:
Diffstat (limited to 'tests/networks')
-rw-r--r--tests/networks/AlexNetNetwork.h11
-rw-r--r--tests/networks/LeNet5Network.h7
-rw-r--r--tests/networks/MobileNetNetwork.h7
-rw-r--r--tests/networks/MobileNetV1Network.h8
4 files changed, 29 insertions, 4 deletions
diff --git a/tests/networks/AlexNetNetwork.h b/tests/networks/AlexNetNetwork.h
index 819111f897..a30b7f8f75 100644
--- a/tests/networks/AlexNetNetwork.h
+++ b/tests/networks/AlexNetNetwork.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,6 +54,13 @@ template <typename ITensorType,
class AlexNetNetwork
{
public:
+ /** Initialize the network.
+ *
+ * @param[in] data_type Data type.
+ * @param[in] fixed_point_position Fixed point position (for quantized data types).
+ * @param[in] batches Number of batches.
+ * @param[in] reshaped_weights Whether the weights need reshaping or not. Default: false.
+ */
void init(DataType data_type, int fixed_point_position, int batches, bool reshaped_weights = false)
{
_data_type = data_type;
@@ -185,6 +192,7 @@ public:
}
}
+ /** Build the network */
void build()
{
input.allocator()->init(TensorInfo(TensorShape(227U, 227U, 3U, _batches), 1, _data_type, _fixed_point_position));
@@ -270,6 +278,7 @@ public:
smx.configure(&fc8_out, &output);
}
+ /** Allocate the network */
void allocate()
{
input.allocator()->allocate();
diff --git a/tests/networks/LeNet5Network.h b/tests/networks/LeNet5Network.h
index a46489f88c..9cfd59284c 100644
--- a/tests/networks/LeNet5Network.h
+++ b/tests/networks/LeNet5Network.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,6 +50,10 @@ template <typename TensorType,
class LeNet5Network
{
public:
+ /** Initialize the network.
+ *
+ * @param[in] batches Number of batches.
+ */
void init(int batches)
{
_batches = batches;
@@ -94,6 +98,7 @@ public:
smx.configure(&fc2_out, &output);
}
+ /** Allocate the network */
void allocate()
{
// Allocate tensors
diff --git a/tests/networks/MobileNetNetwork.h b/tests/networks/MobileNetNetwork.h
index 8c3cb1fb2d..ec054b237e 100644
--- a/tests/networks/MobileNetNetwork.h
+++ b/tests/networks/MobileNetNetwork.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,10 @@ template <typename TensorType,
class MobileNetNetwork
{
public:
+ /** Initialize the network.
+ *
+ * @param[in] batches Number of batches.
+ */
void init(int batches)
{
_batches = batches;
@@ -105,6 +109,7 @@ public:
reshape.configure(&conv_out[14], &output);
}
+ /** Allocate the network. */
void allocate()
{
input.allocator()->allocate();
diff --git a/tests/networks/MobileNetV1Network.h b/tests/networks/MobileNetV1Network.h
index 0957c6b555..aea5c113e8 100644
--- a/tests/networks/MobileNetV1Network.h
+++ b/tests/networks/MobileNetV1Network.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -55,6 +55,11 @@ template <typename TensorType,
class MobileNetV1Network
{
public:
+ /** Initialize the network.
+ *
+ * @param[in] input_spatial_size Size of the spatial input.
+ * @param[in] batches Number of batches.
+ */
void init(unsigned int input_spatial_size, int batches)
{
_batches = batches;
@@ -117,6 +122,7 @@ public:
smx.configure(&reshape_out, &output);
}
+ /** Allocate the network. */
void allocate()
{
input.allocator()->allocate();