aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-06-03 14:39:29 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-06-29 17:37:48 +0000
commit5d4873fb0aa06aef4e5bc709950067606725bd62 (patch)
tree519ba034b04ca65dccfe00342cf6b87f18b1ce70
parent2e32961e568e8e99a65dd7726bffcd56dfb9f87e (diff)
downloadandroid-nn-driver-5d4873fb0aa06aef4e5bc709950067606725bd62.tar.gz
IVGCVSW-4903 Gather support for axis != 0
!armnn:3301 Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ieba2cddd45bc353714c3a34f98f5ea49c772f426
-rw-r--r--ConversionUtils_1_2.hpp24
1 files changed, 10 insertions, 14 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp
index 497446c1..0a153996 100644
--- a/ConversionUtils_1_2.hpp
+++ b/ConversionUtils_1_2.hpp
@@ -698,7 +698,7 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers
if (outputDimensions != inputDimensions + indicesDimensions - 1)
{
return Fail("%s: Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor",
- __func__, outputDimensions, inputDimensions,indicesDimensions);
+ __func__, outputDimensions, inputDimensions, indicesDimensions);
}
int32_t axis;
@@ -706,20 +706,15 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers
{
return Fail("%s: Operation has invalid or unsupported axis operand", __func__);
}
- if (-inputDimensions <= axis || axis > inputDimensions)
+ if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
{
- return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-&d, %d))", __func__, axis,
- inputDimensions,inputDimensions);
- }
- if (axis < 0)
- {
- axis += inputDimensions;
- }
- if (axis != 0)
- {
- return Fail("%s: Only axis 0 is currently supported. Axis: %d", __func__, axis);
+ return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-%d, %d))", __func__, axis,
+ inputDimensions, inputDimensions);
}
+ GatherDescriptor desc;
+ desc.m_Axis = axis;
+
bool isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
IsGatherSupported,
@@ -727,13 +722,14 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers
isSupported,
input.GetTensorInfo(),
indices.GetTensorInfo(),
- outputInfo);
+ outputInfo,
+ desc);
if (!isSupported)
{
return false;
}
- IConnectableLayer* layer = data.m_Network->AddGatherLayer();
+ IConnectableLayer* layer = data.m_Network->AddGatherLayer(desc);
assert(layer != nullptr);
input.Connect(layer->GetInputSlot(0));
indices.Connect(layer->GetInputSlot(1));