aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h')
-rw-r--r--arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h29
1 files changed, 21 insertions, 8 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
index 05b7ce3735..885f8430cf 100644
--- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
@@ -28,7 +28,6 @@
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
#include "arm_compute/runtime/IWeightsManager.h"
-
#include "arm_compute/runtime/NEON/functions/NETranspose.h"
#include "arm_compute/runtime/Tensor.h"
@@ -88,7 +87,8 @@ class NEFullyConnectedLayer : public IFunction
{
public:
/** Constructor */
- NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
+ NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr,
+ IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
NEFullyConnectedLayer(const NEFullyConnectedLayer &) = delete;
/** Prevent instances of this class from being moved (As this class contains pointers) */
@@ -126,16 +126,24 @@ public:
* @param[in] fc_info (Optional) Fully connected layer additional info
* @param[in] weights_info (Optional) Stores neccessary compute information when weights are already reshaped
*/
- void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output,
- FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
+ void configure(const ITensor *input,
+ const ITensor *weights,
+ const ITensor *biases,
+ ITensor *output,
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
+ const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will lead to a valid configuration of @ref NEFullyConnectedLayer
*
* Similar to @ref NEFullyConnectedLayer::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
- FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
+ static Status validate(const ITensorInfo *input,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *output,
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
+ const WeightsInfo &weights_info = WeightsInfo());
/** Static function that queries whether fixed-format kernel exists for a given problem description
*
@@ -149,8 +157,13 @@ public:
*
* @return a status
*/
- static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights,
- const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info, const WeightsInfo &weights_info);
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format,
+ const ITensorInfo *input,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *output,
+ const FullyConnectedLayerInfo &fc_info,
+ const WeightsInfo &weights_info);
//Inherited methods override
void run() override;