aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-11 14:51:27 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-13 09:49:42 +0000
commit44179c372eea9f17c96cbf50ee383e57e14d70a6 (patch)
tree2a2971c2db67426107b21d9a045cfa46a4a1663a /src/backends/reference/workloads
parente9b5d2989abc8008df7ff3ea287ee896ee1121a6 (diff)
downloadarmnn-44179c372eea9f17c96cbf50ee383e57e14d70a6.tar.gz
IVGCVSW-4511 Add BFloat16 to RefLayerSupport and unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ifaae4d5aac468ba927b2c6a4bf31b8c8522aeb2e
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/Pad.cpp7
-rw-r--r--src/backends/reference/workloads/RefPadWorkload.cpp1
-rw-r--r--src/backends/reference/workloads/RefPadWorkload.hpp1
-rw-r--r--src/backends/reference/workloads/RefPermuteWorkload.cpp1
-rw-r--r--src/backends/reference/workloads/RefPermuteWorkload.hpp1
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.cpp1
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.hpp1
7 files changed, 13 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/Pad.cpp b/src/backends/reference/workloads/Pad.cpp
index 9fedb44f96..ffdd469609 100644
--- a/src/backends/reference/workloads/Pad.cpp
+++ b/src/backends/reference/workloads/Pad.cpp
@@ -152,6 +152,13 @@ void Pad(const TensorInfo& inputInfo,
}
}
+template void Pad<BFloat16>(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
+ const BFloat16* inputData,
+ BFloat16* outData,
+ const float padValue);
+
template void Pad<float>(const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
diff --git a/src/backends/reference/workloads/RefPadWorkload.cpp b/src/backends/reference/workloads/RefPadWorkload.cpp
index 356f6b1172..777682d70c 100644
--- a/src/backends/reference/workloads/RefPadWorkload.cpp
+++ b/src/backends/reference/workloads/RefPadWorkload.cpp
@@ -33,6 +33,7 @@ void RefPadWorkload<DataType>::Execute() const
Pad(inputInfo, outputInfo, m_Data.m_Parameters.m_PadList, inputData, outputData, m_Data.m_Parameters.m_PadValue);
}
+template class RefPadWorkload<DataType::BFloat16>;
template class RefPadWorkload<DataType::Float32>;
template class RefPadWorkload<DataType::Float16>;
template class RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPadWorkload.hpp b/src/backends/reference/workloads/RefPadWorkload.hpp
index 28fb55386e..5134ac8bff 100644
--- a/src/backends/reference/workloads/RefPadWorkload.hpp
+++ b/src/backends/reference/workloads/RefPadWorkload.hpp
@@ -30,6 +30,7 @@ public:
void Execute() const override;
};
+using RefPadBFloat16Workload = RefPadWorkload<DataType::BFloat16>;
using RefPadFloat32Workload = RefPadWorkload<DataType::Float32>;
using RefPadFloat16Workload = RefPadWorkload<DataType::Float16>;
using RefPadQAsymm8Workload = RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.cpp b/src/backends/reference/workloads/RefPermuteWorkload.cpp
index d0e1431ffd..5751ed80a3 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.cpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.cpp
@@ -28,6 +28,7 @@ void RefPermuteWorkload<DataType>::Execute() const
src->Map(), dst->Map(), sizeof(T));
}
+template class RefPermuteWorkload<DataType::BFloat16>;
template class RefPermuteWorkload<DataType::Float16>;
template class RefPermuteWorkload<DataType::Float32>;
template class RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.hpp b/src/backends/reference/workloads/RefPermuteWorkload.hpp
index 00a33850aa..a8d308e47c 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.hpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.hpp
@@ -27,6 +27,7 @@ public:
void Execute() const override;
};
+using RefPermuteBFloat16Workload = RefPermuteWorkload<DataType::BFloat16>;
using RefPermuteFloat16Workload = RefPermuteWorkload<DataType::Float16>;
using RefPermuteFloat32Workload = RefPermuteWorkload<DataType::Float32>;
using RefPermuteQAsymm8Workload = RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp
index 6bdfb2111d..242668b6b1 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.cpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp
@@ -27,6 +27,7 @@ void RefTransposeWorkload<DataType>::Execute() const
armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T));
}
+template class RefTransposeWorkload<DataType::BFloat16>;
template class RefTransposeWorkload<DataType::Float16>;
template class RefTransposeWorkload<DataType::Float32>;
template class RefTransposeWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp
index 4b1c3d303b..dcfe618b75 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.hpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp
@@ -27,6 +27,7 @@ public:
void Execute() const override;
};
+using RefTransposeBFloat16Workload = RefTransposeWorkload<DataType::BFloat16>;
using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>;
using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>;
using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>;