From bf866e2dc7bd5936788fe213b5c0f74483ec1532 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Sun, 9 Aug 2020 23:55:01 +0100 Subject: IVGCVSW-3568 Eliminate rank and axis restriction in Softmax. * Restriction in axis will be now part of ACL. Currently, ACL only supports axis = 0, which translates to axis = -1 in ArmNN and Android. * Beta must be Float16 when input/output are Float16 !armnn:3690 Signed-off-by: Teresa Charlin Change-Id: I2645a005840e17703367b3ec7e9ed91e83a2f6c7 --- ConversionUtils_1_2.hpp | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 0ad50f31..824a8f4a 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -2023,26 +2023,36 @@ bool ConvertSoftmax(const HalOperation& operation, const HalModel& model, Conver } SoftmaxDescriptor desc; - if (!GetInputFloat32(operation, 1, desc.m_Beta, model, data)) + HalOperandType outputType = outputOperand->type; + + // Read beta value + if (outputType == HalOperandType::TENSOR_FLOAT16) { - return Fail("%s: Operation has invalid inputs", __func__); - } + Half value; - if (operation.inputs.size() > 2 && !GetInputScalar(operation, - 2, - HalOperandType::INT32, - desc.m_Axis, - model, - data)) + if (!GetInputScalar(operation, 1, HalOperandType::FLOAT16, value, model, data)) + { + return Fail("%s: Operation has invalid inputs %d", __func__, outputType); + } + + desc.m_Beta = static_cast(value); + } + else { - return Fail("%s: Operation has invalid inputs", __func__); + if (!GetInputFloat32(operation, 1, desc.m_Beta, model, data)) + { + return Fail("%s: Operation has invalid inputs %d", __func__, outputType); + } } - if (input.GetTensorInfo().GetNumDimensions() > 2 || - !(desc.m_Axis == 1 || - (desc.m_Axis < 0 && static_cast(input.GetTensorInfo().GetNumDimensions()) + desc.m_Axis == 1))) + if (operation.inputs.size() > 2 && !GetInputScalar(operation, + 2, + HalOperandType::INT32, + desc.m_Axis, + model, + data)) { - return Fail("%s: Unsupported input greater than 2D or axis != 1", __func__); + return Fail("%s: Operation has invalid inputs", __func__); } bool isSupported = false; -- cgit v1.2.1