diff options
Diffstat (limited to 'src/backends/reference/workloads/Broadcast.hpp')
-rw-r--r-- | src/backends/reference/workloads/Broadcast.hpp | 35 |
1 files changed, 34 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/Broadcast.hpp b/src/backends/reference/workloads/Broadcast.hpp index 5bf6be8939..a3d944ae75 100644 --- a/src/backends/reference/workloads/Broadcast.hpp +++ b/src/backends/reference/workloads/Broadcast.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2019 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -15,6 +15,8 @@ struct BroadcastLoop { BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape); + BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape); + unsigned int GetNumDimensions() { return static_cast<unsigned int>(m_DimData.size()); @@ -56,6 +58,37 @@ struct BroadcastLoop outData -= outDataMovement; } + template <typename Func, typename DecoderOp, typename EncoderOp> + void Unroll(Func operationFunc, + unsigned int dimension, + DecoderOp& inData, + EncoderOp& outData) + { + if (dimension >= GetNumDimensions()) + { + outData.Set(operationFunc(inData.Get())); + return; + } + + unsigned int inDataMovement = 0; + unsigned int outDataMovement = 0; + + for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++) + { + Unroll(operationFunc, dimension + 1, inData, outData); + + inData += m_DimData[dimension].m_Stride1; + outData += m_DimData[dimension].m_StrideOut; + + inDataMovement += m_DimData[dimension].m_Stride1; + outDataMovement += m_DimData[dimension].m_StrideOut; + } + + // move iterator back to the start + inData -= inDataMovement; + outData -= outDataMovement; + } + private: // Struct to hold the dimension data. struct BroadcastDimensionData |