diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2020-07-16 17:46:51 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2020-07-21 16:22:21 +0000 |
commit | f20d6d6ae5a0da2c856294e93341cdc065db58f9 (patch) | |
tree | 0e11a924371691fb9d345cd9362cd8cc06662432 /arm_compute/runtime/CL/functions/CLStridedSlice.h | |
parent | c6eaec3610fa27651582f6c1acad35afffe360f6 (diff) | |
download | ComputeLibrary-f20d6d6ae5a0da2c856294e93341cdc065db58f9.tar.gz |
COMPMID-3390: Async support to CLStridedSliceLayerKernel kernels/functions
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Change-Id: I9ff7e8d2fb4d36c4b7c44e885abf34ff6d4c577c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3587
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLStridedSlice.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLStridedSlice.h | 76 |
1 files changed, 74 insertions, 2 deletions
diff --git a/arm_compute/runtime/CL/functions/CLStridedSlice.h b/arm_compute/runtime/CL/functions/CLStridedSlice.h index 98a3bd49d3..394d8c4f59 100644 --- a/arm_compute/runtime/CL/functions/CLStridedSlice.h +++ b/arm_compute/runtime/CL/functions/CLStridedSlice.h @@ -24,7 +24,9 @@ #ifndef ARM_COMPUTE_CL_STRIDED_SLICE_H #define ARM_COMPUTE_CL_STRIDED_SLICE_H -#include "arm_compute/runtime/CL/ICLSimpleFunction.h" +#include "arm_compute/runtime/CL/CLRuntimeContext.h" +#include "arm_compute/runtime/CL/ICLOperator.h" +#include "arm_compute/runtime/IFunction.h" namespace arm_compute { @@ -32,9 +34,24 @@ namespace arm_compute class ICLTensor; /** Basic function to run @ref CLStridedSliceKernel */ -class CLStridedSlice : public ICLSimpleFunction +class CLStridedSlice : public IFunction { public: + /** Constructor + * + * @param[in] ctx Runtime context to be used by the function + */ + CLStridedSlice(CLRuntimeContext *ctx = nullptr); + /** Destructor */ + ~CLStridedSlice(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLStridedSlice(const CLStridedSlice &) = delete; + /** Default move constructor */ + CLStridedSlice(CLStridedSlice &&); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLStridedSlice &operator=(const CLStridedSlice &) = delete; + /** Default move assignment operator */ + CLStridedSlice &operator=(CLStridedSlice &&); /** Configure kernel * * @note Supported tensor rank: up to 4 @@ -88,6 +105,61 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Coordinates &starts, const Coordinates &ends, const BiStrides &strides, int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0); + + // Inherited methods overridden: + void run() override; + +private: + struct Impl; + std::unique_ptr<Impl> _impl; +}; + +namespace experimental +{ +/** Basic function to run @ref CLStridedSliceKernel */ +class CLStridedSlice : public ICLOperator +{ +public: + /** Configure kernel + * + * @note Supported tensor rank: up to 4 + * + * @param[in] compile_context The compile context to be used. + * @param[in] input Source tensor info. Data type supported: All. + * @param[out] output Destination tensor info. Data type supported: Same as @p input + * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] ends The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] strides The strides of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] end_mask (Optional) If the ith bit of end_mask is set, ends[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1. + * A slice of size 1 starting from starts[i] in the dimension must be preserved. + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo *input, ITensorInfo *output, + const Coordinates &starts, const Coordinates &ends, const BiStrides &strides, + int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0); + + /** Static function to check if given info will lead to a valid configuration of @ref CLStridedSlice + * + * @note Supported tensor rank: up to 4 + * + * @param[in] input Source tensor info. Data type supported: All. + * @param[in] output Destination tensor info. Data type supported: Same as @p input + * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] ends The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] strides The strides of the dimensions of the input tensor to be sliced. The length must be of rank(input). + * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] end_mask (Optional) If the ith bit of end_mask is set, ends[i] is ignored and the fullest possible range in that dimension is used instead. + * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1. + * A slice of size 1 starting from starts[i] in the dimension must be preserved. + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *output, + const Coordinates &starts, const Coordinates &ends, const BiStrides &strides, + int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0); + + // Inherited methods overridden: + MemoryRequirements workspace() const override; }; +} // namespace experimental } // namespace arm_compute #endif /* ARM_COMPUTE_CL_STRIDED_SLICE_H */ |