aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BatchToSpaceNd.cpp
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2019-06-20 12:07:19 +0100
committerFrancis Murtagh <francis.murtagh@arm.com>2019-06-20 12:07:30 +0100
commit47ea3c0e8d8d10906d04a0e7c537ffee68b0f819 (patch)
treeab160c7126820dd2b5766974256167fdc66d183a /src/backends/reference/workloads/BatchToSpaceNd.cpp
parent51982472bfedf12e7d82cde6614617f94b2c86d0 (diff)
downloadarmnn-47ea3c0e8d8d10906d04a0e7c537ffee68b0f819.tar.gz
IVGCVSW-3248 Refactor reference BatchToSpace workload
* Add Decoders and Encoders to workload to make it data type agnostic * Merge float32 and Uint8 into single workload Change-Id: I8adfa1898a63f13889eaaf55a31c26fd1e2d7ee8 Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/BatchToSpaceNd.cpp')
-rw-r--r--src/backends/reference/workloads/BatchToSpaceNd.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/BatchToSpaceNd.cpp b/src/backends/reference/workloads/BatchToSpaceNd.cpp
index 5f64213b39..7efdb9b75c 100644
--- a/src/backends/reference/workloads/BatchToSpaceNd.cpp
+++ b/src/backends/reference/workloads/BatchToSpaceNd.cpp
@@ -37,8 +37,8 @@ void BatchToSpaceNd(const DataLayoutIndexed& dataLayout,
const TensorInfo& outputTensorInfo,
const std::vector<unsigned int>& blockShape,
const std::vector<std::pair<unsigned int, unsigned int>>& cropsData,
- const float* inputData,
- float* outputData)
+ Decoder<float>& inputDecoder,
+ Encoder<float>& outputEncoder)
{
TensorShape inputShape = inputTensorInfo.GetShape();
@@ -90,7 +90,10 @@ void BatchToSpaceNd(const DataLayoutIndexed& dataLayout,
{
unsigned int outOffset = Offset(outputShape, outBatch, outH, outW, c, dataLayout);
unsigned int inOffset = Offset(inputShape, inBatch, inH, inW, c, dataLayout);
- outputData[outOffset] = inputData[inOffset];
+
+ outputEncoder[outOffset];
+ inputDecoder[inOffset];
+ outputEncoder.Set(inputDecoder.Get());
}
}
}