aboutsummaryrefslogtreecommitdiff
path: root/1.2/HalPolicy.cpp
diff options
context:
space:
mode:
Diffstat (limited to '1.2/HalPolicy.cpp')
-rw-r--r--1.2/HalPolicy.cpp71
1 files changed, 71 insertions, 0 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 1de57e5a..7aa6967a 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -9,6 +9,7 @@
#include <DataLayoutIndexed.hpp>
#include <Half.hpp>
+#include <TensorUtils.hpp>
#include <cmath>
@@ -39,6 +40,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertDequantize(operation, model, data);
case V1_2::OperationType::DIV:
return ConvertDiv(operation, model, data);
+ case V1_2::OperationType::EXPAND_DIMS:
+ return ConvertExpandDims(operation, model, data);
case V1_2::OperationType::FLOOR:
return ConvertFloor(operation, model, data);
case V1_2::OperationType::FULLY_CONNECTED:
@@ -473,6 +476,74 @@ bool HalPolicy::ConvertDiv(const Operation& operation, const Model& model, Conve
return ::ConvertDiv<hal_1_2::HalPolicy>(operation, model, data);
}
+bool HalPolicy::ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data)
+{
+ ALOGV("hal_1_2::HalPolicy::ConvertExpandDims()");
+
+ LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
+
+ if (!input.IsValid())
+ {
+ return Fail("%s: Operation has invalid input", __func__);
+ }
+
+ const Operand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Operation has invalid output", __func__);
+ }
+
+ const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+ if (IsDynamicTensor(outputInfo))
+ {
+ return Fail("%s: Dynamic output tensors are not supported", __func__);
+ }
+
+ int32_t axis;
+ if (!GetInputScalar<HalPolicy>(operation, 1, OperandType::INT32, axis, model, data))
+ {
+ return Fail("%s: failed to get axis input value", __func__);
+ }
+
+ armnn::TensorShape targetShape;
+
+ try
+ {
+ targetShape = armnnUtils::ExpandDims(input.GetTensorInfo().GetShape(), axis);
+ }
+ catch (const std::exception &e)
+ {
+ return Fail("%s: %s", __func__, e.what());
+ }
+
+ if (targetShape != outputInfo.GetShape())
+ {
+ return Fail("%s: Shape of the output operand does not match the resolved expanded shape", __func__);
+ }
+
+ armnn::ReshapeDescriptor reshapeDescriptor;
+ reshapeDescriptor.m_TargetShape = targetShape;
+
+ bool isSupported = false;
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsReshapeSupported,
+ data.m_Backends,
+ isSupported,
+ input.GetTensorInfo(),
+ reshapeDescriptor);
+
+ if (!isSupported)
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* layer = data.m_Network->AddReshapeLayer(reshapeDescriptor);
+ assert(layer != nullptr);
+ input.Connect(layer->GetInputSlot(0));
+
+ return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
+}
+
bool HalPolicy::ConvertFloor(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertFloor()");