aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2019-10-29 13:13:19 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-31 22:26:59 +0000
commit62eeb53a5eee9d388a6074553175909fd1b441b5 (patch)
tree62e051ba5b4f73adb5ba909d623fd0323d2704e9 /arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
parent44bfc3fe8dacfc4297702ca88323ea675a7c52e2 (diff)
downloadComputeLibrary-62eeb53a5eee9d388a6074553175909fd1b441b5.tar.gz
COMPMID-2266: [CL] add support for Log Softmax
Change-Id: I4a8f3519328553e24cbb4fe45a8ca4d47c90975d Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/2182 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLSoftmaxLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLSoftmaxLayer.h13
1 files changed, 10 insertions, 3 deletions
diff --git a/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h b/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
index 407827087c..e3feebb762 100644
--- a/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
+++ b/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
@@ -43,16 +43,20 @@ class ICLTensor;
* Softmax is calculated by :
* @f[ out = exp((x - max(x)) * beta) / sum(exp((x - max(x)) * beta)) @f]
*
+ * Log Softmax is calculated by :
+ * @f[ out = (x - max(x) * beta) - \sum{e^{x - max(x) * beta}} @f]
+ *
* This function runs the following kernels:
* -# @ref CLLogits1DMaxKernel
* -# @ref CLLogits1DShiftExpSumKernel
* -# @ref CLLogits1DNormKernel
*/
-class CLSoftmaxLayer : public IFunction
+template <bool IS_LOG = false>
+class CLSoftmaxLayerGeneric : public IFunction
{
public:
/** Constructor */
- CLSoftmaxLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+ CLSoftmaxLayerGeneric(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
/** Set the input and output tensors.
*
* @param[in] input Source tensor. Data types supported: QASYMM8/F16/F32
@@ -106,5 +110,8 @@ private:
CLTensor _output_flattened;
bool _needs_flattening;
};
-}
+
+using CLSoftmaxLayer = CLSoftmaxLayerGeneric<false>;
+using CLLogSoftmaxLayer = CLSoftmaxLayerGeneric<true>;
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLSOFTMAXLAYER_H__ */