diff options
author | narpra01 <narumol.prangnawarat@arm.com> | 2018-09-17 14:25:04 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-12 11:59:53 +0100 |
commit | 3c05256c46bd6f8180305758c3cf48dbaa3f671b (patch) | |
tree | 80e3200c0c2612e5ad03cd93fea0e6749a2d4ee9 /1.1/HalPolicy.cpp | |
parent | 38e1294770598d49e2a4542d96be0491918546bb (diff) | |
download | android-nn-driver-3c05256c46bd6f8180305758c3cf48dbaa3f671b.tar.gz |
IVGCVSW-1814 - Add ConvertMean functionality to HalPolicy
Change-Id: I04b6150bb490a25fa7bf2da68c339ca5f1fe75de
Diffstat (limited to '1.1/HalPolicy.cpp')
-rw-r--r-- | 1.1/HalPolicy.cpp | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/1.1/HalPolicy.cpp b/1.1/HalPolicy.cpp index 857d29bb..e3ccf73c 100644 --- a/1.1/HalPolicy.cpp +++ b/1.1/HalPolicy.cpp @@ -29,6 +29,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertDiv(operation, model, data); case V1_1::OperationType::SUB: return ConvertSub(operation, model, data); + case V1_1::OperationType::MEAN: + return ConvertMean(operation, model, data); default: return Fail("%s: Operation type %s not supported in ArmnnDriver", __func__, toString(operation.type).c_str()); @@ -138,5 +140,65 @@ bool HalPolicy::ConvertSub(const Operation& operation, const Model& model, Conve return Fail("%s: ProcessActivation failed", __func__); } +bool HalPolicy::ConvertMean(const Operation& operation, const Model& model, ConversionData& data) +{ + LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data); + + if (!input.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); + + armnn::MeanDescriptor descriptor; + + const Operand* axisOperand = GetInputOperand(operation, 1, model); + if (axisOperand) + { + std::vector<int32_t> axis; + GetTensorInt32Values(*axisOperand, axis, model, data); + unsigned int rank = inputInfo.GetNumDimensions(); + // convert the axis to unsigned int. + for (auto& i : axis) + { + unsigned int unsignedAxis = (i + rank) % rank; + if (std::find(descriptor.m_Axis.begin(), descriptor.m_Axis.end(), unsignedAxis) == descriptor.m_Axis.end()) + { + descriptor.m_Axis.push_back(unsignedAxis); + } + } + } + + int32_t keepDims; + GetInputInt32(operation, 2, keepDims, model, data); + if (keepDims > 0) + { + descriptor.m_KeepDims = true; + } + + const Operand* output = GetOutputOperand(operation, 0, model); + if (!output) + { + return Fail("%s: Could not read output 0", __func__); + } + + const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + + if (!IsLayerSupported(__func__, + armnn::IsMeanSupported, + data.m_Compute, + inputInfo, + outputInfo, + descriptor)) + { + return false; + } + + armnn::IConnectableLayer* const layer = data.m_Network->AddMeanLayer(descriptor); + + return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data); +} + } // namespace hal_1_1 } // namespace armnn_driver
\ No newline at end of file |