diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-06-03 14:39:29 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2020-06-29 17:37:48 +0000 |
commit | 5d4873fb0aa06aef4e5bc709950067606725bd62 (patch) | |
tree | 519ba034b04ca65dccfe00342cf6b87f18b1ce70 | |
parent | 2e32961e568e8e99a65dd7726bffcd56dfb9f87e (diff) | |
download | android-nn-driver-5d4873fb0aa06aef4e5bc709950067606725bd62.tar.gz |
IVGCVSW-4903 Gather support for axis != 0
!armnn:3301
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Ieba2cddd45bc353714c3a34f98f5ea49c772f426
-rw-r--r-- | ConversionUtils_1_2.hpp | 24 |
1 files changed, 10 insertions, 14 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 497446c1..0a153996 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -698,7 +698,7 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers if (outputDimensions != inputDimensions + indicesDimensions - 1) { return Fail("%s: Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor", - __func__, outputDimensions, inputDimensions,indicesDimensions); + __func__, outputDimensions, inputDimensions, indicesDimensions); } int32_t axis; @@ -706,20 +706,15 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers { return Fail("%s: Operation has invalid or unsupported axis operand", __func__); } - if (-inputDimensions <= axis || axis > inputDimensions) + if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0))) { - return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-&d, %d))", __func__, axis, - inputDimensions,inputDimensions); - } - if (axis < 0) - { - axis += inputDimensions; - } - if (axis != 0) - { - return Fail("%s: Only axis 0 is currently supported. Axis: %d", __func__, axis); + return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-%d, %d))", __func__, axis, + inputDimensions, inputDimensions); } + GatherDescriptor desc; + desc.m_Axis = axis; + bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsGatherSupported, @@ -727,13 +722,14 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers isSupported, input.GetTensorInfo(), indices.GetTensorInfo(), - outputInfo); + outputInfo, + desc); if (!isSupported) { return false; } - IConnectableLayer* layer = data.m_Network->AddGatherLayer(); + IConnectableLayer* layer = data.m_Network->AddGatherLayer(desc); assert(layer != nullptr); input.Connect(layer->GetInputSlot(0)); indices.Connect(layer->GetInputSlot(1)); |