diff options
Diffstat (limited to '1.2/HalPolicy.cpp')
-rw-r--r-- | 1.2/HalPolicy.cpp | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index ac78e96b..ddd85d9b 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -1887,10 +1887,29 @@ bool HalPolicy::ConvertResize(const Operation& operation, descriptor.m_TargetWidth = std::floor(width * widthScale); descriptor.m_TargetHeight = std::floor(height * heightScale); } + else if (operandType1 == OperandType::FLOAT16) + { + Half widthScale; + Half heightScale; + + if (!GetInputScalar<HalPolicy>(operation, 1, HalPolicy::OperandType::FLOAT16, widthScale, model, data) || + !GetInputScalar<HalPolicy>(operation, 2, HalPolicy::OperandType::FLOAT16, heightScale, model, data)) + { + return Fail("%s: Operation has invalid inputs for resizing by scale", __func__); + } + + const TensorShape& inputShape = inputInfo.GetShape(); + armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout); + + Half width = static_cast<Half>(inputShape[dataLayoutIndexed.GetWidthIndex()]); + Half height = static_cast<Half>(inputShape[dataLayoutIndexed.GetHeightIndex()]); + + descriptor.m_TargetWidth = std::floor(width * widthScale); + descriptor.m_TargetHeight = std::floor(height * heightScale); + } else { - // NOTE: FLOAT16 scales are not supported - return false; + return Fail("%s: Operand has invalid data type for resizing by scale", __func__); } bool isSupported = false; |