ArmNN
 22.05
PreluImpl.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "PreluImpl.hpp"
7 #include "RefWorkloadUtils.hpp"
8 #include "Broadcast.hpp"
9 
10 namespace armnn
11 {
12 
13 void PreluImpl(const TensorInfo& inputInfo,
14  const TensorInfo& alphaInfo,
15  const TensorInfo& outputInfo,
16  Decoder<float>& inputData,
17  Decoder<float>& alphaData,
18  Encoder<float>& outputData)
19 {
20  const TensorShape& inputShape = inputInfo.GetShape();
21  const TensorShape& alphaShape = alphaInfo.GetShape();
22  const TensorShape& outputShape = outputInfo.GetShape();
23 
24  // PReLU activation: f(x) = alpha * x for x < 0, f(x) = x for x >= 0
25  auto prelu = [](float x, float alpha)
26  {
27  return x < 0 ? alpha * x : x;
28  };
29 
30  BroadcastLoop(inputShape, alphaShape, outputShape).Unroll(prelu, 0, inputData, alphaData, outputData);
31 }
32 
33 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
Copyright (c) 2021 ARM Limited and Contributors.
void PreluImpl(const TensorInfo &inputInfo, const TensorInfo &alphaInfo, const TensorInfo &outputInfo, Decoder< float > &inputData, Decoder< float > &alphaData, Encoder< float > &outputData)
Definition: PreluImpl.cpp:13
void Unroll(Func operationFunc, unsigned int dimension, DecoderOp &inData0, DecoderOp &inData1, EncoderOp &outData)
Definition: Broadcast.hpp:26