diff options
Diffstat (limited to 'src/backends/reference/workloads/Broadcast.cpp')
-rw-r--r-- | src/backends/reference/workloads/Broadcast.cpp | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp index 8421a0a7ed..24af0fc4b1 100644 --- a/src/backends/reference/workloads/Broadcast.cpp +++ b/src/backends/reference/workloads/Broadcast.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2019 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -30,4 +30,23 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape0, const TensorShape& inS } } +BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape) +: m_DimData(outShape.GetNumDimensions()) +{ + const unsigned int numDims = GetNumDimensions(); + + unsigned int sIn = 1; + unsigned int sOut = 1; + + for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--) + { + m_DimData[j].m_DimSize = outShape[j]; + m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0; + m_DimData[j].m_StrideOut = sOut; + + sIn *= inShape[j]; + sOut *= outShape[j]; + } +} + } // namespace armnn |