diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2020-07-16 17:46:51 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2020-07-21 16:22:21 +0000 |
commit | f20d6d6ae5a0da2c856294e93341cdc065db58f9 (patch) | |
tree | 0e11a924371691fb9d345cd9362cd8cc06662432 /arm_compute/runtime/CL/functions | |
parent | c6eaec3610fa27651582f6c1acad35afffe360f6 (diff) | |
download | ComputeLibrary-f20d6d6ae5a0da2c856294e93341cdc065db58f9.tar.gz |
COMPMID-3390: Async support to CLStridedSliceLayerKernel kernels/functions
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Change-Id: I9ff7e8d2fb4d36c4b7c44e885abf34ff6d4c577c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3587
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLSlice.h | 66 | ||||
-rw-r--r-- | arm_compute/runtime/CL/functions/CLStridedSlice.h | 76 |
2 files changed, 138 insertions, 4 deletions
diff --git a/arm_compute/runtime/CL/functions/CLSlice.h b/arm_compute/runtime/CL/functions/CLSlice.h index 9f9591e4de..6fe62acaf5 100644 --- a/arm_compute/runtime/CL/functions/CLSlice.h +++ b/arm_compute/runtime/CL/functions/CLSlice.h @@ -24,15 +24,18 @@ #ifndef ARM_COMPUTE_CL_SLICE_H #define ARM_COMPUTE_CL_SLICE_H -#include "arm_compute/runtime/CL/ICLSimpleFunction.h" +#include "arm_compute/runtime/CL/ICLOperator.h" +#include "arm_compute/runtime/IFunction.h" namespace arm_compute { // Forward Declarations class ICLTensor; +namespace experimental +{ /** Basic function to perform tensor slicing */ -class CLSlice : public ICLSimpleFunction +class CLSlice : public ICLOperator { public: /** Configure kernel @@ -42,6 +45,58 @@ public: * @note End coordinates can be negative, which represents the number of elements before the end of that dimension. * @note End indices are not inclusive unless negative. * + * @param[in] compile_context The compile context to be used. + * @param[in] input Source tensor info. Data type supported: All. + * @param[out] output Destination tensor info. Data type supported: Same as @p input + * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] ends The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input). + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo *input, ITensorInfo *output, const Coordinates &starts, const Coordinates &ends); + + /** Static function to check if given info will lead to a valid configuration of @ref CLSlice + * + * @note Supported tensor rank: up to 4 + * @note Start indices must be non-negative. 0 <= starts[i] + * @note End coordinates can be negative, which represents the number of elements before the end of that dimension. + * @note End indices are not inclusive unless negative. + * + * @param[in] input Source tensor info. Data type supported: All + * @param[in] output Destination tensor info. Data type supported: Same as @p input + * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] ends The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * + * @return A status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Coordinates &starts, const Coordinates &ends); + + // Inherited methods overridden: + MemoryRequirements workspace() const override; +}; +} // namespace experimental + +/** Basic function to perform tensor slicing */ +class CLSlice : public IFunction +{ +public: + /** Default Constructor */ + CLSlice(); + /** Default Destructor */ + ~CLSlice(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLSlice(const CLSlice &) = delete; + /** Default move constructor */ + CLSlice(CLSlice &&); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLSlice &operator=(const CLSlice &) = delete; + /** Default move assignment operator */ + CLSlice &operator=(CLSlice &&); + /** Configure kernel + * + * @note Supported tensor rank: up to 4 + * @note Start indices must be non-negative. 0 <= starts[i] + * @note End coordinates can be negative, which represents the number of elements before the end of that dimension. + * @note End indices are not inclusive unless negative. + * * @param[in] input Source tensor. Data type supported: All. * @param[out] output Destination tensor. Data type supported: Same as @p input * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input). @@ -78,6 +133,13 @@ public: * @return A status */ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Coordinates &starts, const Coordinates &ends); + + // Inherited methods overridden: + void run() override; + +private: + struct Impl; + std::unique_ptr<Impl> _impl; }; } // namespace arm_compute #endif /* ARM_COMPUTE_CL_SLICE_H */ diff --git a/arm_compute/runtime/CL/functions/CLStridedSlice.h b/arm_compute/runtime/CL/functions/CLStridedSlice.h index 98a3bd49d3..394d8c4f59 100644 --- a/arm_compute/runtime/CL/functions/CLStridedSlice.h +++ b/arm_compute/runtime/CL/functions/CLStridedSlice.h @@ -24,7 +24,9 @@ #ifndef ARM_COMPUTE_CL_STRIDED_SLICE_H #define ARM_COMPUTE_CL_STRIDED_SLICE_H -#include "arm_compute/runtime/CL/ICLSimpleFunction.h" +#include "arm_compute/runtime/CL/CLRuntimeContext.h" +#include "arm_compute/runtime/CL/ICLOperator.h" +#include "arm_compute/runtime/IFunction.h" namespace arm_compute { @@ -32,9 +34,24 @@ namespace arm_compute class ICLTensor; /** Basic function to run @ref CLStridedSliceKernel */ -class CLStridedSlice : public ICLSimpleFunction +class CLStridedSlice : public IFunction { public: + /** Constructor + * + * @param[in] ctx Runtime context to be used by the function + */ + CLStridedSlice(CLRuntimeContext *ctx = nullptr); + /** Destructor */ + ~CLStridedSlice(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLStridedSlice(const CLStridedSlice &) = delete; + /** Default move constructor */ + CLStridedSlice(CLStridedSlice &&); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLStridedSlice &operator=(const CLStridedSlice &) = delete; + /** Default move assignment operator */ + CLStridedSlice &operator=(CLStridedSlice &&); /** Configure kernel * * @note Supported tensor rank: up to 4 @@ -88,6 +105,61 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Coordinates &starts, const Coordinates &ends, const BiStrides &strides, int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0); + + // Inherited methods overridden: + void run() override; + +private: + struct Impl; + std::unique_ptr<Impl> _impl; +}; + +namespace experimental +{ +/** Basic function to run @ref CLStridedSliceKernel */ +class CLStridedSlice : public ICLOperator +{ +public: + /** Configure kernel + * + * @note Supported tensor rank: up to 4 + * + * @param[in] compile_context The compile context to be used. + * @param[in] input Source tensor info. Data type supported: All. + * @param[out] output Destination tensor info. Data type supported: Same as @p input + * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] ends The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] strides The strides of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] end_mask (Optional) If the ith bit of end_mask is set, ends[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1. + * A slice of size 1 starting from starts[i] in the dimension must be preserved. + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo *input, ITensorInfo *output, + const Coordinates &starts, const Coordinates &ends, const BiStrides &strides, + int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0); + + /** Static function to check if given info will lead to a valid configuration of @ref CLStridedSlice + * + * @note Supported tensor rank: up to 4 + * + * @param[in] input Source tensor info. Data type supported: All. + * @param[in] output Destination tensor info. Data type supported: Same as @p input + * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] ends The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] strides The strides of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] end_mask (Optional) If the ith bit of end_mask is set, ends[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1. + * A slice of size 1 starting from starts[i] in the dimension must be preserved. + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *output, + const Coordinates &starts, const Coordinates &ends, const BiStrides &strides, + int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0); + + // Inherited methods overridden: + MemoryRequirements workspace() const override; }; +} // namespace experimental } // namespace arm_compute #endif /* ARM_COMPUTE_CL_STRIDED_SLICE_H */ |