aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Pad.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Pad.cpp')
-rw-r--r--src/backends/reference/workloads/Pad.cpp51
1 files changed, 35 insertions, 16 deletions
diff --git a/src/backends/reference/workloads/Pad.cpp b/src/backends/reference/workloads/Pad.cpp
index 7a928a1336..1e58124627 100644
--- a/src/backends/reference/workloads/Pad.cpp
+++ b/src/backends/reference/workloads/Pad.cpp
@@ -5,8 +5,10 @@
#include "Pad.hpp"
#include "backendsCommon/WorkloadData.hpp"
-#include <boost/numeric/conversion/cast.hpp>
#include "TensorBufferArrayView.hpp"
+#include "Encoders.hpp"
+
+#include <boost/numeric/conversion/cast.hpp>
#include <cmath>
#include <cstddef>
#include <functional>
@@ -15,12 +17,25 @@
namespace armnn
{
+
+template <typename T>
+T ConvertToDataType(const float& value,
+ const armnn::TensorInfo& tensorInfo)
+{
+ std::vector<T> output(1);
+ std::unique_ptr<armnn::Encoder<float>> pEncoder = armnn::MakeEncoder<float>(tensorInfo, output.data());
+ armnn::Encoder<float>& rEncoder = *pEncoder;
+ rEncoder.Set(value);
+ return output[0];
+}
+
template <typename T>
void Pad(const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
- std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
+ std::vector<std::pair<unsigned int, unsigned int>> m_padList,
const T* inputData,
- T* outData)
+ T* outData,
+ const float padValue)
{
unsigned int numOutputElements = outputInfo.GetNumElements();
@@ -45,9 +60,11 @@ void Pad(const TensorInfo& inputInfo,
unsigned int outputHeight = 0;
unsigned int outputWidth = 0;
+ T convertedPadValue = ConvertToDataType<T>(padValue, inputInfo);
+
for (unsigned int i = 0; i < numOutputElements; ++i)
{
- outData[i] = 0;
+ outData[i] = convertedPadValue;
}
switch(numInputDimensions) {
@@ -58,7 +75,7 @@ void Pad(const TensorInfo& inputInfo,
for (unsigned int w = 0; w < inputWidth ; w++)
{
- outData[w+std::get<0>(m_PadList[0])] = inputData[w];
+ outData[w+std::get<0>(m_padList[0])] = inputData[w];
}
break;
@@ -74,8 +91,8 @@ void Pad(const TensorInfo& inputInfo,
{
for (unsigned int w = 0; w < inputWidth ; w++)
{
- outData[(h+std::get<0>(m_PadList[0]))*outputWidth
- + (w+std::get<0>(m_PadList[1]))] = inputData[h * inputWidth + w];
+ outData[(h+std::get<0>(m_padList[0]))*outputWidth
+ + (w+std::get<0>(m_padList[1]))] = inputData[h * inputWidth + w];
}
}
@@ -96,9 +113,9 @@ void Pad(const TensorInfo& inputInfo,
{
for (unsigned int w = 0; w < inputWidth ; w++)
{
- outData[(c+std::get<0>(m_PadList[0]))*outputHeight*outputWidth
- + (h+std::get<0>(m_PadList[1]))*outputWidth
- + (w+std::get<0>(m_PadList[2]))] = inputData[c * inputHeight * inputWidth
+ outData[(c+std::get<0>(m_padList[0]))*outputHeight*outputWidth
+ + (h+std::get<0>(m_padList[1]))*outputWidth
+ + (w+std::get<0>(m_padList[2]))] = inputData[c * inputHeight * inputWidth
+ h * inputWidth
+ w];
}
@@ -125,10 +142,10 @@ void Pad(const TensorInfo& inputInfo,
{
for (unsigned int w = 0; w < inputWidth ; w++)
{
- outData[(b+std::get<0>(m_PadList[0])) * outputChannels * outputHeight * outputWidth
- + (c+std::get<0>(m_PadList[1])) * outputHeight * outputWidth
- + (h+std::get<0>(m_PadList[2])) * outputWidth
- + (w+std::get<0>(m_PadList[3]))] = inputData[b * inputChannels * inputHeight
+ outData[(b+std::get<0>(m_padList[0])) * outputChannels * outputHeight * outputWidth
+ + (c+std::get<0>(m_padList[1])) * outputHeight * outputWidth
+ + (h+std::get<0>(m_padList[2])) * outputWidth
+ + (w+std::get<0>(m_padList[3]))] = inputData[b * inputChannels * inputHeight
* inputWidth
+ c * inputHeight * inputWidth
+ h * inputWidth
@@ -150,11 +167,13 @@ template void Pad<float>(const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
const float* inputData,
- float* outData);
+ float* outData,
+ const float padValue);
template void Pad<uint8_t>(const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
const uint8_t* inputData,
- uint8_t* outData);
+ uint8_t* outData,
+ const float padValue);
} //namespace armnn \ No newline at end of file