aboutsummaryrefslogtreecommitdiff
path: root/1.1
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-09-12 16:02:24 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-09-18 12:40:42 +0100
commit38e1294770598d49e2a4542d96be0491918546bb (patch)
treec57841da5f5f6029454488729f894742e93f397b /1.1
parentd11c52311b338298c1577a0934c6af40bb53d9dc (diff)
downloadandroid-nn-driver-38e1294770598d49e2a4542d96be0491918546bb.tar.gz
IVGCVSW-1805: model converter functions for the Android 1.1 SUB operator
Change-Id: I939cb8d26766c93ee8e3c92909d69549328b0ea7
Diffstat (limited to '1.1')
-rw-r--r--1.1/HalPolicy.cpp53
-rw-r--r--1.1/HalPolicy.hpp1
2 files changed, 54 insertions, 0 deletions
diff --git a/1.1/HalPolicy.cpp b/1.1/HalPolicy.cpp
index 0e669432..857d29bb 100644
--- a/1.1/HalPolicy.cpp
+++ b/1.1/HalPolicy.cpp
@@ -27,6 +27,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
{
case V1_1::OperationType::DIV:
return ConvertDiv(operation, model, data);
+ case V1_1::OperationType::SUB:
+ return ConvertSub(operation, model, data);
default:
return Fail("%s: Operation type %s not supported in ArmnnDriver",
__func__, toString(operation.type).c_str());
@@ -85,5 +87,56 @@ bool HalPolicy::ConvertDiv(const Operation& operation, const Model& model, Conve
return Fail("%s: ProcessActivation failed", __func__);
}
+bool HalPolicy::ConvertSub(const Operation& operation, const Model& model, ConversionData& data)
+{
+ LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data);
+ LayerInputHandle input1 = ConvertToLayerInputHandle(operation, 1, model, data);
+
+ if (!input0.IsValid() || !input1.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ // The FuseActivation parameter is always the input index 2
+ // and it should be optional
+ ActivationFn activationFunction;
+ if (!GetOptionalInputActivation(operation, 2, activationFunction, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const Operand* outputOperand = GetOutputOperand(operation, 0, model);
+ if (!outputOperand)
+ {
+ return false;
+ }
+
+ const armnn::TensorInfo& outInfo = GetTensorInfoForOperand(*outputOperand);
+
+ if (!IsLayerSupported(__func__,
+ armnn::IsSubtractionSupported,
+ data.m_Compute,
+ input0.GetTensorInfo(),
+ input1.GetTensorInfo(),
+ outInfo))
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* const startLayer = data.m_Network->AddSubtractionLayer();
+ armnn::IConnectableLayer* const endLayer = ProcessActivation(outInfo, activationFunction, startLayer, data);
+
+ const armnn::TensorInfo& inputTensorInfo0 = input0.GetTensorInfo();
+ const armnn::TensorInfo& inputTensorInfo1 = input1.GetTensorInfo();
+
+ if (endLayer)
+ {
+ BroadcastTensor(input0, input1, startLayer, *data.m_Network);
+ return SetupAndTrackLayerOutputSlot(operation, 0, *endLayer, model, data);
+ }
+
+ return Fail("%s: ProcessActivation failed", __func__);
+}
+
} // namespace hal_1_1
} // namespace armnn_driver \ No newline at end of file
diff --git a/1.1/HalPolicy.hpp b/1.1/HalPolicy.hpp
index af858781..3b7fe541 100644
--- a/1.1/HalPolicy.hpp
+++ b/1.1/HalPolicy.hpp
@@ -26,6 +26,7 @@ public:
private:
static bool ConvertDiv(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertSub(const Operation& operation, const Model& model, ConversionData& data);
};
} // namespace hal_1_1