aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp')
-rw-r--r--src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp32
1 files changed, 19 insertions, 13 deletions
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
index 973c87b009..d21cfa947a 100644
--- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
+++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
@@ -22,26 +22,32 @@ void RefL2NormalizationFloat32Workload::Execute() const
const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- TensorBufferArrayView<const float> input(inputInfo.GetShape(), GetInputTensorDataFloat(0, m_Data));
- TensorBufferArrayView<float> output(outputInfo.GetShape(), GetOutputTensorDataFloat(0, m_Data));
+ TensorBufferArrayView<const float> input(inputInfo.GetShape(),
+ GetInputTensorDataFloat(0, m_Data),
+ m_Data.m_Parameters.m_DataLayout);
+ TensorBufferArrayView<float> output(outputInfo.GetShape(),
+ GetOutputTensorDataFloat(0, m_Data),
+ m_Data.m_Parameters.m_DataLayout);
- const unsigned int batchSize = inputInfo.GetShape()[0];
- const unsigned int depth = inputInfo.GetShape()[1];
- const unsigned int rows = inputInfo.GetShape()[2];
- const unsigned int cols = inputInfo.GetShape()[3];
+ DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout);
- for (unsigned int n = 0; n < batchSize; ++n)
+ const unsigned int batches = inputInfo.GetShape()[0];
+ const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
+ const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
+ const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
+
+ for (unsigned int n = 0; n < batches; ++n)
{
- for (unsigned int d = 0; d < depth; ++d)
+ for (unsigned int c = 0; c < channels; ++c)
{
- for (unsigned int h = 0; h < rows; ++h)
+ for (unsigned int h = 0; h < height; ++h)
{
- for (unsigned int w = 0; w < cols; ++w)
+ for (unsigned int w = 0; w < width; ++w)
{
float reduction = 0.0;
- for (unsigned int c = 0; c < depth; ++c)
+ for (unsigned int d = 0; d < channels; ++d)
{
- const float value = input.Get(n, c, h, w);
+ const float value = input.Get(n, d, h, w);
reduction += value * value;
}
@@ -51,7 +57,7 @@ void RefL2NormalizationFloat32Workload::Execute() const
// backend.
// - The reference semantics for this operator do not include this parameter.
const float scale = 1.0f / sqrtf(reduction);
- output.Get(n, d, h, w) = input.Get(n, d, h, w) * scale;
+ output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale;
}
}
}