29 return CloneBase<GatherLayer>(graph,
GetName());
41 const unsigned int outputDim = paramsDim - 1 + indicesDim;
43 std::vector<unsigned int> dimSizes;
45 for (
unsigned int i = 0; i < indicesDim; ++i)
47 dimSizes.push_back(indices.
GetShape()[i]);
49 for (
unsigned int i = 1; i < paramsDim; ++i)
51 dimSizes.push_back(params.
GetShape()[i]);
56 ConditionalThrowIfNotEqual<LayerValidationException>(
57 "GatherLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
void Accept(ILayerVisitor &visitor) const override
const char * GetName() const override
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
virtual void VisitGatherLayer(const IConnectableLayer *layer, const char *name=nullptr)=0
virtual const TensorInfo & GetTensorInfo() const =0
unsigned int GetNumDimensions() const
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
Helper function to reduce duplication in *LayerCreateWorkload.
void Gather(const TensorInfo ¶msInfo, const TensorInfo &indicesInfo, const TensorInfo &outputInfo, Decoder< float > ¶ms, const int32_t *indices, Encoder< float > &output)
virtual std::unique_ptr< IWorkload > CreateGather(const GatherQueueDescriptor &descriptor, const WorkloadInfo &info) const
GatherLayer(const char *name)
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation &location) const
This layer represents a Gather operator.
void ValidateTensorShapesFromInputs() override
GatherLayer * Clone(Graph &graph) const override
const TensorShape & GetShape() const
const TensorInfo & GetTensorInfo() const override
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
const InputSlot & GetInputSlot(unsigned int index) const override