diff options
-rw-r--r-- | ConversionUtils_1_2.hpp | 24 |
1 files changed, 10 insertions, 14 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 497446c1..0a153996 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -698,7 +698,7 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers if (outputDimensions != inputDimensions + indicesDimensions - 1) { return Fail("%s: Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor", - __func__, outputDimensions, inputDimensions,indicesDimensions); + __func__, outputDimensions, inputDimensions, indicesDimensions); } int32_t axis; @@ -706,20 +706,15 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers { return Fail("%s: Operation has invalid or unsupported axis operand", __func__); } - if (-inputDimensions <= axis || axis > inputDimensions) + if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0))) { - return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-&d, %d))", __func__, axis, - inputDimensions,inputDimensions); - } - if (axis < 0) - { - axis += inputDimensions; - } - if (axis != 0) - { - return Fail("%s: Only axis 0 is currently supported. Axis: %d", __func__, axis); + return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-%d, %d))", __func__, axis, + inputDimensions, inputDimensions); } + GatherDescriptor desc; + desc.m_Axis = axis; + bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsGatherSupported, @@ -727,13 +722,14 @@ bool ConvertGather(const HalOperation& operation, const HalModel& model, Convers isSupported, input.GetTensorInfo(), indices.GetTensorInfo(), - outputInfo); + outputInfo, + desc); if (!isSupported) { return false; } - IConnectableLayer* layer = data.m_Network->AddGatherLayer(); + IConnectableLayer* layer = data.m_Network->AddGatherLayer(desc); assert(layer != nullptr); input.Connect(layer->GetInputSlot(0)); indices.Connect(layer->GetInputSlot(1)); |