25 #include <shared_mutex>
30 #ifdef ALPAKA_ACC_SYCL_ENABLED
32 # include <sycl/sycl.hpp>
38 template<
typename TPlatform,
typename TSfinae>
42 template<concepts::Tag TTag>
43 using QueueGenericSyclBlocking = detail::QueueGenericSyclBase<TTag, true>;
45 template<concepts::Tag TTag>
46 using QueueGenericSyclNonBlocking = detail::QueueGenericSyclBase<TTag, false>;
48 template<concepts::Tag TTag>
49 struct PlatformGenericSycl;
51 template<
typename TElem,
typename TDim,
typename TIdx, concepts::Tag TTag>
56 class DevGenericSyclImpl
59 DevGenericSyclImpl(sycl::device device, sycl::context context)
60 : m_device{std::move(device)}
61 , m_context{std::move(context)}
66 auto clean_queues() ->
void
70 auto const old_end =
std::end(m_queues);
71 auto const new_end = std::remove_if(start, old_end, [](
auto q_ptr) {
return q_ptr.expired(); });
72 m_queues.erase(new_end, old_end);
75 auto register_queue(std::shared_ptr<QueueGenericSyclImpl>
const& queue) ->
void
77 std::lock_guard<std::shared_mutex> lock{m_mutex};
80 m_queues.emplace_back(queue);
83 auto register_dependency(sycl::event event) ->
void
85 std::shared_lock<std::shared_mutex> lock{m_mutex};
87 for(
auto& q_ptr : m_queues)
89 if(
auto ptr = q_ptr.lock(); ptr !=
nullptr)
90 ptr->register_dependency(event);
96 std::shared_lock<std::shared_mutex> lock{m_mutex};
98 for(
auto& q_ptr : m_queues)
100 if(
auto ptr = q_ptr.lock(); ptr !=
nullptr)
105 auto get_device() const -> sycl::device
110 auto get_context() const -> sycl::context
116 sycl::device m_device;
117 sycl::context m_context;
118 std::vector<std::weak_ptr<QueueGenericSyclImpl>> m_queues;
119 std::shared_mutex
mutable m_mutex;
124 template<concepts::Tag TTag>
126 :
public interface::Implements<ConceptCurrentThreadWaitFor, DevGenericSycl<TTag>>
127 ,
public interface::Implements<ConceptDev, DevGenericSycl<TTag>>
129 friend struct trait::GetDevByIdx<PlatformGenericSycl<TTag>>;
132 DevGenericSycl(sycl::device device, sycl::context context)
133 : m_impl{std::make_shared<detail::DevGenericSyclImpl>(std::move(device), std::move(context))}
137 friend auto operator==(DevGenericSycl
const& lhs, DevGenericSycl
const& rhs) ->
bool
139 return (lhs.m_impl == rhs.m_impl);
142 friend auto operator!=(DevGenericSycl
const& lhs, DevGenericSycl
const& rhs) ->
bool
144 return !(lhs == rhs);
147 [[nodiscard]]
auto getNativeHandle() const -> std::pair<sycl::device, sycl::context>
149 return std::make_pair(m_impl->get_device(), m_impl->get_context());
152 std::shared_ptr<detail::DevGenericSyclImpl> m_impl;
158 template<concepts::Tag TTag>
159 struct GetName<DevGenericSycl<TTag>>
161 static auto getName(DevGenericSycl<TTag>
const& dev) -> std::string
163 auto const device = dev.getNativeHandle().first;
164 return device.template get_info<sycl::info::device::name>();
169 template<concepts::Tag TTag>
170 struct GetMemBytes<DevGenericSycl<TTag>>
172 static auto getMemBytes(DevGenericSycl<TTag>
const& dev) -> std::size_t
174 auto const device = dev.getNativeHandle().first;
175 return device.template get_info<sycl::info::device::global_mem_size>();
180 template<concepts::Tag TTag>
181 struct GetFreeMemBytes<DevGenericSycl<TTag>>
183 static auto getFreeMemBytes(DevGenericSycl<TTag>
const& ) -> std::size_t
186 !
sizeof(PlatformGenericSycl<TTag>),
187 "Querying free device memory not supported for SYCL devices.");
188 return std::size_t{};
193 template<concepts::Tag TTag>
194 struct GetWarpSizes<DevGenericSycl<TTag>>
196 static auto getWarpSizes(DevGenericSycl<TTag>
const& dev) -> std::vector<std::size_t>
198 auto const device = dev.getNativeHandle().first;
199 std::vector<std::size_t> warp_sizes = device.template get_info<sycl::info::device::sub_group_sizes>();
201 auto find64 = std::find(warp_sizes.begin(), warp_sizes.end(), 64);
202 if(find64 != warp_sizes.end())
203 warp_sizes.erase(find64);
205 std::sort(warp_sizes.begin(), warp_sizes.end(), std::greater<>{});
211 template<concepts::Tag TTag>
212 struct GetPreferredWarpSize<DevGenericSycl<TTag>>
216 return GetWarpSizes<DevGenericSycl<TTag>>
::getWarpSizes(dev).front();
221 template<concepts::Tag TTag>
222 struct Reset<DevGenericSycl<TTag>>
224 static auto reset(DevGenericSycl<TTag>
const&) ->
void
227 !
sizeof(PlatformGenericSycl<TTag>),
228 "Explicit device reset not supported for SYCL devices");
233 template<concepts::Tag TTag>
236 [[nodiscard]]
static auto getNativeHandle(DevGenericSycl<TTag>
const& dev)
238 return dev.getNativeHandle();
243 template<
typename TElem,
typename TDim,
typename TIdx, concepts::Tag TTag>
244 struct BufType<DevGenericSycl<TTag>, TElem, TDim, TIdx>
246 using type = BufGenericSycl<TElem, TDim, TIdx, TTag>;
250 template<concepts::Tag TTag>
251 struct PlatformType<DevGenericSycl<TTag>>
253 using type = PlatformGenericSycl<TTag>;
257 template<concepts::Tag TTag>
258 struct CurrentThreadWaitFor<DevGenericSycl<TTag>>
260 static auto currentThreadWaitFor(DevGenericSycl<TTag>
const& dev) ->
void
267 template<concepts::Tag TTag>
268 struct QueueType<DevGenericSycl<TTag>, Blocking>
270 using type = QueueGenericSyclBlocking<TTag>;
274 template<concepts::Tag TTag>
275 struct QueueType<DevGenericSycl<TTag>, NonBlocking>
277 using type = QueueGenericSyclNonBlocking<TTag>;
constexpr ALPAKA_FN_HOST_ACC bool operator==(Complex< T > const &lhs, Complex< T > const &rhs)
Equality of two complex numbers.
constexpr ALPAKA_FN_HOST_ACC bool operator!=(Complex< T > const &lhs, Complex< T > const &rhs)
Inequality of two complex numbers.
ALPAKA_FN_HOST auto end(TView &view) -> Iterator< TView >
ALPAKA_FN_HOST auto begin(TView &view) -> Iterator< TView >
The alpaka accelerator library.
ALPAKA_FN_HOST auto getName(TDev const &dev) -> std::string
ALPAKA_FN_HOST auto getWarpSizes(TDev const &dev) -> std::vector< std::size_t >
ALPAKA_FN_HOST auto reset(TDev const &dev) -> void
Resets the device. What this method does is dependent on the accelerator.
ALPAKA_FN_HOST auto getFreeMemBytes(TDev const &dev) -> std::size_t
ALPAKA_FN_HOST auto getMemBytes(TDev const &dev) -> std::size_t
constexpr ALPAKA_FN_HOST auto getPreferredWarpSize(TDev const &dev) -> std::size_t
decltype(getNativeHandle(std::declval< TImpl >())) NativeHandle
Alias to the type of the native handle.
ALPAKA_FN_HOST auto getNativeHandle(TImpl const &impl)
Get the native handle of the alpaka object. It will return the alpaka object handle if there is any,...
ALPAKA_FN_HOST auto wait(TAwaited const &awaited) -> void
Waits the thread for the completion of the given awaited action to complete.
static auto getNativeHandle(TImpl const &)