11 namespace optimizations
20 if (IsReshape(transpose))
24 const std::string name = std::string(
"as_reshape-") + transpose.
GetName();
45 const unsigned int numDimensions = permutation.
GetSize();
46 std::map<unsigned int, unsigned int> permuteMappings;
47 for (
unsigned int i = 0; i < permutation.
GetSize(); ++i)
49 permuteMappings[permutation[i]] = i;
52 std::vector<unsigned int> permuteVector;
53 for (
unsigned int i = 0; i < permutation.
GetSize(); ++i)
55 permuteVector.push_back(permuteMappings.at(i));
58 unsigned int lastGtOne = 0;
59 while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U))
64 bool isReshape =
true;
65 for (
unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
67 if (outShape[permuteVector[i]] > 1U)
69 isReshape = permuteVector[lastGtOne] < permuteVector[i];
const TensorShape & GetShape() const
A ReshapeDescriptor for the ReshapeLayer.
~TransposeAsReshapeImpl()=default
This layer represents a reshape operation.
Copyright (c) 2020 ARM Limited.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
This layer represents a transpose operation.
void SetTensorInfo(const TensorInfo &tensorInfo)
Sets the TensorInfo used by this output handler.
TransposeAsReshapeImpl()=default
const OutputHandler & GetOutputHandler(unsigned int i=0) const
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
const char * GetName() const override
Returns the name of the layer.
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
const PermutationVector & GetPermutation() const
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
void Run(Graph &graph, TransposeLayer &transpose) const
Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.