aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-08-09 23:55:01 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-08-10 10:50:20 +0000
commitbf866e2dc7bd5936788fe213b5c0f74483ec1532 (patch)
tree30865eba7db754e5a8ba629a7b757cc2ce233ce3
parentf057e6ff03c88c169a0e2996108bde7b3d65273c (diff)
downloadandroid-nn-driver-bf866e2dc7bd5936788fe213b5c0f74483ec1532.tar.gz
IVGCVSW-3568 Eliminate rank and axis restriction in Softmax.
* Restriction in axis will be now part of ACL. Currently, ACL only supports axis = 0, which translates to axis = -1 in ArmNN and Android. * Beta must be Float16 when input/output are Float16 !armnn:3690 Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I2645a005840e17703367b3ec7e9ed91e83a2f6c7
-rw-r--r--ConversionUtils_1_2.hpp38
1 files changed, 24 insertions, 14 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp
index 0ad50f31..824a8f4a 100644
--- a/ConversionUtils_1_2.hpp
+++ b/ConversionUtils_1_2.hpp
@@ -2023,26 +2023,36 @@ bool ConvertSoftmax(const HalOperation& operation, const HalModel& model, Conver
}
SoftmaxDescriptor desc;
- if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data))
+ HalOperandType outputType = outputOperand->type;
+
+ // Read beta value
+ if (outputType == HalOperandType::TENSOR_FLOAT16)
{
- return Fail("%s: Operation has invalid inputs", __func__);
- }
+ Half value;
- if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation,
- 2,
- HalOperandType::INT32,
- desc.m_Axis,
- model,
- data))
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
+ }
+
+ desc.m_Beta = static_cast<float>(value);
+ }
+ else
{
- return Fail("%s: Operation has invalid inputs", __func__);
+ if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
+ }
}
- if (input.GetTensorInfo().GetNumDimensions() > 2 ||
- !(desc.m_Axis == 1 ||
- (desc.m_Axis < 0 && static_cast<int>(input.GetTensorInfo().GetNumDimensions()) + desc.m_Axis == 1)))
+ if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation,
+ 2,
+ HalOperandType::INT32,
+ desc.m_Axis,
+ model,
+ data))
{
- return Fail("%s: Unsupported input greater than 2D or axis != 1", __func__);
+ return Fail("%s: Operation has invalid inputs", __func__);
}
bool isSupported = false;