aboutsummaryrefslogtreecommitdiff
path: root/arm_compute
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute')
-rw-r--r--arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h3
-rw-r--r--arm_compute/core/KernelDescriptors.h5
-rw-r--r--arm_compute/core/Utils.h9
3 files changed, 14 insertions, 3 deletions
diff --git a/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h b/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h
index 93e403e257..f64739ae32 100644
--- a/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h
+++ b/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h
@@ -187,10 +187,11 @@ public:
* @param[in] input Source tensor. Data types supported: S32/F16/F32
* @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input
* @param[in] output Destination tensor. Data types supported: QASYMM8 for S32 @p input, or same as @p input
+ * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo.
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h
index 2aa076246e..f358153b0d 100644
--- a/arm_compute/core/KernelDescriptors.h
+++ b/arm_compute/core/KernelDescriptors.h
@@ -79,8 +79,9 @@ struct DWCWeightsKernelInfo
/** Descriptor used by the softmax kernels */
struct SoftmaxKernelInfo
{
- float beta{ 1.f }; /**< A scaling factor for the exponent with default value 1.0 */
- bool is_log{ false }; /**< Flag used to perform Log Softmax operation */
+ float beta{ 1.f }; /**< A scaling factor for the exponent with default value 1.0 */
+ bool is_log{ false }; /**< Flag used to perform Log Softmax operation */
+ DataType input_data_type{ DataType::UNKNOWN }; /**< Input tensor data type */
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CORE_KERNEL_DESCRIPTORS_H */
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 18c5471f8f..0a7eeefded 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -958,6 +958,15 @@ std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsi
*/
bool needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis);
+/** Returns output quantization information for softmax layer
+ *
+ * @param[in] input_type The data type of the input tensor
+ * @param[in] is_log True for log softmax
+ *
+ * @return Quantization information for the output tensor
+ */
+QuantizationInfo get_softmax_output_quantization_info(DataType input_type, bool is_log);
+
/** Convert a tensor format into a string.
*
* @param[in] format @ref Format to be translated to string.