diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-08-09 23:55:01 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2020-08-10 10:50:20 +0000 |
commit | bf866e2dc7bd5936788fe213b5c0f74483ec1532 (patch) | |
tree | 30865eba7db754e5a8ba629a7b757cc2ce233ce3 /ConversionUtils_1_2.hpp | |
parent | f057e6ff03c88c169a0e2996108bde7b3d65273c (diff) | |
download | android-nn-driver-bf866e2dc7bd5936788fe213b5c0f74483ec1532.tar.gz |
IVGCVSW-3568 Eliminate rank and axis restriction in Softmax.
* Restriction in axis will be now part of ACL. Currently, ACL only
supports axis = 0, which translates to axis = -1 in ArmNN and Android.
* Beta must be Float16 when input/output are Float16
!armnn:3690
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I2645a005840e17703367b3ec7e9ed91e83a2f6c7
Diffstat (limited to 'ConversionUtils_1_2.hpp')
-rw-r--r-- | ConversionUtils_1_2.hpp | 38 |
1 files changed, 24 insertions, 14 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 0ad50f31..824a8f4a 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -2023,26 +2023,36 @@ bool ConvertSoftmax(const HalOperation& operation, const HalModel& model, Conver } SoftmaxDescriptor desc; - if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data)) + HalOperandType outputType = outputOperand->type; + + // Read beta value + if (outputType == HalOperandType::TENSOR_FLOAT16) { - return Fail("%s: Operation has invalid inputs", __func__); - } + Half value; - if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation, - 2, - HalOperandType::INT32, - desc.m_Axis, - model, - data)) + if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data)) + { + return Fail("%s: Operation has invalid inputs %d", __func__, outputType); + } + + desc.m_Beta = static_cast<float>(value); + } + else { - return Fail("%s: Operation has invalid inputs", __func__); + if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data)) + { + return Fail("%s: Operation has invalid inputs %d", __func__, outputType); + } } - if (input.GetTensorInfo().GetNumDimensions() > 2 || - !(desc.m_Axis == 1 || - (desc.m_Axis < 0 && static_cast<int>(input.GetTensorInfo().GetNumDimensions()) + desc.m_Axis == 1))) + if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation, + 2, + HalOperandType::INT32, + desc.m_Axis, + model, + data)) { - return Fail("%s: Unsupported input greater than 2D or axis != 1", __func__); + return Fail("%s: Operation has invalid inputs", __func__); } bool isSupported = false; |