aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRyan OShea <Ryan.OShea2@arm.com>2020-06-05 17:17:06 +0100
committerKeithARM <keith.davis@arm.com>2020-06-11 12:17:49 +0000
commitec6c68093eaef8a2b8e1fd64fcc765237973512e (patch)
treee4dd87466d5be44d1310f9e1c9c6e1355df43fbf
parent6350d27286114dfdae5f65ae1823ba1150087efb (diff)
downloadarmnn-ec6c68093eaef8a2b8e1fd64fcc765237973512e.tar.gz
IVGCVSW-4906 Add front-end support for FILL operator
* Added new fill layer * Added visitor tests Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com> Change-Id: Iea677014866b4f2d514004623f59ee83f3c0eef8 Signed-off-by: Keith Davis <keith.davis@arm.com>
-rw-r--r--Android.mk1
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/armnn/Descriptors.hpp19
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/ILayerSupport.hpp5
-rw-r--r--include/armnn/ILayerVisitor.hpp8
-rw-r--r--include/armnn/INetwork.hpp9
-rw-r--r--include/armnn/LayerVisitorBase.hpp4
-rw-r--r--src/armnn/InternalTypes.hpp1
-rw-r--r--src/armnn/LayersFwd.hpp2
-rw-r--r--src/armnn/Network.cpp6
-rw-r--r--src/armnn/Network.hpp3
-rw-r--r--src/armnn/layers/FillLayer.cpp51
-rw-r--r--src/armnn/layers/FillLayer.hpp41
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp7
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp1
-rw-r--r--src/armnnSerializer/Serializer.cpp10
-rw-r--r--src/armnnSerializer/Serializer.hpp4
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp8
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp23
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp13
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp20
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
25 files changed, 248 insertions, 1 deletions
diff --git a/Android.mk b/Android.mk
index c9254c7f31..83779486c0 100644
--- a/Android.mk
+++ b/Android.mk
@@ -151,6 +151,7 @@ LOCAL_SRC_FILES := \
src/armnn/layers/ElementwiseBaseLayer.cpp \
src/armnn/layers/ElementwiseUnaryLayer.cpp \
src/armnn/layers/FakeQuantizationLayer.cpp \
+ src/armnn/layers/FillLayer.cpp \
src/armnn/layers/FloorLayer.cpp \
src/armnn/layers/FullyConnectedLayer.cpp \
src/armnn/layers/GatherLayer.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0dc859bea5..92edf8017f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -302,6 +302,8 @@ list(APPEND armnn_sources
src/armnn/layers/ElementwiseUnaryLayer.cpp
src/armnn/layers/FakeQuantizationLayer.hpp
src/armnn/layers/FakeQuantizationLayer.cpp
+ src/armnn/layers/FillLayer.hpp
+ src/armnn/layers/FillLayer.cpp
src/armnn/layers/FloorLayer.hpp
src/armnn/layers/FloorLayer.cpp
src/armnn/layers/FullyConnectedLayer.hpp
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 5d0990e816..653e64701a 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -711,6 +711,25 @@ struct FakeQuantizationDescriptor
float m_Max;
};
+/// A FillDescriptor for the FillLayer
+struct FillDescriptor
+{
+ FillDescriptor()
+ : m_Value(0)
+ {}
+
+ FillDescriptor(const float& value)
+ : m_Value(value)
+ {}
+
+ bool operator ==(const FillDescriptor& rhs) const
+ {
+ return m_Value == rhs.m_Value;
+ }
+
+ float m_Value;
+};
+
/// A ResizeBilinearDescriptor for the ResizeBilinearLayer.
struct ResizeBilinearDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 1c813b534f..e31fb96aec 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -18,6 +18,7 @@ struct DepthwiseConvolution2dDescriptor;
struct DetectionPostProcessDescriptor;
struct ElementwiseUnaryDescriptor;
struct FakeQuantizationDescriptor;
+struct FillDescriptor;
struct FullyConnectedDescriptor;
struct InstanceNormalizationDescriptor;
struct L2NormalizationDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 58509c906c..33389eb25f 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -157,6 +157,11 @@ public:
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsFillSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const FillDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsFloorSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index 530e74f30a..aa5bdba33c 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -184,6 +184,14 @@ public:
virtual void VisitEqualLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
+ /// Function a fill layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param fillDescriptor - Description of the layer
+ /// @param name - Optional name for the layer.
+ virtual void VisitFillLayer(const IConnectableLayer* layer,
+ const FillDescriptor& fillDescriptor,
+ const char* name = nullptr) = 0;
+
/// Function a floor layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 1dd949d038..ade6c52c90 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -213,10 +213,17 @@ public:
/// Add an ElementwiseUnary layer to the network.
/// @param name - Optional name for the layer.
/// @param desc - Descriptor for the elementwiseUnary operation.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
const char* name = nullptr) = 0;
+ /// Add an Fill layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @param fillDescriptor - Descriptor for the fill operation.
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor,
+ const char* name = nullptr) = 0;
+
/// Adds a fully connected layer to the network.
/// @param fullyConnectedDescriptor - Description of the fully connected layer.
/// @param weights - Tensor for the weights data.
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index 95d6bd37bd..0dc5e545e3 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -101,6 +101,10 @@ public:
void VisitEqualLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitFillLayer(const IConnectableLayer*,
+ const FillDescriptor&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitFloorLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index 455cb60d5d..e2ad7a2ea5 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -33,6 +33,7 @@
X(Division) \
X(ElementwiseUnary) \
X(FakeQuantization) \
+ X(Fill) \
X(Floor) \
X(FullyConnected) \
X(Gather) \
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index befe3819d8..575c3e5c68 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -27,6 +27,7 @@
#include "layers/DivisionLayer.hpp"
#include "layers/ElementwiseUnaryLayer.hpp"
#include "layers/FakeQuantizationLayer.hpp"
+#include "layers/FillLayer.hpp"
#include "layers/FloorLayer.hpp"
#include "layers/FullyConnectedLayer.hpp"
#include "layers/GatherLayer.hpp"
@@ -115,6 +116,7 @@ DECLARE_LAYER(DetectionPostProcess)
DECLARE_LAYER(Division)
DECLARE_LAYER(ElementwiseUnary)
DECLARE_LAYER(FakeQuantization)
+DECLARE_LAYER(Fill)
DECLARE_LAYER(Floor)
DECLARE_LAYER(FullyConnected)
DECLARE_LAYER(Gather)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index fa8eaaf998..180c00b0c3 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1186,6 +1186,12 @@ IConnectableLayer* Network::AddElementwiseUnaryLayer(const ElementwiseUnaryDescr
return m_Graph->AddLayer<ElementwiseUnaryLayer>(elementwiseUnaryDescriptor, name);
}
+IConnectableLayer* Network::AddFillLayer(const FillDescriptor& fillDescriptor,
+ const char* name)
+{
+ return m_Graph->AddLayer<FillLayer>(fillDescriptor, name);
+}
+
IConnectableLayer* Network::AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index 1711d2e2dd..cac2b3a0e6 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -97,6 +97,9 @@ public:
IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
const char* name = nullptr) override;
+ IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor,
+ const char* name = nullptr) override;
+
IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,
diff --git a/src/armnn/layers/FillLayer.cpp b/src/armnn/layers/FillLayer.cpp
new file mode 100644
index 0000000000..03f93f76da
--- /dev/null
+++ b/src/armnn/layers/FillLayer.cpp
@@ -0,0 +1,51 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "FillLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <armnn/TypesUtils.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+namespace armnn
+{
+
+FillLayer::FillLayer(const FillDescriptor& param, const char* name)
+ : LayerWithParameters(1, 1, LayerType::Fill, param, name)
+{
+}
+
+std::unique_ptr<IWorkload> FillLayer::CreateWorkload(const IWorkloadFactory& factory) const
+{
+ FillQueueDescriptor descriptor;
+ return factory.CreateFill(descriptor, PrepInfoAndDesc(descriptor) );
+}
+
+FillLayer* FillLayer::Clone(Graph& graph) const
+{
+ return CloneBase<FillLayer>(graph, m_Param, GetName());
+}
+
+void FillLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(1, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+
+ ARMNN_ASSERT(inferredShapes.size() == 1);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "FillLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+}
+
+void FillLayer::Accept(ILayerVisitor& visitor) const
+{
+ visitor.VisitGatherLayer(this, GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/FillLayer.hpp b/src/armnn/layers/FillLayer.hpp
new file mode 100644
index 0000000000..b9a972a27a
--- /dev/null
+++ b/src/armnn/layers/FillLayer.hpp
@@ -0,0 +1,41 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "LayerWithParameters.hpp"
+
+namespace armnn
+{
+
+/// This layer represents a fill operation.
+class FillLayer : public LayerWithParameters<FillDescriptor>
+{
+public:
+ /// Makes a workload for the Fill layer.
+ /// @param [in] factory The workload factory which will create the workload.
+ /// @return A pointer to the created workload, or nullptr if not created.
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer.
+ /// @param [in] graph The graph into which this layer is being cloned.
+ FillLayer* Clone(Graph& graph) const override;
+
+ /// Check if the input tensor shape(s)
+ /// will lead to a valid configuration of @ref FillLayer.
+ void ValidateTensorShapesFromInputs() override;
+
+ void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+ /// Constructor to create a FillLayer.
+ /// @param [in] descriptor to configure the fill operation.
+ /// @param [in] name Optional name for the layer.
+ FillLayer(const FillDescriptor& descriptor, const char* name);
+
+ /// Default destructor
+ ~FillLayer() = default;
+};
+
+} // namespace
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
index 431db2aa0d..e07e497ab8 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
@@ -92,6 +92,12 @@ armnn::ElementwiseUnaryDescriptor GetDescriptor<armnn::ElementwiseUnaryDescripto
}
template<>
+armnn::FillDescriptor GetDescriptor<armnn::FillDescriptor>()
+{
+ return armnn::FillDescriptor(1);
+}
+
+template<>
armnn::InstanceNormalizationDescriptor GetDescriptor<armnn::InstanceNormalizationDescriptor>()
{
armnn::InstanceNormalizationDescriptor descriptor;
@@ -264,6 +270,7 @@ TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(BatchToSpaceNd)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Comparison)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Concat)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(ElementwiseUnary)
+TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Fill)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(InstanceNormalization)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(L2Normalization)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(LogSoftmax)
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
index b9877a8111..c8df505db0 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
@@ -49,6 +49,7 @@ DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Comparison)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Concat)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(DepthToSpace)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(ElementwiseUnary)
+DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Fill)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(InstanceNormalization)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(L2Normalization)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(LogSoftmax)
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 9c62a93e3b..ddd38e18ef 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -468,6 +468,16 @@ void SerializerVisitor::VisitEqualLayer(const armnn::IConnectableLayer* layer, c
CreateAnyLayer(fbEqualLayer.o, serializer::Layer::Layer_EqualLayer);
}
+void SerializerVisitor::VisitFillLayer(const armnn::IConnectableLayer* layer,
+ const armnn::FillDescriptor& fillDescriptor,
+ const char* name)
+{
+ throw UnimplementedException("SerializerVisitor::VisitFillLayer is not implemented");
+ IgnoreUnused(name);
+ IgnoreUnused(layer);
+ IgnoreUnused(fillDescriptor);
+}
+
void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name)
{
IgnoreUnused(name);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 6dd655827d..65d87b7cf7 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -121,6 +121,10 @@ public:
void VisitEqualLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ void VisitFillLayer(const armnn::IConnectableLayer* layer,
+ const armnn::FillDescriptor& fillDescriptor,
+ const char* name = nullptr) override;
+
void VisitFloorLayer(const armnn::IConnectableLayer *layer,
const char *name = nullptr) override;
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index c55f51d315..e509a7b929 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -248,6 +248,14 @@ bool LayerSupportBase::IsFakeQuantizationSupported(const TensorInfo& /*input*/,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsFillSupported(const TensorInfo& /*input*/,
+ const TensorInfo& /*output*/,
+ const FillDescriptor& /*descriptor*/,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool LayerSupportBase::IsFloorSupported(const TensorInfo& /*input*/,
const TensorInfo& /*output*/,
Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index fcc3326601..aff4529417 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -143,6 +143,11 @@ public:
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ virtual bool IsFillSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const FillDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsFloorSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 2060093015..3949fa945d 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -966,6 +966,29 @@ void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
"output");
}
+void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"FillQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
+
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::Signed32
+ };
+
+ ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+}
+
void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"FullyConnectedQueueDescriptor"};
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index adce5570d3..ba9b0f394b 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -137,6 +137,19 @@ struct ArgMinMaxQueueDescriptor : QueueDescriptorWithParameters<ArgMinMaxDescrip
void Validate(const WorkloadInfo& workloadInfo) const;
};
+// Fill layer workload data.
+struct FillQueueDescriptor : QueueDescriptorWithParameters<FillDescriptor>
+{
+ FillQueueDescriptor()
+ : m_Value(nullptr)
+ {
+ }
+
+ const ConstCpuTensorHandle* m_Value;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
// Fully connected layer workload data.
struct FullyConnectedQueueDescriptor : QueueDescriptorWithParameters<FullyConnectedDescriptor>
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 0a13c99ab8..d2565cf21d 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -312,6 +312,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::Fill:
+ {
+ auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ const FillDescriptor& descriptor = cLayer->GetParameters();
+
+ result = layerSupportObject->IsFillSupported(
+ OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ descriptor,
+ reason);
+ break;
+ }
case LayerType::FakeQuantization:
{
auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
@@ -1336,6 +1350,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQu
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateFill(const FillQueueDescriptor& /*descriptor*/,
+ const WorkloadInfo& /*info*/) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
const WorkloadInfo& /*info*/) const
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index 89c073c170..e373a4f218 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -122,6 +122,9 @@ public:
virtual std::unique_ptr<IWorkload> CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ virtual std::unique_ptr<IWorkload> CreateFill(const FillQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
virtual std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 0780f4bd27..dcd073d279 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -545,6 +545,8 @@ DECLARE_LAYER_POLICY_2_PARAM(ElementwiseUnary)
DECLARE_LAYER_POLICY_2_PARAM(FakeQuantization)
+DECLARE_LAYER_POLICY_2_PARAM(Fill)
+
DECLARE_LAYER_POLICY_1_PARAM(Floor)
DECLARE_LAYER_POLICY_2_PARAM(FullyConnected)