From 38e1294770598d49e2a4542d96be0491918546bb Mon Sep 17 00:00:00 2001 From: David Beck Date: Wed, 12 Sep 2018 16:02:24 +0100 Subject: IVGCVSW-1805: model converter functions for the Android 1.1 SUB operator Change-Id: I939cb8d26766c93ee8e3c92909d69549328b0ea7 --- 1.1/HalPolicy.cpp | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1.1/HalPolicy.hpp | 1 + 2 files changed, 54 insertions(+) (limited to '1.1') diff --git a/1.1/HalPolicy.cpp b/1.1/HalPolicy.cpp index 0e669432..857d29bb 100644 --- a/1.1/HalPolicy.cpp +++ b/1.1/HalPolicy.cpp @@ -27,6 +27,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, { case V1_1::OperationType::DIV: return ConvertDiv(operation, model, data); + case V1_1::OperationType::SUB: + return ConvertSub(operation, model, data); default: return Fail("%s: Operation type %s not supported in ArmnnDriver", __func__, toString(operation.type).c_str()); @@ -85,5 +87,56 @@ bool HalPolicy::ConvertDiv(const Operation& operation, const Model& model, Conve return Fail("%s: ProcessActivation failed", __func__); } +bool HalPolicy::ConvertSub(const Operation& operation, const Model& model, ConversionData& data) +{ + LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data); + LayerInputHandle input1 = ConvertToLayerInputHandle(operation, 1, model, data); + + if (!input0.IsValid() || !input1.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + // The FuseActivation parameter is always the input index 2 + // and it should be optional + ActivationFn activationFunction; + if (!GetOptionalInputActivation(operation, 2, activationFunction, model, data)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + const Operand* outputOperand = GetOutputOperand(operation, 0, model); + if (!outputOperand) + { + return false; + } + + const armnn::TensorInfo& outInfo = GetTensorInfoForOperand(*outputOperand); + + if (!IsLayerSupported(__func__, + armnn::IsSubtractionSupported, + data.m_Compute, + input0.GetTensorInfo(), + input1.GetTensorInfo(), + outInfo)) + { + return false; + } + + armnn::IConnectableLayer* const startLayer = data.m_Network->AddSubtractionLayer(); + armnn::IConnectableLayer* const endLayer = ProcessActivation(outInfo, activationFunction, startLayer, data); + + const armnn::TensorInfo& inputTensorInfo0 = input0.GetTensorInfo(); + const armnn::TensorInfo& inputTensorInfo1 = input1.GetTensorInfo(); + + if (endLayer) + { + BroadcastTensor(input0, input1, startLayer, *data.m_Network); + return SetupAndTrackLayerOutputSlot(operation, 0, *endLayer, model, data); + } + + return Fail("%s: ProcessActivation failed", __func__); +} + } // namespace hal_1_1 } // namespace armnn_driver \ No newline at end of file diff --git a/1.1/HalPolicy.hpp b/1.1/HalPolicy.hpp index af858781..3b7fe541 100644 --- a/1.1/HalPolicy.hpp +++ b/1.1/HalPolicy.hpp @@ -26,6 +26,7 @@ public: private: static bool ConvertDiv(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertSub(const Operation& operation, const Model& model, ConversionData& data); }; } // namespace hal_1_1 -- cgit v1.2.1