diff options
author | Francis Murtagh <francis.murtagh@arm.com> | 2019-07-22 16:40:57 +0100 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2019-07-22 16:40:57 +0100 |
commit | 074c25a1535b648fdf19d7f6648e8aceef9aa7ad (patch) | |
tree | 87fa092745acac3f7788cb7412078105b5c7e6b1 /1.2 | |
parent | 65c42dc4d68ac163b77a3139feee3e7d4530b5c5 (diff) | |
download | android-nn-driver-074c25a1535b648fdf19d7f6648e8aceef9aa7ad.tar.gz |
IVGCVSW-3140 Fix Hal 1.1 Softmax failures on Android Q
* Add support for Softmax Axis parameter.
!armnn:1567
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: I7e47d36f13b122dbad7976c0d59773845bc182b1
Diffstat (limited to '1.2')
-rw-r--r-- | 1.2/HalPolicy.cpp | 65 | ||||
-rw-r--r-- | 1.2/HalPolicy.hpp | 2 |
2 files changed, 66 insertions, 1 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 5fe54d80..3c00388c 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -46,7 +46,6 @@ bool HandledByV1_0(V1_2::OperationType operationType) case V1_0::OperationType::RELU6: case V1_0::OperationType::RESHAPE: case V1_0::OperationType::RNN: - case V1_0::OperationType::SOFTMAX: case V1_0::OperationType::SPACE_TO_DEPTH: case V1_0::OperationType::SVDF: case V1_0::OperationType::TANH: @@ -154,6 +153,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertResize(operation, model, data, armnn::ResizeMethod::Bilinear); case V1_2::OperationType::RESIZE_NEAREST_NEIGHBOR: return ConvertResize(operation, model, data, armnn::ResizeMethod::NearestNeighbor); + case V1_2::OperationType::SOFTMAX: + return ConvertSoftmax(operation, model, data); default: return Fail("%s: Operation type %s not supported in ArmnnDriver", __func__, toString(operation.type).c_str()); @@ -945,5 +946,67 @@ bool HalPolicy::ConvertSpaceToDepth(const Operation& operation, const Model& mod return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, model, data); } +bool HalPolicy::ConvertSoftmax(const Operation& operation, const Model& model, ConversionData& data) +{ + LayerInputHandle input = ConvertToLayerInputHandle<hal_1_2::HalPolicy>(operation, 0, model, data); + if (!input.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + const Operand* outputOperand = GetOutputOperand<hal_1_2::HalPolicy>(operation, 0, model); + if (!outputOperand) + { + return Fail("%s: Operation has no outputs", __func__); + } + + armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*outputOperand); + if (IsDynamicOutput(outputInfo)) + { + ALOGD("Output shape not set, will infer from input"); + outputInfo.SetShape(input.GetTensorInfo().GetShape()); + } + + armnn::SoftmaxDescriptor desc; + if (!GetInputFloat32<hal_1_2::HalPolicy>(operation, 1, desc.m_Beta, model, data)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + if (operation.inputs.size() > 2 && !GetInputScalar<hal_1_2::HalPolicy>(operation, + 2, + HalPolicy::OperandType::INT32, + desc.m_Axis, + model, + data)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsSoftmaxSupported, + data.m_Backends, + isSupported, + input.GetTensorInfo(), + outputInfo, + desc); + if (!isSupported) + { + return false; + } + + armnn::IConnectableLayer* layer = data.m_Network->AddSoftmaxLayer(desc); + assert(layer != nullptr); + input.Connect(layer->GetInputSlot(0)); + + return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, + 0, + *layer, + model, + data, + armnn::Optional<armnn::TensorInfo>(outputInfo)); +} + } // namespace hal_1_2 } // namespace armnn_driver diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp index 18cf0359..3c4906cb 100644 --- a/1.2/HalPolicy.hpp +++ b/1.2/HalPolicy.hpp @@ -48,6 +48,8 @@ private: ConversionData& data, armnn::ResizeMethod resizeMethod); + static bool ConvertSoftmax(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertSpaceToDepth(const Operation& operation, const Model& model, ConversionData& data); }; |