diff options
Diffstat (limited to 'ConversionUtils_1_3.hpp')
-rw-r--r-- | ConversionUtils_1_3.hpp | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp index a7f00fc3..150735e9 100644 --- a/ConversionUtils_1_3.hpp +++ b/ConversionUtils_1_3.hpp @@ -153,6 +153,79 @@ bool ConvertFill(const HalOperation& operation, const HalModel& model, Conversio template<typename HalPolicy, typename HalOperation = typename HalPolicy::Operation, typename HalModel = typename HalPolicy::Model> +bool ConvertLogicalBinary(const HalOperation& operation, + const HalModel& model, + ConversionData& data, + LogicalBinaryOperation logicalOperation) +{ + using HalOperand = typename HalPolicy::Operand; + + ALOGV("HalPolicy::ConvertLogicalBinary()"); + ALOGV("logicalOperation = %s", GetLogicalBinaryOperationAsCString(logicalOperation)); + + LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data); + LayerInputHandle input1 = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data); + + if (!(input0.IsValid() && input1.IsValid())) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model); + if (!output) + { + return Fail("%s: Could not read output 0", __func__); + } + + const TensorInfo& inputInfo0 = input0.GetTensorInfo(); + const TensorInfo& inputInfo1 = input1.GetTensorInfo(); + const TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + + LogicalBinaryDescriptor descriptor(logicalOperation); + + bool isSupported = false; + + auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsLogicalBinarySupported, + data.m_Backends, + isSupported, + inputInfo0, + inputInfo1, + outputInfo, + descriptor); + }; + + if(!IsDynamicTensor(outputInfo)) + { + validateFunc(outputInfo, isSupported); + } + else + { + isSupported = AreDynamicTensorsSupported(); + } + + if (!isSupported) + { + return false; + } + + IConnectableLayer* layer = data.m_Network->AddLogicalBinaryLayer(descriptor); + assert(layer != nullptr); + + bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data); + if (!isReshapeSupported) + { + return false; + } + + return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc); +} + +template<typename HalPolicy, + typename HalOperation = typename HalPolicy::Operation, + typename HalModel = typename HalPolicy::Model> bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, ConversionData& data) { using HalOperand = typename HalPolicy::Operand; |