aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/Gather.cpp7
-rw-r--r--src/backends/reference/workloads/Gather.hpp5
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.cpp4
3 files changed, 10 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
index c23edcd3bd..3e2190c81b 100644
--- a/src/backends/reference/workloads/Gather.cpp
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -20,9 +20,12 @@ void Gather(const TensorInfo& paramsInfo,
const TensorInfo& outputInfo,
Decoder<float>& params,
const int32_t* indices,
- Encoder<float>& output)
+ Encoder<float>& output,
+ const int32_t axis)
{
IgnoreUnused(outputInfo);
+ IgnoreUnused(axis);
+
const TensorShape& paramsShape = paramsInfo.GetShape();
unsigned int paramsProduct = 1;
diff --git a/src/backends/reference/workloads/Gather.hpp b/src/backends/reference/workloads/Gather.hpp
index 16c983eec4..1550f4b97c 100644
--- a/src/backends/reference/workloads/Gather.hpp
+++ b/src/backends/reference/workloads/Gather.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -19,6 +19,7 @@ void Gather(const TensorInfo& paramsInfo,
const TensorInfo& outputInfo,
Decoder<float>& params,
const int32_t* indices,
- Encoder<float>& output);
+ Encoder<float>& output,
+ const int32_t = 0);
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
index 8edf14c8f8..eaeed61b0a 100644
--- a/src/backends/reference/workloads/RefGatherWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -29,7 +29,7 @@ void RefGatherWorkload::Execute() const
std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
Encoder<float>& encoder = *encoderPtr;
- Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder);
+ Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder, m_Data.m_Parameters.m_Axis);
}
} //namespace armnn