From f931af987c63466c95426742d7297d49438f8170 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Fri, 10 Apr 2020 16:46:53 +0100 Subject: IVGCVSW-3847 Add Support for GATHER Signed-off-by: Teresa Charlin Change-Id: I69dd78d47628355c207a450119b054b04581c729 --- 1.2/HalPolicy.cpp | 8 +++++ 1.2/HalPolicy.hpp | 2 ++ 1.3/HalPolicy.cpp | 8 +++++ 1.3/HalPolicy.hpp | 2 ++ ConversionUtils_1_2.hpp | 81 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 101 insertions(+) diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index d55e587f..4c2a6b5d 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -53,6 +53,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertFloor(operation, model, data); case V1_2::OperationType::FULLY_CONNECTED: return ConvertFullyConnected(operation, model, data); + case V1_2::OperationType::GATHER: + return ConvertGather(operation, model, data); case V1_2::OperationType::GREATER: return ConvertComparison(operation, model, data, ComparisonOperation::Greater); case V1_2::OperationType::GREATER_EQUAL: @@ -240,6 +242,12 @@ bool HalPolicy::ConvertFullyConnected(const Operation& operation, const Model& m return ::ConvertFullyConnected(operation, model, data); } +bool HalPolicy::ConvertGather (const Operation& operation, const Model& model, ConversionData& data) +{ + ALOGV("hal_1_2::HalPolicy::ConvertGather()"); + return ::ConvertGather(operation, model, data); +} + bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertGroupedConv2d()"); diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp index ce43a6e5..be02c22f 100644 --- a/1.2/HalPolicy.hpp +++ b/1.2/HalPolicy.hpp @@ -74,6 +74,8 @@ private: static bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertGather(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertGroupedConv2d(const Operation& operation, const Model& model, ConversionData& data); static bool ConvertInstanceNormalization(const Operation& operation, const Model& model, ConversionData& data); diff --git a/1.3/HalPolicy.cpp b/1.3/HalPolicy.cpp index 1077b787..707ef726 100644 --- a/1.3/HalPolicy.cpp +++ b/1.3/HalPolicy.cpp @@ -55,6 +55,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return ConvertFloor(operation, model, data); case V1_3::OperationType::FULLY_CONNECTED: return ConvertFullyConnected(operation, model, data); + case V1_3::OperationType::GATHER: + return ConvertGather(operation, model, data); case V1_3::OperationType::GREATER: return ConvertComparison(operation, model, data, ComparisonOperation::Greater); case V1_3::OperationType::GREATER_EQUAL: @@ -253,6 +255,12 @@ bool HalPolicy::ConvertFullyConnected(const Operation& operation, const Model& m return ::ConvertFullyConnected(operation, model, data); } +bool HalPolicy::ConvertGather(const Operation& operation, const Model& model, ConversionData& data) +{ + ALOGV("hal_1_3::HalPolicy::ConvertGather()"); + return ::ConvertGather(operation, model, data); +} + bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_3::HalPolicy::ConvertGroupedConv2d()"); diff --git a/1.3/HalPolicy.hpp b/1.3/HalPolicy.hpp index b59710a6..024d3ff5 100644 --- a/1.3/HalPolicy.hpp +++ b/1.3/HalPolicy.hpp @@ -77,6 +77,8 @@ private: static bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertGather(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertGroupedConv2d(const Operation& operation, const Model& model, ConversionData& data); static bool ConvertHardSwish(const Operation& operation, const Model& model, ConversionData& data); diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 29367f2f..4f142040 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -660,6 +660,87 @@ bool ConvertExpandDims(const HalOperation& operation, const HalModel& model, Con return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data); } +template +bool ConvertGather(const HalOperation& operation, const HalModel& model, ConversionData& data) +{ + using HalOperand = typename HalPolicy::Operand; + using HalOperandType = typename HalPolicy::OperandType; + + ALOGV("HalPolicy::ConvertGather()"); + + LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data); + if (!input.IsValid()) + { + return Fail("%s: Operation has invalid input", __func__); + } + auto inputDimensions = input.GetTensorInfo().GetNumDimensions(); + + LayerInputHandle indices = ConvertToLayerInputHandle(operation, 2, model, data); + if (!indices.IsValid()) + { + return Fail("%s: Operation has invalid indices", __func__); + } + auto indicesDimensions = indices.GetTensorInfo().GetNumDimensions(); + + const HalOperand* output = GetOutputOperand(operation, 0, model); + if (!output) + { + return Fail("%s: Operation has invalid output", __func__); + } + const TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + auto outputDimensions = outputInfo.GetNumDimensions(); + if (IsDynamicTensor(outputInfo)) + { + return Fail("%s: Dynamic output tensors are not supported", __func__); + } + if (outputDimensions != inputDimensions + indicesDimensions - 1) + { + return Fail("%s: Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor", + __func__, outputDimensions, inputDimensions,indicesDimensions); + } + + int32_t axis; + if (!GetInputScalar(operation, 1, HalOperandType::INT32, axis, model, data)) + { + return Fail("%s: Operation has invalid or unsupported axis operand", __func__); + } + if (-inputDimensions <= axis || axis > inputDimensions) + { + return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-&d, %d))", __func__, axis, + inputDimensions,inputDimensions); + } + if (axis < 0) + { + axis += inputDimensions; + } + if (axis != 0) + { + return Fail("%s: Only axis 0 is currently supported. Axis: %d", __func__, axis); + } + + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsGatherSupported, + data.m_Backends, + isSupported, + input.GetTensorInfo(), + indices.GetTensorInfo(), + outputInfo); + if (!isSupported) + { + return false; + } + + IConnectableLayer* layer = data.m_Network->AddGatherLayer(); + assert(layer != nullptr); + input.Connect(layer->GetInputSlot(0)); + indices.Connect(layer->GetInputSlot(1)); + + return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data); +} + template -- cgit v1.2.1