aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2018-09-17 14:25:04 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-12 11:59:53 +0100
commit3c05256c46bd6f8180305758c3cf48dbaa3f671b (patch)
tree80e3200c0c2612e5ad03cd93fea0e6749a2d4ee9
parent38e1294770598d49e2a4542d96be0491918546bb (diff)
downloadandroid-nn-driver-3c05256c46bd6f8180305758c3cf48dbaa3f671b.tar.gz
IVGCVSW-1814 - Add ConvertMean functionality to HalPolicy
Change-Id: I04b6150bb490a25fa7bf2da68c339ca5f1fe75de
-rw-r--r--1.1/HalPolicy.cpp62
-rw-r--r--1.1/HalPolicy.hpp1
2 files changed, 63 insertions, 0 deletions
diff --git a/1.1/HalPolicy.cpp b/1.1/HalPolicy.cpp
index 857d29bb..e3ccf73c 100644
--- a/1.1/HalPolicy.cpp
+++ b/1.1/HalPolicy.cpp
@@ -29,6 +29,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertDiv(operation, model, data);
case V1_1::OperationType::SUB:
return ConvertSub(operation, model, data);
+ case V1_1::OperationType::MEAN:
+ return ConvertMean(operation, model, data);
default:
return Fail("%s: Operation type %s not supported in ArmnnDriver",
__func__, toString(operation.type).c_str());
@@ -138,5 +140,65 @@ bool HalPolicy::ConvertSub(const Operation& operation, const Model& model, Conve
return Fail("%s: ProcessActivation failed", __func__);
}
+bool HalPolicy::ConvertMean(const Operation& operation, const Model& model, ConversionData& data)
+{
+ LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
+
+ if (!input.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
+
+ armnn::MeanDescriptor descriptor;
+
+ const Operand* axisOperand = GetInputOperand(operation, 1, model);
+ if (axisOperand)
+ {
+ std::vector<int32_t> axis;
+ GetTensorInt32Values(*axisOperand, axis, model, data);
+ unsigned int rank = inputInfo.GetNumDimensions();
+ // convert the axis to unsigned int.
+ for (auto& i : axis)
+ {
+ unsigned int unsignedAxis = (i + rank) % rank;
+ if (std::find(descriptor.m_Axis.begin(), descriptor.m_Axis.end(), unsignedAxis) == descriptor.m_Axis.end())
+ {
+ descriptor.m_Axis.push_back(unsignedAxis);
+ }
+ }
+ }
+
+ int32_t keepDims;
+ GetInputInt32(operation, 2, keepDims, model, data);
+ if (keepDims > 0)
+ {
+ descriptor.m_KeepDims = true;
+ }
+
+ const Operand* output = GetOutputOperand(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Could not read output 0", __func__);
+ }
+
+ const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+
+ if (!IsLayerSupported(__func__,
+ armnn::IsMeanSupported,
+ data.m_Compute,
+ inputInfo,
+ outputInfo,
+ descriptor))
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* const layer = data.m_Network->AddMeanLayer(descriptor);
+
+ return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data);
+}
+
} // namespace hal_1_1
} // namespace armnn_driver \ No newline at end of file
diff --git a/1.1/HalPolicy.hpp b/1.1/HalPolicy.hpp
index 3b7fe541..dc3332c0 100644
--- a/1.1/HalPolicy.hpp
+++ b/1.1/HalPolicy.hpp
@@ -27,6 +27,7 @@ public:
private:
static bool ConvertDiv(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertSub(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertMean(const Operation& operation, const Model& model, ConversionData& data);
};
} // namespace hal_1_1