From 5d4873fb0aa06aef4e5bc709950067606725bd62 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Wed, 3 Jun 2020 14:39:29 +0100 Subject: IVGCVSW-4903 Gather support for axis != 0 !armnn:3301 Signed-off-by: Teresa Charlin Change-Id: Ieba2cddd45bc353714c3a34f98f5ea49c772f426 --- ConversionUtils_1_2.hpp | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 497446c1..0a153996 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -698,7 +698,7 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers if (outputDimensions != inputDimensions + indicesDimensions - 1) { return Fail("%s: Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor", - __func__, outputDimensions, inputDimensions,indicesDimensions); + __func__, outputDimensions, inputDimensions, indicesDimensions); } int32_t axis; @@ -706,20 +706,15 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers { return Fail("%s: Operation has invalid or unsupported axis operand", __func__); } - if (-inputDimensions <= axis || axis > inputDimensions) + if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0))) { - return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-&d, %d))", __func__, axis, - inputDimensions,inputDimensions); - } - if (axis < 0) - { - axis += inputDimensions; - } - if (axis != 0) - { - return Fail("%s: Only axis 0 is currently supported. Axis: %d", __func__, axis); + return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-%d, %d))", __func__, axis, + inputDimensions, inputDimensions); } + GatherDescriptor desc; + desc.m_Axis = axis; + bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsGatherSupported, @@ -727,13 +722,14 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers isSupported, input.GetTensorInfo(), indices.GetTensorInfo(), - outputInfo); + outputInfo, + desc); if (!isSupported) { return false; } - IConnectableLayer* layer = data.m_Network->AddGatherLayer(); + IConnectableLayer* layer = data.m_Network->AddGatherLayer(desc); assert(layer != nullptr); input.Connect(layer->GetInputSlot(0)); indices.Connect(layer->GetInputSlot(1)); -- cgit v1.2.1