aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertConstants.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/ConvertConstants.hpp')
-rw-r--r--src/armnn/optimizations/ConvertConstants.hpp98
1 files changed, 98 insertions, 0 deletions
diff --git a/src/armnn/optimizations/ConvertConstants.hpp b/src/armnn/optimizations/ConvertConstants.hpp
new file mode 100644
index 0000000000..d2dd650665
--- /dev/null
+++ b/src/armnn/optimizations/ConvertConstants.hpp
@@ -0,0 +1,98 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#pragma once
+
+#include "Optimization.hpp"
+#include "backends/CpuTensorHandle.hpp"
+#include "Half.hpp"
+#include "FloatingPointConverter.hpp"
+
+namespace armnn
+{
+namespace optimizations
+{
+
+struct Float16ToFloat32
+{
+ static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
+ {
+ const TensorInfo& info = handle->GetTensorInfo();
+
+ if (info.GetDataType() == DataType::Float16)
+ {
+ std::vector<float> newValues(info.GetNumElements());
+
+ armnnUtils::FloatingPointConverter::ConvertFloat16To32(handle->GetTensor<Half>(),
+ info.GetNumElements(),
+ newValues.data());
+
+ TensorInfo newInfo(info.GetShape(), DataType::Float32);
+ ConstTensor newInput(newInfo, newValues);
+ handle.reset(new ScopedCpuTensorHandle(newInput));
+ }
+ }
+};
+
+struct Float32ToFloat16
+{
+ static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
+ {
+ const TensorInfo& info = handle->GetTensorInfo();
+
+ if (info.GetDataType() == DataType::Float32)
+ {
+ std::vector<Half> newValues(info.GetNumElements());
+
+ armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetTensor<float>(),
+ info.GetNumElements(),
+ newValues.data());
+
+ TensorInfo newInfo(info.GetShape(), DataType::Float16);
+ ConstTensor newInput(newInfo, newValues);
+ handle.reset(new ScopedCpuTensorHandle(newInput));
+ }
+ }
+};
+
+template<typename Converter, typename Predicate>
+class ConvertConstants : public Optimization
+{
+public:
+ ConvertConstants() = default;
+ ConvertConstants(const ConvertConstants&) = default;
+ virtual ~ConvertConstants() = default;
+
+ void Run(Graph& graph, Layer& layer) const override
+ {
+ if (Predicate::Test(layer))
+ {
+ layer.OperateOnConstantTensors(Converter::Func);
+ }
+ }
+protected:
+};
+
+struct IsFloat32Layer
+{
+ static bool Test(const Layer& layer)
+ {
+ return layer.GetDataType() == DataType::Float32;
+ }
+};
+
+struct IsFloat16Layer
+{
+ static bool Test(const Layer& layer)
+ {
+ return layer.GetDataType() == DataType::Float16;
+ }
+};
+
+using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>;
+using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>;
+
+} //namespace optimizations
+} //namespace armnn