aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuFullyConnected.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuFullyConnected.h')
-rw-r--r--src/cpu/operators/CpuFullyConnected.h52
1 files changed, 38 insertions, 14 deletions
diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h
index 1e8c6478d0..7073fb9f7c 100644
--- a/src/cpu/operators/CpuFullyConnected.h
+++ b/src/cpu/operators/CpuFullyConnected.h
@@ -24,11 +24,11 @@
#ifndef ARM_COMPUTE_CPU_FULLY_CONNECTED_H
#define ARM_COMPUTE_CPU_FULLY_CONNECTED_H
-#include "src/cpu/ICpuOperator.h"
-
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/function_info/FullyConnectedLayerInfo.h"
+#include "src/cpu/ICpuOperator.h"
+
#include <memory>
namespace arm_compute
@@ -86,16 +86,24 @@ public:
* @param[in] fc_info (Optional) Fully connected layer additional info
* @param[in] weights_info (Optional) Stores neccessary compute information when weights are already reshaped
*/
- void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst,
- FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
+ void configure(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ ITensorInfo *dst,
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
+ const WeightsInfo &weights_info = WeightsInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected
*
* Similar to @ref CpuFullyConnected::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *dst,
+ FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
+ const WeightsInfo &weights_info = WeightsInfo());
/** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format
* weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same
@@ -103,19 +111,35 @@ public:
*
* @return a status
*/
- static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights,
- const ITensorInfo *biases, const ITensorInfo *dst,
- FullyConnectedLayerInfo fc_info, WeightsInfo weights_info);
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format,
+ const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *dst,
+ FullyConnectedLayerInfo fc_info,
+ WeightsInfo weights_info);
//Inherited methods override
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
- void configure_fc_fc(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act);
- void configure_conv_fc(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act);
- void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act);
+ void configure_fc_fc(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ ITensorInfo *dst,
+ const ActivationLayerInfo &act);
+ void configure_conv_fc(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ ITensorInfo *dst,
+ const ActivationLayerInfo &act);
+ void configure_mm(const ITensorInfo *src,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ ITensorInfo *dst,
+ const ActivationLayerInfo &act);
enum AuxTensorIdx
{