aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils_1_2.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils_1_2.hpp')
-rw-r--r--ConversionUtils_1_2.hpp38
1 files changed, 24 insertions, 14 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp
index 0ad50f31..824a8f4a 100644
--- a/ConversionUtils_1_2.hpp
+++ b/ConversionUtils_1_2.hpp
@@ -2023,26 +2023,36 @@ bool ConvertSoftmax(const HalOperation& operation, const HalModel& model, Conver
}
SoftmaxDescriptor desc;
- if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data))
+ HalOperandType outputType = outputOperand->type;
+
+ // Read beta value
+ if (outputType == HalOperandType::TENSOR_FLOAT16)
{
- return Fail("%s: Operation has invalid inputs", __func__);
- }
+ Half value;
- if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation,
- 2,
- HalOperandType::INT32,
- desc.m_Axis,
- model,
- data))
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
+ }
+
+ desc.m_Beta = static_cast<float>(value);
+ }
+ else
{
- return Fail("%s: Operation has invalid inputs", __func__);
+ if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
+ }
}
- if (input.GetTensorInfo().GetNumDimensions() > 2 ||
- !(desc.m_Axis == 1 ||
- (desc.m_Axis < 0 && static_cast<int>(input.GetTensorInfo().GetNumDimensions()) + desc.m_Axis == 1)))
+ if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation,
+ 2,
+ HalOperandType::INT32,
+ desc.m_Axis,
+ model,
+ data))
{
- return Fail("%s: Unsupported input greater than 2D or axis != 1", __func__);
+ return Fail("%s: Operation has invalid inputs", __func__);
}
bool isSupported = false;