aboutsummaryrefslogtreecommitdiff
path: root/1.1/HalPolicy.cpp
diff options
context:
space:
mode:
Diffstat (limited to '1.1/HalPolicy.cpp')
-rw-r--r--1.1/HalPolicy.cpp62
1 files changed, 62 insertions, 0 deletions
diff --git a/1.1/HalPolicy.cpp b/1.1/HalPolicy.cpp
index 857d29bb..e3ccf73c 100644
--- a/1.1/HalPolicy.cpp
+++ b/1.1/HalPolicy.cpp
@@ -29,6 +29,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertDiv(operation, model, data);
case V1_1::OperationType::SUB:
return ConvertSub(operation, model, data);
+ case V1_1::OperationType::MEAN:
+ return ConvertMean(operation, model, data);
default:
return Fail("%s: Operation type %s not supported in ArmnnDriver",
__func__, toString(operation.type).c_str());
@@ -138,5 +140,65 @@ bool HalPolicy::ConvertSub(const Operation& operation, const Model& model, Conve
return Fail("%s: ProcessActivation failed", __func__);
}
+bool HalPolicy::ConvertMean(const Operation& operation, const Model& model, ConversionData& data)
+{
+ LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
+
+ if (!input.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
+
+ armnn::MeanDescriptor descriptor;
+
+ const Operand* axisOperand = GetInputOperand(operation, 1, model);
+ if (axisOperand)
+ {
+ std::vector<int32_t> axis;
+ GetTensorInt32Values(*axisOperand, axis, model, data);
+ unsigned int rank = inputInfo.GetNumDimensions();
+ // convert the axis to unsigned int.
+ for (auto& i : axis)
+ {
+ unsigned int unsignedAxis = (i + rank) % rank;
+ if (std::find(descriptor.m_Axis.begin(), descriptor.m_Axis.end(), unsignedAxis) == descriptor.m_Axis.end())
+ {
+ descriptor.m_Axis.push_back(unsignedAxis);
+ }
+ }
+ }
+
+ int32_t keepDims;
+ GetInputInt32(operation, 2, keepDims, model, data);
+ if (keepDims > 0)
+ {
+ descriptor.m_KeepDims = true;
+ }
+
+ const Operand* output = GetOutputOperand(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Could not read output 0", __func__);
+ }
+
+ const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+
+ if (!IsLayerSupported(__func__,
+ armnn::IsMeanSupported,
+ data.m_Compute,
+ inputInfo,
+ outputInfo,
+ descriptor))
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* const layer = data.m_Network->AddMeanLayer(descriptor);
+
+ return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data);
+}
+
} // namespace hal_1_1
} // namespace armnn_driver \ No newline at end of file