aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads/ConvImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloads/ConvImpl.hpp')
-rw-r--r--src/armnn/backends/RefWorkloads/ConvImpl.hpp26
1 files changed, 13 insertions, 13 deletions
diff --git a/src/armnn/backends/RefWorkloads/ConvImpl.hpp b/src/armnn/backends/RefWorkloads/ConvImpl.hpp
index 8b66b0b7d2..b7d5d17a8d 100644
--- a/src/armnn/backends/RefWorkloads/ConvImpl.hpp
+++ b/src/armnn/backends/RefWorkloads/ConvImpl.hpp
@@ -18,7 +18,7 @@
namespace armnn
{
-/// Performs multiplication of a integer with a multiplier which is less than one,
+/// Performs multiplication of an integer with a multiplier which is less than one,
/// using quantized integer arithmetic which is consistent with AndroidNN's CPU executor.
struct QuantizedMultiplierSmallerThanOne
{
@@ -28,21 +28,21 @@ public:
/// The implementation of this function is adapted from Android NN's QuantizeMultiplierSmallerThanOne().
QuantizedMultiplierSmallerThanOne(float multiplier);
- /// The implementation of this function is adapted from Android NN's MultiplyByQuantizedMultiplierSmallerThanOne()
+ /// The implementation of this function is adapted from Android NN's MultiplyByQuantizedMultiplierSmallerThanOne().
int32_t operator*(int32_t rhs) const;
private:
- /// The implementation of this function is adapted from gemmlowp's SaturatingRoundingDoublingHighMul()
+ /// The implementation of this function is adapted from gemmlowp's SaturatingRoundingDoublingHighMul().
static int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b);
- /// The implementation of this function is adapted from gemmlowp's RoundingDivideByPOT()
+ /// The implementation of this function is adapted from gemmlowp's RoundingDivideByPOT().
static int32_t RoundingDivideByPOT(int32_t x, int exponent);
int32_t m_Multiplier;
int32_t m_RightShift;
};
-/// an implementation shared by normal and depthwise convolution
+/// An implementation shared by normal and depthwise convolution.
template<typename ConvData, typename InputType, typename BiasType, typename AccumulatorType>
static void ConvImpl(ConvData data,
const InputType* inputData,
@@ -55,6 +55,7 @@ static void ConvImpl(ConvData data,
InputType* outputData,
float outputScale,
int32_t outputOffset,
+ const TensorInfo& filterInfo,
bool depthwise = false)
{
if (data.m_Parameters.m_BiasEnabled && !biasData)
@@ -64,7 +65,6 @@ static void ConvImpl(ConvData data,
const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]);
- const TensorInfo& filterInfo = data.m_Weight->GetTensorInfo();
unsigned int depthMult = depthwise ? filterInfo.GetShape()[0] : 1;
unsigned int channelsInput = filterInfo.GetShape()[1];
@@ -84,7 +84,7 @@ static void ConvImpl(ConvData data,
unsigned int hStride = data.m_Parameters.m_StrideY;
unsigned int xStride = data.m_Parameters.m_StrideX;
- // the world's least efficient convolution
+ // The world's least efficient convolution.
for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++)
{
for (unsigned int cOutput = 0; cOutput < channelsOutput; cOutput++)
@@ -93,11 +93,11 @@ static void ConvImpl(ConvData data,
{
for (unsigned int xOutput = 0; xOutput < widthOutput; xOutput++)
{
- // this loop goes over each output element
+ // This loop goes over each output element.
AccumulatorType sum = AccumulatorType();
- // for depthwise, each output channel corresponds to exactly one input channel
- // for normal, must loop over each input channel
+ // For depthwise, each output channel corresponds to exactly one input channel.
+ // For normal, must loop over each input channel.
for (unsigned int cInput = 0; cInput < (depthwise ? 1 : channelsInput); cInput++)
{
unsigned int depthwiseMultiplierIdx = 0;
@@ -111,11 +111,11 @@ static void ConvImpl(ConvData data,
{
for (unsigned int xFilter = 0; xFilter < widthFilter; xFilter++)
{
- // this loop goes over each input element for each output element
+ // This loop goes over each input element for each output element.
unsigned int filterIndex;
- // since dimensionality of kernel depends on depthwiseness, so does index
+ // Since dimensionality of kernel depends on depthwiseness, so does index.
if (depthwise)
{
filterIndex = depthwiseMultiplierIdx * widthFilter * heightFilter * channelsInput +
@@ -138,7 +138,7 @@ static void ConvImpl(ConvData data,
AccumulatorType inputValue;
- // check if we're in the padding
+ // Check if we're in the padding.
if (yInput < paddingTop || yInput >= heightInput + paddingTop ||
xInput < paddingLeft || xInput >= widthInput + paddingLeft )
{