// // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "DelegateUtils.hpp" #include #include #include #include namespace armnnDelegate { TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, const armnn::TensorInfo& inputInfo, const armnn::TensorInfo& outputTensorInfo, const armnn::SoftmaxDescriptor& descriptor) { bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, tfLiteContext, IsSoftmaxSupported, delegateData.m_Backends, isSupported, inputInfo, outputTensorInfo, descriptor); return isSupported ? kTfLiteOk : kTfLiteError; } TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, const armnn::TensorInfo& inputInfo, const armnn::TensorInfo& outputTensorInfo, const armnn::LogSoftmaxDescriptor& descriptor) { bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, tfLiteContext, IsLogSoftmaxSupported, delegateData.m_Backends, isSupported, inputInfo, outputTensorInfo, descriptor); return isSupported ? kTfLiteOk : kTfLiteError; } TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode, int nodeIndex, int32_t softmaxOperatorCode) { TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors; const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]]; if (IsDynamicTensor(tfLiteInputTensor)) { TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext, "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ", nodeIndex); return kTfLiteError; } const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]]; if (IsDynamicTensor(tfLiteOutputTensor)) { TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext, "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ", nodeIndex); return kTfLiteError; } const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); if (!delegateData.m_Network) { switch(softmaxOperatorCode) { case kTfLiteBuiltinSoftmax: { armnn::SoftmaxDescriptor descriptor; auto* params = reinterpret_cast(tfLiteNode->builtin_data); descriptor.m_Beta = params->beta; return ValidateSoftmaxOperator(delegateData, tfLiteContext, inputTensorInfo, outputTensorInfo, descriptor); } case kTfLiteBuiltinLogSoftmax: { armnn::LogSoftmaxDescriptor descriptor; return ValidateLogSoftmaxOperator(delegateData, tfLiteContext, inputTensorInfo, outputTensorInfo, descriptor); } default: return kTfLiteError; } } armnn::IConnectableLayer* softmaxLayer = nullptr; switch(softmaxOperatorCode) { case kTfLiteBuiltinSoftmax: { armnn::SoftmaxDescriptor descriptor; auto* params = reinterpret_cast(tfLiteNode->builtin_data); descriptor.m_Beta = params->beta; softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor); break; } case kTfLiteBuiltinLogSoftmax: { armnn::LogSoftmaxDescriptor descriptor; softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor); break; } default: return kTfLiteError; } ARMNN_ASSERT(softmaxLayer != nullptr); armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0); outputSlot.SetTensorInfo(outputTensorInfo); // Connect return Connect(softmaxLayer, tfLiteNode, delegateData); } } // namespace armnnDelegate