aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2018-05-31 17:31:05 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:54 +0000
commitb62280aca3148dd6762e57e5af3da0cb0a9e2db5 (patch)
treeaa10c3750dcb8b13151d40529facf92667c336c9 /arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
parentda2491fb6d3cefb69846f220356fff282486495c (diff)
downloadComputeLibrary-b62280aca3148dd6762e57e5af3da0cb0a9e2db5.tar.gz
COMPMID-1244: Allow retaining weights in CLGEMMConvolutionLayer and CLFullyConnectedLayer
Change-Id: I1c3b2197906cd4b905309bbd5f2012bbae6a7dba Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/133730 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h34
1 files changed, 20 insertions, 14 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index 7fb5af9229..127d8acf10 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -86,26 +86,32 @@ public:
CLFullyConnectedLayer &operator=(CLFullyConnectedLayer &&) = default;
/** Set the input and output tensors.
*
- * @param[in] input Source tensor. Data type supported: QS8/QASYMM8/QS16/F16/F32.
- * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input
- * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input.
- * @param[out] output Destination tensor. Data type supported: Same as @p input.
- * @param[in] transpose_weights (Optional) Transpose weights if true. Defaults to true.
- * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false.
+ * @param[in] input Source tensor. Data type supported: QS8/QASYMM8/QS16/F16/F32.
+ * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input
+ * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input.
+ * @param[out] output Destination tensor. Data type supported: Same as @p input.
+ * @param[in] transpose_weights (Optional) Transpose weights if true. Defaults to true.
+ * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false.
+ * @param[in] retain_internal_weights (Optional) Retain internal reshaped weights. Defaults to false.
+ * Used for reconfiguration purposes.
*/
- void configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights = true, bool are_weights_reshaped = false);
+ void configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights = true, bool are_weights_reshaped = false,
+ bool retain_internal_weights = false);
/** Static function to check if given info will lead to a valid configuration of @ref CLFullyConnectedLayer
*
- * @param[in] input Source tensor. Data type supported: QS8/QASYMM8/QS16/F16/F32.
- * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input
- * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input.
- * @param[in] output Destination tensor. Data type supported: Same as @p input.
- * @param[in] transpose_weights (Optional) Transpose weights if true. Defaults to true.
- * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false.
+ * @param[in] input Source tensor. Data type supported: QS8/QASYMM8/QS16/F16/F32.
+ * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input
+ * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input.
+ * @param[in] output Destination tensor. Data type supported: Same as @p input.
+ * @param[in] transpose_weights (Optional) Transpose weights if true. Defaults to true.
+ * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false.
+ * @param[in] retain_internal_weights (Optional) Retain internal reshaped weights. Defaults to false.
+ * Used for reconfiguration purposes.
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights = true, bool are_weights_reshaped = false);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights = true, bool are_weights_reshaped = false,
+ bool retain_internal_weights = false);
//Inherited methods override
void run() override;