diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-10-15 13:33:03 +0100 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-10-15 13:33:03 +0100 |
commit | 75e677939a98298e50d65c4a7e99a03fb51d5e3c (patch) | |
tree | 69a32790256c26a94a45f31641179a37373f0bc8 /1.2/HalPolicy.cpp | |
parent | ad7545385e42bd183f6d27bfea30a57f76a1fb46 (diff) | |
download | android-nn-driver-75e677939a98298e50d65c4a7e99a03fb51d5e3c.tar.gz |
IVGCVSW-3894 Add support for LOG_SOFTMAX to the HAL 1.2 driver
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I59645b339f3b176e5d0852769acb95f5657101d3
Diffstat (limited to '1.2/HalPolicy.cpp')
-rw-r--r-- | 1.2/HalPolicy.cpp | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 019f5054..55df9dab 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -64,6 +64,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertLocalResponseNormalization(operation, model, data); case V1_2::OperationType::LOGISTIC: return ConvertLogistic(operation, model, data); + case V1_2::OperationType::LOG_SOFTMAX: + return ConvertLogSoftmax(operation, model, data); case V1_2::OperationType::LSTM: return ConvertLstm(operation, model, data); case V1_2::OperationType::MAX_POOL_2D: @@ -998,6 +1000,90 @@ bool HalPolicy::ConvertLogistic(const Operation& operation, const Model& model, return ::ConvertLogistic<hal_1_2::HalPolicy>(operation, model, data); } +bool HalPolicy::ConvertLogSoftmax(const Operation& operation, const Model& model, ConversionData& data) +{ + ALOGV("hal_1_2::HalPolicy::ConvertLogSoftmax()"); + + LayerInputHandle input = ConvertToLayerInputHandle<hal_1_2::HalPolicy>(operation, 0, model, data); + if (!input.IsValid()) + { + return Fail("%s: Failed to read input 0", __func__); + } + + const Operand* output = GetOutputOperand<hal_1_2::HalPolicy>(operation, 0, model); + if (!output) + { + return Fail("%s: Failed to read output", __func__); + } + + const TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + if (IsDynamicTensor(outputInfo)) + { + return Fail("%s: Dynamic output tensors are not supported", __func__); + } + + // Determine data type of input tensor + OperandType inputType; + if (!GetOperandType<hal_1_2::HalPolicy>(operation, 0, model, inputType)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + LogSoftmaxDescriptor descriptor; + + // Read beta + if (inputType == OperandType::TENSOR_FLOAT16) + { + Half fp16Beta; + if (!GetInputScalar<hal_1_2::HalPolicy>(operation, 1, OperandType::FLOAT16, fp16Beta, model, data)) + { + return Fail("%s: Failed to read input 1 (FLOAT16)", __func__); + } + + descriptor.m_Beta = static_cast<float>(fp16Beta); + } + else if (inputType == OperandType::TENSOR_FLOAT32) + { + if (!GetInputScalar<hal_1_2::HalPolicy>(operation, 1, OperandType::FLOAT32, descriptor.m_Beta, model, data)) + { + return Fail("%s: Failed to read input 1 (FLOAT32)", __func__); + } + } + else + { + return Fail("%s: Unsupported input tensor type: %d", __func__, inputType); + } + + // Read axis + if (!GetInputInt32<hal_1_2::HalPolicy>(operation, 2, descriptor.m_Axis, model, data)) + { + return Fail("%s: Failed to read input 2", __func__); + } + + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsLogSoftmaxSupported, + data.m_Backends, + isSupported, + input.GetTensorInfo(), + outputInfo, + descriptor); + if (!isSupported) + { + return false; + } + + armnn::IConnectableLayer* layer = data.m_Network->AddLogSoftmaxLayer(descriptor); + if (!layer) + { + return Fail("%s: AddLogSoftmaxLayer() returned nullptr", __func__); + } + + input.Connect(layer->GetInputSlot(0)); + + return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data); +} + bool HalPolicy::ConvertMaxPool2d(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertMaxPool2d()"); |