alpaka
Abstraction Library for Parallel Kernel Acceleration
Loading...
Searching...
No Matches
DevGenericSycl.hpp
Go to the documentation of this file.
1/* Copyright 2024 Jan Stephan, Antonio Di Pilato, Luca Ferragina, Aurora Perego, Andrea Bocci
2 * SPDX-License-Identifier: MPL-2.0
3 */
4
5#pragma once
6
7#include "alpaka/acc/Tag.hpp"
10#include "alpaka/core/Sycl.hpp"
11#include "alpaka/dev/Traits.hpp"
19
20#include <algorithm>
21#include <cstddef>
22#include <functional>
23#include <memory>
24#include <mutex>
25#include <shared_mutex>
26#include <string>
27#include <utility>
28#include <vector>
29
30#ifdef ALPAKA_ACC_SYCL_ENABLED
31
32# include <sycl/sycl.hpp>
33
34namespace alpaka
35{
36 namespace trait
37 {
38 template<typename TPlatform, typename TSfinae>
39 struct GetDevByIdx;
40 } // namespace trait
41
42 template<concepts::Tag TTag>
43 using QueueGenericSyclBlocking = detail::QueueGenericSyclBase<TTag, true>;
44
45 template<concepts::Tag TTag>
46 using QueueGenericSyclNonBlocking = detail::QueueGenericSyclBase<TTag, false>;
47
48 template<concepts::Tag TTag>
49 struct PlatformGenericSycl;
50
51 template<typename TElem, typename TDim, typename TIdx, concepts::Tag TTag>
52 class BufGenericSycl;
53
54 namespace detail
55 {
56 class DevGenericSyclImpl
57 {
58 public:
59 DevGenericSyclImpl(sycl::device device, sycl::context context)
60 : m_device{std::move(device)}
61 , m_context{std::move(context)}
62 {
63 }
64
65 // Don't call this without locking first!
66 auto clean_queues() -> void
67 {
68 // Clean up dead queues
69 auto const start = std::begin(m_queues);
70 auto const old_end = std::end(m_queues);
71 auto const new_end = std::remove_if(start, old_end, [](auto q_ptr) { return q_ptr.expired(); });
72 m_queues.erase(new_end, old_end);
73 }
74
75 auto register_queue(std::shared_ptr<QueueGenericSyclImpl> const& queue) -> void
76 {
77 std::lock_guard<std::shared_mutex> lock{m_mutex};
78
79 clean_queues();
80 m_queues.emplace_back(queue);
81 }
82
83 auto register_dependency(sycl::event event) -> void
84 {
85 std::shared_lock<std::shared_mutex> lock{m_mutex};
86
87 for(auto& q_ptr : m_queues)
88 {
89 if(auto ptr = q_ptr.lock(); ptr != nullptr)
90 ptr->register_dependency(event);
91 }
92 }
93
94 auto wait()
95 {
96 std::shared_lock<std::shared_mutex> lock{m_mutex};
97
98 for(auto& q_ptr : m_queues)
99 {
100 if(auto ptr = q_ptr.lock(); ptr != nullptr)
101 ptr->wait();
102 }
103 }
104
105 auto get_device() const -> sycl::device
106 {
107 return m_device;
108 }
109
110 auto get_context() const -> sycl::context
111 {
112 return m_context;
113 }
114
115 private:
116 sycl::device m_device;
117 sycl::context m_context;
118 std::vector<std::weak_ptr<QueueGenericSyclImpl>> m_queues;
119 std::shared_mutex mutable m_mutex;
120 };
121 } // namespace detail
122
123 //! The SYCL device handle.
124 template<concepts::Tag TTag>
125 class DevGenericSycl
126 : public interface::Implements<ConceptCurrentThreadWaitFor, DevGenericSycl<TTag>>
127 , public interface::Implements<ConceptDev, DevGenericSycl<TTag>>
128 {
129 friend struct trait::GetDevByIdx<PlatformGenericSycl<TTag>>;
130
131 public:
132 DevGenericSycl(sycl::device device, sycl::context context)
133 : m_impl{std::make_shared<detail::DevGenericSyclImpl>(std::move(device), std::move(context))}
134 {
135 }
136
137 friend auto operator==(DevGenericSycl const& lhs, DevGenericSycl const& rhs) -> bool
138 {
139 return (lhs.m_impl == rhs.m_impl);
140 }
141
142 friend auto operator!=(DevGenericSycl const& lhs, DevGenericSycl const& rhs) -> bool
143 {
144 return !(lhs == rhs);
145 }
146
147 [[nodiscard]] auto getNativeHandle() const -> std::pair<sycl::device, sycl::context>
148 {
149 return std::make_pair(m_impl->get_device(), m_impl->get_context());
150 }
151
152 std::shared_ptr<detail::DevGenericSyclImpl> m_impl;
153 };
154
155 namespace trait
156 {
157 //! The SYCL device name get trait specialization.
158 template<concepts::Tag TTag>
159 struct GetName<DevGenericSycl<TTag>>
160 {
161 static auto getName(DevGenericSycl<TTag> const& dev) -> std::string
162 {
163 auto const device = dev.getNativeHandle().first;
164 return device.template get_info<sycl::info::device::name>();
165 }
166 };
167
168 //! The SYCL device available memory get trait specialization.
169 template<concepts::Tag TTag>
170 struct GetMemBytes<DevGenericSycl<TTag>>
171 {
172 static auto getMemBytes(DevGenericSycl<TTag> const& dev) -> std::size_t
173 {
174 auto const device = dev.getNativeHandle().first;
175 return device.template get_info<sycl::info::device::global_mem_size>();
176 }
177 };
178
179 //! The SYCL device free memory get trait specialization.
180 template<concepts::Tag TTag>
181 struct GetFreeMemBytes<DevGenericSycl<TTag>>
182 {
183 static auto getFreeMemBytes(DevGenericSycl<TTag> const& /* dev */) -> std::size_t
184 {
185 static_assert(
186 !sizeof(PlatformGenericSycl<TTag>),
187 "Querying free device memory not supported for SYCL devices.");
188 return std::size_t{};
189 }
190 };
191
192 //! The SYCL device warp size get trait specialization.
193 template<concepts::Tag TTag>
194 struct GetWarpSizes<DevGenericSycl<TTag>>
195 {
196 static auto getWarpSizes(DevGenericSycl<TTag> const& dev) -> std::vector<std::size_t>
197 {
198 auto const device = dev.getNativeHandle().first;
199 std::vector<std::size_t> warp_sizes = device.template get_info<sycl::info::device::sub_group_sizes>();
200 // The CPU runtime supports a sub-group size of 64, but the SYCL implementation currently does not
201 auto find64 = std::find(warp_sizes.begin(), warp_sizes.end(), 64);
202 if(find64 != warp_sizes.end())
203 warp_sizes.erase(find64);
204 // Sort the warp sizes in decreasing order
205 std::sort(warp_sizes.begin(), warp_sizes.end(), std::greater<>{});
206 return warp_sizes;
207 }
208 };
209
210 //! The SYCL device preferred warp size get trait specialization.
211 template<concepts::Tag TTag>
212 struct GetPreferredWarpSize<DevGenericSycl<TTag>>
213 {
214 static auto getPreferredWarpSize(DevGenericSycl<TTag> const& dev) -> std::size_t
215 {
216 return GetWarpSizes<DevGenericSycl<TTag>>::getWarpSizes(dev).front();
217 }
218 };
219
220 //! The SYCL device reset trait specialization.
221 template<concepts::Tag TTag>
222 struct Reset<DevGenericSycl<TTag>>
223 {
224 static auto reset(DevGenericSycl<TTag> const&) -> void
225 {
226 static_assert(
227 !sizeof(PlatformGenericSycl<TTag>),
228 "Explicit device reset not supported for SYCL devices");
229 }
230 };
231
232 //! The SYCL device native handle trait specialization.
233 template<concepts::Tag TTag>
234 struct NativeHandle<DevGenericSycl<TTag>>
235 {
236 [[nodiscard]] static auto getNativeHandle(DevGenericSycl<TTag> const& dev)
237 {
238 return dev.getNativeHandle();
239 }
240 };
241
242 //! The SYCL device memory buffer type trait specialization.
243 template<typename TElem, typename TDim, typename TIdx, concepts::Tag TTag>
244 struct BufType<DevGenericSycl<TTag>, TElem, TDim, TIdx>
245 {
246 using type = BufGenericSycl<TElem, TDim, TIdx, TTag>;
247 };
248
249 //! The SYCL device platform type trait specialization.
250 template<concepts::Tag TTag>
251 struct PlatformType<DevGenericSycl<TTag>>
252 {
253 using type = PlatformGenericSycl<TTag>;
254 };
255
256 //! The thread SYCL device wait specialization.
257 template<concepts::Tag TTag>
258 struct CurrentThreadWaitFor<DevGenericSycl<TTag>>
259 {
260 static auto currentThreadWaitFor(DevGenericSycl<TTag> const& dev) -> void
261 {
262 dev.m_impl->wait();
263 }
264 };
265
266 //! The SYCL blocking queue trait specialization.
267 template<concepts::Tag TTag>
268 struct QueueType<DevGenericSycl<TTag>, Blocking>
269 {
270 using type = QueueGenericSyclBlocking<TTag>;
271 };
272
273 //! The SYCL non-blocking queue trait specialization.
274 template<concepts::Tag TTag>
275 struct QueueType<DevGenericSycl<TTag>, NonBlocking>
276 {
277 using type = QueueGenericSyclNonBlocking<TTag>;
278 };
279
280 } // namespace trait
281} // namespace alpaka
282
283#endif
constexpr ALPAKA_FN_HOST_ACC bool operator==(Complex< T > const &lhs, Complex< T > const &rhs)
Equality of two complex numbers.
Definition Complex.hpp:294
constexpr ALPAKA_FN_HOST_ACC bool operator!=(Complex< T > const &lhs, Complex< T > const &rhs)
Inequality of two complex numbers.
Definition Complex.hpp:320
The alpaka accelerator library.
ALPAKA_FN_HOST constexpr auto getPreferredWarpSize(TDev const &dev) -> std::size_t
Definition Traits.hpp:118
ALPAKA_FN_HOST auto getName(TDev const &dev) -> std::string
Definition Traits.hpp:87
ALPAKA_FN_HOST auto getWarpSizes(TDev const &dev) -> std::vector< std::size_t >
Definition Traits.hpp:111
ALPAKA_FN_HOST auto reset(TDev const &dev) -> void
Resets the device. What this method does is dependent on the accelerator.
Definition Traits.hpp:126
ALPAKA_FN_HOST auto getFreeMemBytes(TDev const &dev) -> std::size_t
Definition Traits.hpp:104
ALPAKA_FN_HOST auto getMemBytes(TDev const &dev) -> std::size_t
Definition Traits.hpp:95
decltype(getNativeHandle(std::declval< TImpl >())) NativeHandle
Alias to the type of the native handle.
Definition Traits.hpp:36
ALPAKA_FN_HOST auto getNativeHandle(TImpl const &impl)
Get the native handle of the alpaka object. It will return the alpaka object handle if there is any,...
Definition Traits.hpp:29
ALPAKA_FN_HOST auto wait(TAwaited const &awaited) -> void
Waits the thread for the completion of the given awaited action to complete.
Definition Traits.hpp:34
STL namespace.
static auto getNativeHandle(TImpl const &)
Definition Traits.hpp:18