aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClSoftmaxKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClSoftmaxKernel.h')
-rw-r--r--src/gpu/cl/kernels/ClSoftmaxKernel.h103
1 files changed, 29 insertions, 74 deletions
diff --git a/src/gpu/cl/kernels/ClSoftmaxKernel.h b/src/gpu/cl/kernels/ClSoftmaxKernel.h
index 2dd53da346..130dc7835c 100644
--- a/src/gpu/cl/kernels/ClSoftmaxKernel.h
+++ b/src/gpu/cl/kernels/ClSoftmaxKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,11 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CL_SOFTMAX_KERNEL_H
-#define ARM_COMPUTE_CL_SOFTMAX_KERNEL_H
+#ifndef ACL_SRC_GPU_CL_KERNELS_CLSOFTMAXKERNEL_H
+#define ACL_SRC_GPU_CL_KERNELS_CLSOFTMAXKERNEL_H
#include "arm_compute/core/Error.h"
#include "arm_compute/core/KernelDescriptors.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Window.h"
#include "src/core/common/Macros.h"
#include "src/gpu/cl/ClCompileContext.h"
@@ -37,94 +39,47 @@ namespace opencl
{
namespace kernels
{
-/** Interface for max, shifting, exponentiating and summing the logits */
-class ClLogits1DMaxShiftExpSumKernel : public IClKernel
-{
- /**< Grid size (obtained through auto-tuning) */
- static const unsigned int _grid_size;
- /**< Vector size in the serial case (obtained through auto-tuning) */
- static const unsigned int _serial_vector_size;
- /**< Vector size in the parallel case (obtained through auto-tuning, enables the best memory access pattern for Bifrost) .*/
- static const unsigned int _parallel_vector_size;
+/** The CL kernel that performs softmax function. */
+class ClSoftmaxKernel : public IClKernel
+{
public:
- /** Info for whether a parallel reduction will be run and the vector size of the execution. */
- using ParallelReductionInfo = std::tuple<bool, unsigned int>;
+ ClSoftmaxKernel();
- ClLogits1DMaxShiftExpSumKernel();
- ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClLogits1DMaxShiftExpSumKernel);
- /** Configure the kernel using the given information about tensors
- *
- * @param[in] compile_context The compile context to be used.
- * @param[in] src Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
- * @param[in,out] max Max values tensor. Data types supported: same as @p src
- * @param[out] dst Destination tensor. Data types supported: same as @p src
- * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p src
- * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo.
- */
- void configure(const CLCompileContext &compile_context,
- const ITensorInfo &src,
- ITensorInfo &max,
- ITensorInfo &dst,
- ITensorInfo &sum,
- const SoftmaxKernelInfo &info);
- /** Static function to check if given info will lead to a valid configuration
- *
- * Similar to @ref ClLogits1DMaxShiftExpSumKernel::configure()
- *
- * @return a status
- */
- static Status
- validate(const ITensorInfo &src, const ITensorInfo &max, const ITensorInfo &dst, const ITensorInfo &sum);
- /** Checks if the given size is eligible for parallel reduction
- *
- * @note Serial reduction is launched for width < (_grid_size * _serial_vector_size).
- * @note Parallel reduction is launched for width >= (_grid_size * _serial_vector_size) and vector_size is forced to 4.
+ ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClSoftmaxKernel);
+
+ /** Check if the kernel arguments are valid.
*
- * @param[in] size Size to check
+ * See @ref ClSoftmaxKernel::configure().
*
- * @return A two-element tuple where the first element is a boolean specifying if a parallel reduction will be run,
- * while the second element is the vector size of the execution.
+ * @return The status.
*/
- static ParallelReductionInfo is_parallel_reduction(size_t size);
-
- // Inherited methods overridden:
- void run_op(ITensorPack &tensors, const Window &window, ::cl::CommandQueue &queue) override;
-};
-
-/** Interface for calculating the final step of the Softmax Layer where each logit value is multiplied by the inverse of the sum of the logits. */
-class ClLogits1DNormKernel : public IClKernel
-{
-public:
- ClLogits1DNormKernel();
- ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClLogits1DNormKernel);
+ static Status validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info);
- /** Set the input and output tensors.
+ /** Configure the kernel.
*
* @param[in] compile_context The compile context to be used.
- * @param[in] src Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported.
- * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input
- * @param[out] dst Destination tensor. Data types supported: QASYMM8/QASYMM8_SIGNED for S32 @p input, or same as @p input
+ * @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 for Softmax and F16/F32 for Log Softmax
+ * @param[out] dst Destination tensor info. Data types supported: same as @p src
* @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo.
*/
void configure(const CLCompileContext &compile_context,
const ITensorInfo &src,
- const ITensorInfo &sum,
ITensorInfo &dst,
const SoftmaxKernelInfo &info);
- /** Static function to check if given info will lead to a valid configuration
- *
- * Similar to @ref ClLogits1DNormKernel::configure()
- *
- * @return a status
- */
- static Status
- validate(const ITensorInfo &src, const ITensorInfo &sum, const ITensorInfo &dst, const SoftmaxKernelInfo &info);
- // Inherited methods overridden:
- void run_op(ITensorPack &tensors, const Window &window, ::cl::CommandQueue &queue) override;
+ void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
+
+ /** Get the tensor info of the temporary tensor. */
+ const TensorInfo &tmp_tensor_info() const;
+
+private:
+ bool _prepared{false};
+ int32_t _axis{0};
+ TensorInfo _tmp_info{};
};
+
} // namespace kernels
} // namespace opencl
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CL_SOFTMAX_KERNEL_H */
+#endif // ACL_SRC_GPU_CL_KERNELS_CLSOFTMAXKERNEL_H