diff options
Diffstat (limited to 'src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp index 973c87b009..d21cfa947a 100644 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp +++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp @@ -22,26 +22,32 @@ void RefL2NormalizationFloat32Workload::Execute() const const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - TensorBufferArrayView<const float> input(inputInfo.GetShape(), GetInputTensorDataFloat(0, m_Data)); - TensorBufferArrayView<float> output(outputInfo.GetShape(), GetOutputTensorDataFloat(0, m_Data)); + TensorBufferArrayView<const float> input(inputInfo.GetShape(), + GetInputTensorDataFloat(0, m_Data), + m_Data.m_Parameters.m_DataLayout); + TensorBufferArrayView<float> output(outputInfo.GetShape(), + GetOutputTensorDataFloat(0, m_Data), + m_Data.m_Parameters.m_DataLayout); - const unsigned int batchSize = inputInfo.GetShape()[0]; - const unsigned int depth = inputInfo.GetShape()[1]; - const unsigned int rows = inputInfo.GetShape()[2]; - const unsigned int cols = inputInfo.GetShape()[3]; + DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout); - for (unsigned int n = 0; n < batchSize; ++n) + const unsigned int batches = inputInfo.GetShape()[0]; + const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; + const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()]; + const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()]; + + for (unsigned int n = 0; n < batches; ++n) { - for (unsigned int d = 0; d < depth; ++d) + for (unsigned int c = 0; c < channels; ++c) { - for (unsigned int h = 0; h < rows; ++h) + for (unsigned int h = 0; h < height; ++h) { - for (unsigned int w = 0; w < cols; ++w) + for (unsigned int w = 0; w < width; ++w) { float reduction = 0.0; - for (unsigned int c = 0; c < depth; ++c) + for (unsigned int d = 0; d < channels; ++d) { - const float value = input.Get(n, c, h, w); + const float value = input.Get(n, d, h, w); reduction += value * value; } @@ -51,7 +57,7 @@ void RefL2NormalizationFloat32Workload::Execute() const // backend. // - The reference semantics for this operator do not include this parameter. const float scale = 1.0f / sqrtf(reduction); - output.Get(n, d, h, w) = input.Get(n, d, h, w) * scale; + output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale; } } } |