26#include <shared_mutex>
31#ifdef ALPAKA_ACC_SYCL_ENABLED
33# include <sycl/sycl.hpp>
39 template<
typename TPlatform,
typename TSfinae>
43 template<concepts::Tag TTag>
44 using QueueGenericSyclBlocking = detail::QueueGenericSyclBase<TTag, true>;
46 template<concepts::Tag TTag>
47 using QueueGenericSyclNonBlocking = detail::QueueGenericSyclBase<TTag, false>;
49 template<concepts::Tag TTag>
50 struct PlatformGenericSycl;
52 template<
typename TElem,
typename TDim,
typename TIdx, concepts::Tag TTag>
57 class DevGenericSyclImpl
60 DevGenericSyclImpl(sycl::device device, sycl::context context)
61 : m_device{
std::move(device)}
62 , m_context{
std::move(context)}
67 auto clean_queues() ->
void
70 auto const start = std::begin(m_queues);
71 auto const old_end = std::end(m_queues);
72 auto const new_end = std::remove_if(start, old_end, [](
auto q_ptr) {
return q_ptr.expired(); });
73 m_queues.erase(new_end, old_end);
76 auto register_queue(std::shared_ptr<QueueGenericSyclImpl>
const& queue) ->
void
78 std::lock_guard<std::shared_mutex> lock{m_mutex};
81 m_queues.emplace_back(queue);
84 auto register_dependency(sycl::event event) ->
void
86 std::shared_lock<std::shared_mutex> lock{m_mutex};
88 for(
auto& q_ptr : m_queues)
90 if(
auto ptr = q_ptr.lock(); ptr !=
nullptr)
91 ptr->register_dependency(event);
97 std::shared_lock<std::shared_mutex> lock{m_mutex};
99 for(
auto& q_ptr : m_queues)
101 if(
auto ptr = q_ptr.lock(); ptr !=
nullptr)
106 auto get_device() const -> sycl::device
111 auto get_context() const -> sycl::context
116 auto deviceProperties() -> std::optional<alpaka::DeviceProperties>&
122 m_deviceProperties = std::make_optional<alpaka::DeviceProperties>();
123 auto const& device = this->get_device();
124 m_deviceProperties->name = device.template get_info<sycl::info::device::name>();
125 m_deviceProperties->totalGlobalMem
126 = device.template get_info<sycl::info::device::global_mem_size>();
128 std::vector<std::size_t> warp_sizes
129 = device.template get_info<sycl::info::device::sub_group_sizes>();
132 auto find64 = std::find(warp_sizes.begin(), warp_sizes.end(), 64);
133 if(find64 != warp_sizes.end())
134 warp_sizes.erase(find64);
136 std::sort(warp_sizes.begin(), warp_sizes.end(), std::greater<>{});
137 m_deviceProperties->warpSizes = std::move(warp_sizes);
138 m_deviceProperties->preferredWarpSize = m_deviceProperties->warpSizes.front();
141 return m_deviceProperties;
145 sycl::device m_device;
146 sycl::context m_context;
147 std::vector<std::weak_ptr<QueueGenericSyclImpl>> m_queues;
148 std::optional<alpaka::DeviceProperties> m_deviceProperties;
149 std::shared_mutex
mutable m_mutex;
150 std::once_flag m_onceFlag;
155 template<concepts::Tag TTag>
157 :
public interface::Implements<ConceptCurrentThreadWaitFor, DevGenericSycl<TTag>>
158 ,
public interface::Implements<ConceptDev, DevGenericSycl<TTag>>
160 friend struct trait::GetDevByIdx<PlatformGenericSycl<TTag>>;
163 DevGenericSycl(sycl::device device, sycl::context context)
164 : m_impl{
std::make_shared<detail::DevGenericSyclImpl>(
std::move(device),
std::move(context))}
168 friend auto operator==(DevGenericSycl
const& lhs, DevGenericSycl
const& rhs) ->
bool
170 return (lhs.m_impl == rhs.m_impl);
173 friend auto operator!=(DevGenericSycl
const& lhs, DevGenericSycl
const& rhs) ->
bool
175 return !(lhs == rhs);
180 return std::make_pair(m_impl->get_device(), m_impl->get_context());
183 std::shared_ptr<detail::DevGenericSyclImpl> m_impl;
190 template<concepts::Tag TTag>
191 struct GetName<DevGenericSycl<TTag>>
193 static auto getName(DevGenericSycl<TTag>
const& dev) -> std::string
195 return dev.m_impl->deviceProperties()->name;
200 template<concepts::Tag TTag>
201 struct GetMemBytes<DevGenericSycl<TTag>>
203 static auto getMemBytes(DevGenericSycl<TTag>
const& dev) -> std::size_t
205 return dev.m_impl->deviceProperties()->totalGlobalMem;
210 template<concepts::Tag TTag>
211 struct GetFreeMemBytes<DevGenericSycl<TTag>>
213 static auto getFreeMemBytes(DevGenericSycl<TTag>
const& ) -> std::size_t
216 !
sizeof(PlatformGenericSycl<TTag>),
217 "Querying free device memory not supported for SYCL devices.");
218 return std::size_t{};
223 template<concepts::Tag TTag>
224 struct GetWarpSizes<DevGenericSycl<TTag>>
226 static auto getWarpSizes(DevGenericSycl<TTag>
const& dev) -> std::vector<std::size_t>
228 return dev.m_impl->deviceProperties()->warpSizes;
233 template<concepts::Tag TTag>
234 struct GetPreferredWarpSize<DevGenericSycl<TTag>>
238 return dev.m_impl->deviceProperties()->preferredWarpSize;
243 template<concepts::Tag TTag>
244 struct Reset<DevGenericSycl<TTag>>
246 static auto reset(DevGenericSycl<TTag>
const&) ->
void
249 !
sizeof(PlatformGenericSycl<TTag>),
250 "Explicit device reset not supported for SYCL devices");
255 template<concepts::Tag TTag>
258 [[nodiscard]]
static auto getNativeHandle(DevGenericSycl<TTag>
const& dev)
260 return dev.getNativeHandle();
265 template<
typename TElem,
typename TDim,
typename TIdx, concepts::Tag TTag>
266 struct BufType<DevGenericSycl<TTag>, TElem, TDim, TIdx>
268 using type = BufGenericSycl<TElem, TDim, TIdx, TTag>;
272 template<concepts::Tag TTag>
273 struct PlatformType<DevGenericSycl<TTag>>
275 using type = PlatformGenericSycl<TTag>;
279 template<concepts::Tag TTag>
280 struct CurrentThreadWaitFor<DevGenericSycl<TTag>>
282 static auto currentThreadWaitFor(DevGenericSycl<TTag>
const& dev) ->
void
289 template<concepts::Tag TTag>
290 struct QueueType<DevGenericSycl<TTag>, Blocking>
292 using type = QueueGenericSyclBlocking<TTag>;
296 template<concepts::Tag TTag>
297 struct QueueType<DevGenericSycl<TTag>, NonBlocking>
299 using type = QueueGenericSyclNonBlocking<TTag>;
constexpr ALPAKA_FN_HOST_ACC bool operator==(Complex< T > const &lhs, Complex< T > const &rhs)
Equality of two complex numbers.
constexpr ALPAKA_FN_HOST_ACC bool operator!=(Complex< T > const &lhs, Complex< T > const &rhs)
Inequality of two complex numbers.
The alpaka accelerator library.
ALPAKA_FN_HOST constexpr auto getPreferredWarpSize(TDev const &dev) -> std::size_t
ALPAKA_FN_HOST auto getName(TDev const &dev) -> std::string
ALPAKA_FN_HOST auto getWarpSizes(TDev const &dev) -> std::vector< std::size_t >
ALPAKA_FN_HOST auto reset(TDev const &dev) -> void
Resets the device. What this method does is dependent on the accelerator.
ALPAKA_FN_HOST auto getFreeMemBytes(TDev const &dev) -> std::size_t
ALPAKA_FN_HOST auto getMemBytes(TDev const &dev) -> std::size_t
decltype(getNativeHandle(std::declval< TImpl >())) NativeHandle
Alias to the type of the native handle.
ALPAKA_FN_HOST auto getNativeHandle(TImpl const &impl)
Get the native handle of the alpaka object. It will return the alpaka object handle if there is any,...
ALPAKA_FN_HOST auto wait(TAwaited const &awaited) -> void
Waits the thread for the completion of the given awaited action to complete.
static auto getNativeHandle(TImpl const &)