diff options
Diffstat (limited to '1.2/HalPolicy.cpp')
-rw-r--r-- | 1.2/HalPolicy.cpp | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 1de57e5a..7aa6967a 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -9,6 +9,7 @@ #include <DataLayoutIndexed.hpp> #include <Half.hpp> +#include <TensorUtils.hpp> #include <cmath> @@ -39,6 +40,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertDequantize(operation, model, data); case V1_2::OperationType::DIV: return ConvertDiv(operation, model, data); + case V1_2::OperationType::EXPAND_DIMS: + return ConvertExpandDims(operation, model, data); case V1_2::OperationType::FLOOR: return ConvertFloor(operation, model, data); case V1_2::OperationType::FULLY_CONNECTED: @@ -473,6 +476,74 @@ bool HalPolicy::ConvertDiv(const Operation& operation, const Model& model, Conve return ::ConvertDiv<hal_1_2::HalPolicy>(operation, model, data); } +bool HalPolicy::ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data) +{ + ALOGV("hal_1_2::HalPolicy::ConvertExpandDims()"); + + LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data); + + if (!input.IsValid()) + { + return Fail("%s: Operation has invalid input", __func__); + } + + const Operand* output = GetOutputOperand<HalPolicy>(operation, 0, model); + if (!output) + { + return Fail("%s: Operation has invalid output", __func__); + } + + const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + if (IsDynamicTensor(outputInfo)) + { + return Fail("%s: Dynamic output tensors are not supported", __func__); + } + + int32_t axis; + if (!GetInputScalar<HalPolicy>(operation, 1, OperandType::INT32, axis, model, data)) + { + return Fail("%s: failed to get axis input value", __func__); + } + + armnn::TensorShape targetShape; + + try + { + targetShape = armnnUtils::ExpandDims(input.GetTensorInfo().GetShape(), axis); + } + catch (const std::exception &e) + { + return Fail("%s: %s", __func__, e.what()); + } + + if (targetShape != outputInfo.GetShape()) + { + return Fail("%s: Shape of the output operand does not match the resolved expanded shape", __func__); + } + + armnn::ReshapeDescriptor reshapeDescriptor; + reshapeDescriptor.m_TargetShape = targetShape; + + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsReshapeSupported, + data.m_Backends, + isSupported, + input.GetTensorInfo(), + reshapeDescriptor); + + if (!isSupported) + { + return false; + } + + armnn::IConnectableLayer* layer = data.m_Network->AddReshapeLayer(reshapeDescriptor); + assert(layer != nullptr); + input.Connect(layer->GetInputSlot(0)); + + return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data); +} + bool HalPolicy::ConvertFloor(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertFloor()"); |