24 #include <shared_mutex>
29 #ifdef ALPAKA_ACC_SYCL_ENABLED
31 # include <sycl/sycl.hpp>
35 template<
typename TElem,
typename TDim,
typename TIdx,
typename TDev>
40 class DevGenericSyclImpl
43 DevGenericSyclImpl(sycl::device device, sycl::context context)
44 : m_device{std::move(device)}
45 , m_context{std::move(context)}
50 auto clean_queues() ->
void
54 auto const old_end =
std::end(m_queues);
55 auto const new_end = std::remove_if(start, old_end, [](
auto q_ptr) {
return q_ptr.expired(); });
56 m_queues.erase(new_end, old_end);
59 auto register_queue(std::shared_ptr<QueueGenericSyclImpl>
const& queue) ->
void
61 std::lock_guard<std::shared_mutex> lock{m_mutex};
64 m_queues.emplace_back(queue);
67 auto register_dependency(sycl::event event) ->
void
69 std::shared_lock<std::shared_mutex> lock{m_mutex};
71 for(
auto& q_ptr : m_queues)
73 if(
auto ptr = q_ptr.lock(); ptr !=
nullptr)
74 ptr->register_dependency(event);
80 std::shared_lock<std::shared_mutex> lock{m_mutex};
82 for(
auto& q_ptr : m_queues)
84 if(
auto ptr = q_ptr.lock(); ptr !=
nullptr)
89 auto get_device() const -> sycl::device
94 auto get_context() const -> sycl::context
100 sycl::device m_device;
101 sycl::context m_context;
102 std::vector<std::weak_ptr<QueueGenericSyclImpl>> m_queues;
103 std::shared_mutex
mutable m_mutex;
108 template<
typename TPlatform>
110 :
public concepts::Implements<ConceptCurrentThreadWaitFor, DevGenericSycl<TPlatform>>
111 ,
public concepts::Implements<ConceptDev, DevGenericSycl<TPlatform>>
114 DevGenericSycl(sycl::device device, sycl::context context)
115 : m_impl{std::make_shared<detail::DevGenericSyclImpl>(std::move(device), std::move(context))}
119 friend auto operator==(DevGenericSycl
const& lhs, DevGenericSycl
const& rhs) ->
bool
121 return (lhs.m_impl == rhs.m_impl);
124 friend auto operator!=(DevGenericSycl
const& lhs, DevGenericSycl
const& rhs) ->
bool
126 return !(lhs == rhs);
129 [[nodiscard]]
auto getNativeHandle() const -> std::pair<sycl::device, sycl::context>
131 return std::make_pair(m_impl->get_device(), m_impl->get_context());
134 std::shared_ptr<detail::DevGenericSyclImpl> m_impl;
141 template<
typename TPlatform>
142 struct GetName<DevGenericSycl<TPlatform>>
144 static auto getName(DevGenericSycl<TPlatform>
const& dev) -> std::string
146 auto const device = dev.getNativeHandle().first;
147 return device.template get_info<sycl::info::device::name>();
152 template<
typename TPlatform>
153 struct GetMemBytes<DevGenericSycl<TPlatform>>
155 static auto getMemBytes(DevGenericSycl<TPlatform>
const& dev) -> std::size_t
157 auto const device = dev.getNativeHandle().first;
158 return device.template get_info<sycl::info::device::global_mem_size>();
163 template<
typename TPlatform>
164 struct GetFreeMemBytes<DevGenericSycl<TPlatform>>
166 static auto getFreeMemBytes(DevGenericSycl<TPlatform>
const& ) -> std::size_t
168 static_assert(!
sizeof(TPlatform),
"Querying free device memory not supported for SYCL devices.");
169 return std::size_t{};
174 template<
typename TPlatform>
175 struct GetWarpSizes<DevGenericSycl<TPlatform>>
177 static auto getWarpSizes(DevGenericSycl<TPlatform>
const& dev) -> std::vector<std::size_t>
179 auto const device = dev.getNativeHandle().first;
180 std::vector<std::size_t> warp_sizes = device.template get_info<sycl::info::device::sub_group_sizes>();
182 auto find64 = std::find(warp_sizes.begin(), warp_sizes.end(), 64);
183 if(find64 != warp_sizes.end())
184 warp_sizes.erase(find64);
186 std::sort(warp_sizes.begin(), warp_sizes.end(), std::greater<>{});
192 template<
typename TPlatform>
193 struct GetPreferredWarpSize<DevGenericSycl<TPlatform>>
197 return GetWarpSizes<DevGenericSycl<TPlatform>>
::getWarpSizes(dev).front();
202 template<
typename TPlatform>
203 struct Reset<DevGenericSycl<TPlatform>>
205 static auto reset(DevGenericSycl<TPlatform>
const&) ->
void
207 static_assert(!
sizeof(TPlatform),
"Explicit device reset not supported for SYCL devices");
212 template<
typename TPlatform>
215 [[nodiscard]]
static auto getNativeHandle(DevGenericSycl<TPlatform>
const& dev)
217 return dev.getNativeHandle();
222 template<
typename TElem,
typename TDim,
typename TIdx,
typename TPlatform>
223 struct BufType<DevGenericSycl<TPlatform>, TElem, TDim, TIdx>
225 using type = BufGenericSycl<TElem, TDim, TIdx, TPlatform>;
229 template<
typename TPlatform>
230 struct PlatformType<DevGenericSycl<TPlatform>>
232 using type = TPlatform;
236 template<
typename TPlatform>
237 struct CurrentThreadWaitFor<DevGenericSycl<TPlatform>>
239 static auto currentThreadWaitFor(DevGenericSycl<TPlatform>
const& dev) ->
void
246 template<
typename TPlatform>
247 struct QueueType<DevGenericSycl<TPlatform>, Blocking>
249 using type = detail::QueueGenericSyclBase<DevGenericSycl<TPlatform>,
true>;
253 template<
typename TPlatform>
254 struct QueueType<DevGenericSycl<TPlatform>, NonBlocking>
256 using type = detail::QueueGenericSyclBase<DevGenericSycl<TPlatform>,
false>;
ALPAKA_FN_HOST auto end(TView &view) -> Iterator< TView >
ALPAKA_FN_HOST auto begin(TView &view) -> Iterator< TView >
The alpaka accelerator library.
ALPAKA_FN_HOST auto getName(TDev const &dev) -> std::string
ALPAKA_FN_HOST auto getWarpSizes(TDev const &dev) -> std::vector< std::size_t >
ALPAKA_FN_HOST auto reset(TDev const &dev) -> void
Resets the device. What this method does is dependent on the accelerator.
ALPAKA_FN_HOST auto getFreeMemBytes(TDev const &dev) -> std::size_t
ALPAKA_FN_HOST auto getMemBytes(TDev const &dev) -> std::size_t
constexpr ALPAKA_FN_HOST_ACC bool operator==(Complex< T > const &lhs, Complex< T > const &rhs)
Equality of two complex numbers.
constexpr ALPAKA_FN_HOST auto getPreferredWarpSize(TDev const &dev) -> std::size_t
decltype(getNativeHandle(std::declval< TImpl >())) NativeHandle
Alias to the type of the native handle.
ALPAKA_FN_HOST auto getNativeHandle(TImpl const &impl)
Get the native handle of the alpaka object. It will return the alpaka object handle if there is any,...
ALPAKA_FN_HOST auto wait(TAwaited const &awaited) -> void
Waits the thread for the completion of the given awaited action to complete.
constexpr ALPAKA_FN_HOST_ACC bool operator!=(Complex< T > const &lhs, Complex< T > const &rhs)
Inequality of two complex numbers.
static auto getNativeHandle(TImpl const &)