aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h51
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
index 43abb6769b..e4e6f0760e 100644
--- a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
+++ b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
@@ -25,7 +25,9 @@
#define __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__
#include "arm_compute/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/CL/ICLSimpleFunction.h"
+#include "arm_compute/runtime/ITransformWeights.h"
namespace arm_compute
{
@@ -54,5 +56,54 @@ public:
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout);
};
+
+namespace weights_transformations
+{
+/** Basic function to run @ref CLConvertFullyConnectedWeightsKernel. */
+class CLConvertFullyConnectedWeightsManaged : public ITransformWeights
+{
+public:
+ //Inherited method override
+ void run() override
+ {
+ _output.allocator()->allocate();
+ _func.run();
+ _reshape_run = true;
+ }
+
+ //Inherited method override
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ //Inherited method override
+ ICLTensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ //Inherited method override
+ uint32_t uid() override
+ {
+ return _uid;
+ }
+ /** Configures the @ref CLConvertFullyConnectedWeights function
+ *
+ * @param[in] input Source weights tensor info to convert. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32.
+ * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer).
+ * @param[in] data_layout The data layout the weights have been trained in.
+ */
+ void configure(const ICLTensor *input, const TensorShape &original_input_shape, DataLayout data_layout)
+ {
+ _func.configure(input, &_output, original_input_shape, data_layout);
+ }
+
+private:
+ static constexpr uint32_t _uid = 0x5;
+ CLTensor _output{};
+ CLConvertFullyConnectedWeights _func{};
+};
+} // namespace weights_transformations
} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ */