aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertConstants.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/ConvertConstants.hpp')
-rw-r--r--src/armnn/optimizations/ConvertConstants.hpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/armnn/optimizations/ConvertConstants.hpp b/src/armnn/optimizations/ConvertConstants.hpp
index df5a5b4f67..66b3d2685a 100644
--- a/src/armnn/optimizations/ConvertConstants.hpp
+++ b/src/armnn/optimizations/ConvertConstants.hpp
@@ -9,7 +9,7 @@
#include <armnnUtils/FloatingPointConverter.hpp>
-#include <backendsCommon/CpuTensorHandle.hpp>
+#include <backendsCommon/TensorHandle.hpp>
#include <armnn/utility/IgnoreUnused.hpp>
@@ -23,7 +23,7 @@ namespace optimizations
struct BFloat16ToFloat32
{
- static void Func(std::shared_ptr<ConstCpuTensorHandle>& handle)
+ static void Func(std::shared_ptr<ConstTensorHandle>& handle)
{
const TensorInfo& info = handle->GetTensorInfo();
@@ -37,14 +37,14 @@ struct BFloat16ToFloat32
TensorInfo newInfo(info.GetShape(), DataType::Float32);
ConstTensor newInput(newInfo, newValues);
- handle.reset(new ScopedCpuTensorHandle(newInput));
+ handle.reset(new ScopedTensorHandle(newInput));
}
}
};
struct Float16ToFloat32
{
- static void Func(std::shared_ptr<ConstCpuTensorHandle>& handle)
+ static void Func(std::shared_ptr<ConstTensorHandle>& handle)
{
const TensorInfo& info = handle->GetTensorInfo();
@@ -58,14 +58,14 @@ struct Float16ToFloat32
TensorInfo newInfo(info.GetShape(), DataType::Float32);
ConstTensor newInput(newInfo, newValues);
- handle.reset(new ScopedCpuTensorHandle(newInput));
+ handle.reset(new ScopedTensorHandle(newInput));
}
}
};
struct Float32ToBFloat16
{
- static void Func(std::shared_ptr<ConstCpuTensorHandle>& handle)
+ static void Func(std::shared_ptr<ConstTensorHandle>& handle)
{
const TensorInfo& info = handle->GetTensorInfo();
@@ -79,14 +79,14 @@ struct Float32ToBFloat16
TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
ConstTensor newInput(newInfo, newValues);
- handle.reset(new ScopedCpuTensorHandle(newInput));
+ handle.reset(new ScopedTensorHandle(newInput));
}
}
};
struct Float32ToFloat16
{
- static void Func(std::shared_ptr<ConstCpuTensorHandle>& handle)
+ static void Func(std::shared_ptr<ConstTensorHandle>& handle)
{
const TensorInfo& info = handle->GetTensorInfo();
@@ -100,7 +100,7 @@ struct Float32ToFloat16
TensorInfo newInfo(info.GetShape(), DataType::Float16);
ConstTensor newInput(newInfo, newValues);
- handle.reset(new ScopedCpuTensorHandle(newInput));
+ handle.reset(new ScopedTensorHandle(newInput));
}
}
};