diff options
Diffstat (limited to 'ConversionUtils_1_2.hpp')
-rw-r--r-- | ConversionUtils_1_2.hpp | 38 |
1 files changed, 24 insertions, 14 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 0ad50f31..824a8f4a 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -2023,26 +2023,36 @@ bool ConvertSoftmax(const HalOperation& operation, const HalModel& model, Conver } SoftmaxDescriptor desc; - if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data)) + HalOperandType outputType = outputOperand->type; + + // Read beta value + if (outputType == HalOperandType::TENSOR_FLOAT16) { - return Fail("%s: Operation has invalid inputs", __func__); - } + Half value; - if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation, - 2, - HalOperandType::INT32, - desc.m_Axis, - model, - data)) + if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data)) + { + return Fail("%s: Operation has invalid inputs %d", __func__, outputType); + } + + desc.m_Beta = static_cast<float>(value); + } + else { - return Fail("%s: Operation has invalid inputs", __func__); + if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data)) + { + return Fail("%s: Operation has invalid inputs %d", __func__, outputType); + } } - if (input.GetTensorInfo().GetNumDimensions() > 2 || - !(desc.m_Axis == 1 || - (desc.m_Axis < 0 && static_cast<int>(input.GetTensorInfo().GetNumDimensions()) + desc.m_Axis == 1))) + if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation, + 2, + HalOperandType::INT32, + desc.m_Axis, + model, + data)) { - return Fail("%s: Unsupported input greater than 2D or axis != 1", __func__); + return Fail("%s: Operation has invalid inputs", __func__); } bool isSupported = false; |