alpaka
Abstraction Library for Parallel Kernel Acceleration
Loading...
Searching...
No Matches
MathGenericSycl.hpp
Go to the documentation of this file.
1/* Copyright 2023 Jan Stephan, Sergei Bastrakov, René Widera, Luca Ferragina, Andrea Bocci
2 * SPDX-License-Identifier: MPL-2.0
3 */
4
5#pragma once
6
10
11#include <type_traits>
12
13#ifdef ALPAKA_ACC_SYCL_ENABLED
14
15# include <sycl/sycl.hpp>
16
17//! The mathematical operation specifics.
18namespace alpaka::math
19{
20 //! The SYCL abs.
21 class AbsGenericSycl : public interface::Implements<alpaka::math::ConceptMathAbs, AbsGenericSycl>
22 {
23 };
24
25 //! The SYCL acos.
26 class AcosGenericSycl : public interface::Implements<alpaka::math::ConceptMathAcos, AcosGenericSycl>
27 {
28 };
29
30 //! The SYCL acosh.
31 class AcoshGenericSycl : public interface::Implements<alpaka::math::ConceptMathAcosh, AcoshGenericSycl>
32 {
33 };
34
35 //! The SYCL arg.
36 class ArgGenericSycl : public interface::Implements<alpaka::math::ConceptMathArg, ArgGenericSycl>
37 {
38 };
39
40 //! The SYCL asin.
41 class AsinGenericSycl : public interface::Implements<alpaka::math::ConceptMathAsin, AsinGenericSycl>
42 {
43 };
44
45 //! The SYCL asinh.
46 class AsinhGenericSycl : public interface::Implements<alpaka::math::ConceptMathAsinh, AsinhGenericSycl>
47 {
48 };
49
50 //! The SYCL atan.
51 class AtanGenericSycl : public interface::Implements<alpaka::math::ConceptMathAtan, AtanGenericSycl>
52 {
53 };
54
55 //! The SYCL atanh.
56 class AtanhGenericSycl : public interface::Implements<alpaka::math::ConceptMathAtanh, AtanhGenericSycl>
57 {
58 };
59
60 //! The SYCL atan2.
61 class Atan2GenericSycl : public interface::Implements<alpaka::math::ConceptMathAtan2, Atan2GenericSycl>
62 {
63 };
64
65 //! The SYCL cbrt.
66 class CbrtGenericSycl : public interface::Implements<alpaka::math::ConceptMathCbrt, CbrtGenericSycl>
67 {
68 };
69
70 //! The SYCL ceil.
71 class CeilGenericSycl : public interface::Implements<alpaka::math::ConceptMathCeil, CeilGenericSycl>
72 {
73 };
74
75 //! The SYCL conj.
76 class ConjGenericSycl : public interface::Implements<alpaka::math::ConceptMathConj, ConjGenericSycl>
77 {
78 };
79
80 //! The SYCL copysign.
81 class CopysignGenericSycl : public interface::Implements<alpaka::math::ConceptMathCopysign, CopysignGenericSycl>
82 {
83 };
84
85 //! The SYCL cos.
86 class CosGenericSycl : public interface::Implements<alpaka::math::ConceptMathCos, CosGenericSycl>
87 {
88 };
89
90 //! The SYCL cosh.
91 class CoshGenericSycl : public interface::Implements<alpaka::math::ConceptMathCosh, CoshGenericSycl>
92 {
93 };
94
95 //! The SYCL erf.
96 class ErfGenericSycl : public interface::Implements<alpaka::math::ConceptMathErf, ErfGenericSycl>
97 {
98 };
99
100 //! The SYCL exp.
101 class ExpGenericSycl : public interface::Implements<alpaka::math::ConceptMathExp, ExpGenericSycl>
102 {
103 };
104
105 //! The SYCL floor.
106 class FloorGenericSycl : public interface::Implements<alpaka::math::ConceptMathFloor, FloorGenericSycl>
107 {
108 };
109
110 //! The SYCL fma.
111 class FmaGenericSycl : public interface::Implements<alpaka::math::ConceptMathFma, FmaGenericSycl>
112 {
113 };
114
115 //! The SYCL fmod.
116 class FmodGenericSycl : public interface::Implements<alpaka::math::ConceptMathFmod, FmodGenericSycl>
117 {
118 };
119
120 //! The SYCL isfinite.
121 class IsfiniteGenericSycl : public interface::Implements<alpaka::math::ConceptMathIsfinite, IsfiniteGenericSycl>
122 {
123 };
124
125 //! The SYCL isfinite.
126 class IsinfGenericSycl : public interface::Implements<alpaka::math::ConceptMathIsinf, IsinfGenericSycl>
127 {
128 };
129
130 //! The SYCL isnan.
131 class IsnanGenericSycl : public interface::Implements<alpaka::math::ConceptMathIsnan, IsnanGenericSycl>
132 {
133 };
134
135 //! The SYCL log.
136 class LogGenericSycl : public interface::Implements<alpaka::math::ConceptMathLog, LogGenericSycl>
137 {
138 };
139
140 //! The SYCL log2.
141 class Log2GenericSycl : public interface::Implements<alpaka::math::ConceptMathLog2, Log2GenericSycl>
142 {
143 };
144
145 //! The SYCL log10.
146 class Log10GenericSycl : public interface::Implements<alpaka::math::ConceptMathLog10, Log10GenericSycl>
147 {
148 };
149
150 //! The SYCL max.
151 class MaxGenericSycl : public interface::Implements<alpaka::math::ConceptMathMax, MaxGenericSycl>
152 {
153 };
154
155 //! The SYCL min.
156 class MinGenericSycl : public interface::Implements<alpaka::math::ConceptMathMin, MinGenericSycl>
157 {
158 };
159
160 //! The SYCL pow.
161 class PowGenericSycl : public interface::Implements<alpaka::math::ConceptMathPow, PowGenericSycl>
162 {
163 };
164
165 //! The SYCL remainder.
166 class RemainderGenericSycl : public interface::Implements<alpaka::math::ConceptMathRemainder, RemainderGenericSycl>
167 {
168 };
169
170 //! The SYCL round.
171 class RoundGenericSycl : public interface::Implements<alpaka::math::ConceptMathRound, RoundGenericSycl>
172 {
173 };
174
175 //! The SYCL rsqrt.
176 class RsqrtGenericSycl : public interface::Implements<alpaka::math::ConceptMathRsqrt, RsqrtGenericSycl>
177 {
178 };
179
180 //! The SYCL sin.
181 class SinGenericSycl : public interface::Implements<alpaka::math::ConceptMathSin, SinGenericSycl>
182 {
183 };
184
185 //! The SYCL sinh.
186 class SinhGenericSycl : public interface::Implements<alpaka::math::ConceptMathSinh, SinhGenericSycl>
187 {
188 };
189
190 //! The SYCL sincos.
191 class SinCosGenericSycl : public interface::Implements<alpaka::math::ConceptMathSinCos, SinCosGenericSycl>
192 {
193 };
194
195 //! The SYCL sqrt.
196 class SqrtGenericSycl : public interface::Implements<alpaka::math::ConceptMathSqrt, SqrtGenericSycl>
197 {
198 };
199
200 //! The SYCL tan.
201 class TanGenericSycl : public interface::Implements<alpaka::math::ConceptMathTan, TanGenericSycl>
202 {
203 };
204
205 //! The SYCL tanh.
206 class TanhGenericSycl : public interface::Implements<alpaka::math::ConceptMathTanh, TanhGenericSycl>
207 {
208 };
209
210 //! The SYCL trunc.
211 class TruncGenericSycl : public interface::Implements<alpaka::math::ConceptMathTrunc, TruncGenericSycl>
212 {
213 };
214
215 //! The SYCL math trait specializations.
216 class MathGenericSycl
217 : public AbsGenericSycl
218 , public AcosGenericSycl
219 , public AcoshGenericSycl
220 , public ArgGenericSycl
221 , public AsinGenericSycl
222 , public AsinhGenericSycl
223 , public AtanGenericSycl
224 , public AtanhGenericSycl
225 , public Atan2GenericSycl
226 , public CbrtGenericSycl
227 , public CeilGenericSycl
228 , public ConjGenericSycl
229 , public CopysignGenericSycl
230 , public CosGenericSycl
231 , public CoshGenericSycl
232 , public ErfGenericSycl
233 , public ExpGenericSycl
234 , public FloorGenericSycl
235 , public FmaGenericSycl
236 , public FmodGenericSycl
237 , public IsfiniteGenericSycl
238 , public IsinfGenericSycl
239 , public IsnanGenericSycl
240 , public LogGenericSycl
241 , public Log2GenericSycl
242 , public Log10GenericSycl
243 , public MaxGenericSycl
244 , public MinGenericSycl
245 , public PowGenericSycl
246 , public RemainderGenericSycl
247 , public RoundGenericSycl
248 , public RsqrtGenericSycl
249 , public SinGenericSycl
250 , public SinhGenericSycl
251 , public SinCosGenericSycl
252 , public SqrtGenericSycl
253 , public TanGenericSycl
254 , public TanhGenericSycl
255 , public TruncGenericSycl
256 {
257 };
258} // namespace alpaka::math
259
260namespace alpaka::math::trait
261{
262 //! The SYCL abs trait specialization.
263 template<typename TArg>
264 struct Abs<math::AbsGenericSycl, TArg, std::enable_if_t<std::is_arithmetic_v<TArg>>>
265 {
266 auto operator()(math::AbsGenericSycl const&, TArg const& arg)
267 {
268 if constexpr(std::is_integral_v<TArg>)
269 return sycl::abs(arg);
270 else if constexpr(std::is_floating_point_v<TArg>)
271 return sycl::fabs(arg);
272 else
273 static_assert(!sizeof(TArg), "Unsupported data type");
274 }
275 };
276
277 //! The SYCL acos trait specialization.
278 template<typename TArg>
279 struct Acos<math::AcosGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
280 {
281 auto operator()(math::AcosGenericSycl const&, TArg const& arg)
282 {
283 return sycl::acos(arg);
284 }
285 };
286
287 //! The SYCL acosh trait specialization.
288 template<typename TArg>
289 struct Acosh<math::AcoshGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
290 {
291 auto operator()(math::AcoshGenericSycl const&, TArg const& arg)
292 {
293 return sycl::acosh(arg);
294 }
295 };
296
297 //! The SYCL arg trait specialization.
298 template<typename TArgument>
299 struct Arg<math::ArgGenericSycl, TArgument, std::enable_if_t<std::is_arithmetic_v<TArgument>>>
300 {
301 auto operator()(math::ArgGenericSycl const&, TArgument const& argument)
302 {
303 if constexpr(std::is_integral_v<TArgument>)
304 return sycl::atan2(0.0, static_cast<double>(argument));
305 else if constexpr(std::is_floating_point_v<TArgument>)
306 return sycl::atan2(static_cast<TArgument>(0.0), argument);
307 else
308 static_assert(!sizeof(TArgument), "Unsupported data type");
309 }
310 };
311
312 //! The SYCL asin trait specialization.
313 template<typename TArg>
314 struct Asin<math::AsinGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
315 {
316 auto operator()(math::AsinGenericSycl const&, TArg const& arg)
317 {
318 return sycl::asin(arg);
319 }
320 };
321
322 //! The SYCL asinh trait specialization.
323 template<typename TArg>
324 struct Asinh<math::AsinhGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
325 {
326 auto operator()(math::AsinhGenericSycl const&, TArg const& arg)
327 {
328 return sycl::asinh(arg);
329 }
330 };
331
332 //! The SYCL atan trait specialization.
333 template<typename TArg>
334 struct Atan<math::AtanGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
335 {
336 auto operator()(math::AtanGenericSycl const&, TArg const& arg)
337 {
338 return sycl::atan(arg);
339 }
340 };
341
342 //! The SYCL atanh trait specialization.
343 template<typename TArg>
344 struct Atanh<math::AtanhGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
345 {
346 auto operator()(math::AtanhGenericSycl const&, TArg const& arg)
347 {
348 return sycl::atanh(arg);
349 }
350 };
351
352 //! The SYCL atan2 trait specialization.
353 template<typename Ty, typename Tx>
354 struct Atan2<
355 math::Atan2GenericSycl,
356 Ty,
357 Tx,
358 std::enable_if_t<std::is_floating_point_v<Ty> && std::is_floating_point_v<Tx>>>
359 {
360 using TCommon = std::common_type_t<Ty, Tx>;
361
362 auto operator()(math::Atan2GenericSycl const&, Ty const& y, Tx const& x)
363 {
364 return sycl::atan2(static_cast<TCommon>(y), static_cast<TCommon>(x));
365 }
366 };
367
368 //! The SYCL cbrt trait specialization.
369 template<typename TArg>
370 struct Cbrt<math::CbrtGenericSycl, TArg, std::enable_if_t<std::is_arithmetic_v<TArg>>>
371 {
372 auto operator()(math::CbrtGenericSycl const&, TArg const& arg)
373 {
374 if constexpr(std::is_integral_v<TArg>)
375 return sycl::cbrt(static_cast<double>(arg)); // Mirror CUDA back-end and use double for ints
376 else if constexpr(std::is_floating_point_v<TArg>)
377 return sycl::cbrt(arg);
378 else
379 static_assert(!sizeof(TArg), "Unsupported data type");
380 }
381 };
382
383 //! The SYCL ceil trait specialization.
384 template<typename TArg>
385 struct Ceil<math::CeilGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
386 {
387 auto operator()(math::CeilGenericSycl const&, TArg const& arg)
388 {
389 return sycl::ceil(arg);
390 }
391 };
392
393 //! The SYCL conj trait specialization.
394 template<typename TArg>
395 struct Conj<math::ConjGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
396 {
397 auto operator()(math::ConjGenericSycl const&, TArg const& arg)
398 {
399 return Complex<TArg>{arg, TArg{0.0}};
400 }
401 };
402
403 //! The SYCL copysign trait specialization.
404 template<typename TMag, typename TSgn>
405 struct Copysign<
406 math::CopysignGenericSycl,
407 TMag,
408 TSgn,
409 std::enable_if_t<std::is_floating_point_v<TMag> && std::is_floating_point_v<TSgn>>>
410 {
411 using TCommon = std::common_type_t<TMag, TSgn>;
412
413 auto operator()(math::CopysignGenericSycl const&, TMag const& y, TSgn const& x)
414 {
415 return sycl::copysign(static_cast<TCommon>(y), static_cast<TCommon>(x));
416 }
417 };
418
419 //! The SYCL cos trait specialization.
420 template<typename TArg>
421 struct Cos<math::CosGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
422 {
423 auto operator()(math::CosGenericSycl const&, TArg const& arg)
424 {
425 return sycl::cos(arg);
426 }
427 };
428
429 //! The SYCL cos trait specialization.
430 template<typename TArg>
431 struct Cosh<math::CoshGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
432 {
433 auto operator()(math::CoshGenericSycl const&, TArg const& arg)
434 {
435 return sycl::cosh(arg);
436 }
437 };
438
439 //! The SYCL erf trait specialization.
440 template<typename TArg>
441 struct Erf<math::ErfGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
442 {
443 auto operator()(math::ErfGenericSycl const&, TArg const& arg)
444 {
445 return sycl::erf(arg);
446 }
447 };
448
449 //! The SYCL exp trait specialization.
450 template<typename TArg>
451 struct Exp<math::ExpGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
452 {
453 auto operator()(math::ExpGenericSycl const&, TArg const& arg)
454 {
455 return sycl::exp(arg);
456 }
457 };
458
459 //! The SYCL floor trait specialization.
460 template<typename TArg>
461 struct Floor<math::FloorGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
462 {
463 auto operator()(math::FloorGenericSycl const&, TArg const& arg)
464 {
465 return sycl::floor(arg);
466 }
467 };
468
469 //! The SYCL fma trait specialization.
470 template<typename Tx, typename Ty, typename Tz>
471 struct Fma<
472 math::FmaGenericSycl,
473 Tx,
474 Ty,
475 Tz,
476 std::enable_if_t<std::is_floating_point_v<Tx> && std::is_floating_point_v<Ty> && std::is_floating_point_v<Tz>>>
477 {
478 auto operator()(math::FmaGenericSycl const&, Tx const& x, Ty const& y, Tz const& z)
479 {
480 return sycl::fma(x, y, z);
481 }
482 };
483
484 //! The SYCL fmod trait specialization.
485 template<typename Tx, typename Ty>
486 struct Fmod<
487 math::FmodGenericSycl,
488 Tx,
489 Ty,
490 std::enable_if_t<std::is_floating_point_v<Tx> && std::is_floating_point_v<Ty>>>
491 {
492 using TCommon = std::common_type_t<Tx, Ty>;
493
494 auto operator()(math::FmodGenericSycl const&, Tx const& x, Ty const& y)
495 {
496 return sycl::fmod(static_cast<TCommon>(x), static_cast<TCommon>(y));
497 }
498 };
499
500 //! The SYCL isfinite trait specialization.
501 template<typename TArg>
502 struct Isfinite<math::IsfiniteGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
503 {
504 auto operator()(math::IsfiniteGenericSycl const&, TArg const& arg)
505 {
506 return static_cast<bool>(sycl::isfinite(arg));
507 }
508 };
509
510 //! The SYCL isinf trait specialization.
511 template<typename TArg>
512 struct Isinf<math::IsinfGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
513 {
514 auto operator()(math::IsinfGenericSycl const&, TArg const& arg)
515 {
516 return static_cast<bool>(sycl::isinf(arg));
517 }
518 };
519
520 //! The SYCL isnan trait specialization.
521 template<typename TArg>
522 struct Isnan<math::IsnanGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
523 {
524 auto operator()(math::IsnanGenericSycl const&, TArg const& arg)
525 {
526 return static_cast<bool>(sycl::isnan(arg));
527 }
528 };
529
530 //! The SYCL log trait specialization.
531 template<typename TArg>
532 struct Log<math::LogGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
533 {
534 auto operator()(math::LogGenericSycl const&, TArg const& arg)
535 {
536 return sycl::log(arg);
537 }
538 };
539
540 //! The SYCL log2 trait specialization.
541 template<typename TArg>
542 struct Log2<math::Log2GenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
543 {
544 auto operator()(math::Log2GenericSycl const&, TArg const& arg)
545 {
546 return sycl::log2(arg);
547 }
548 };
549
550 //! The SYCL log10 trait specialization.
551 template<typename TArg>
552 struct Log10<math::Log10GenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
553 {
554 auto operator()(math::Log10GenericSycl const&, TArg const& arg)
555 {
556 return sycl::log10(arg);
557 }
558 };
559
560 //! The SYCL max trait specialization.
561 template<typename Tx, typename Ty>
562 struct Max<math::MaxGenericSycl, Tx, Ty, std::enable_if_t<std::is_arithmetic_v<Tx> && std::is_arithmetic_v<Ty>>>
563 {
564 using TCommon = std::common_type_t<Tx, Ty>;
565
566 auto operator()(math::MaxGenericSycl const&, Tx const& x, Ty const& y)
567 {
568 if constexpr(std::is_integral_v<Tx> && std::is_integral_v<Ty>)
569 return sycl::max(static_cast<TCommon>(x), static_cast<TCommon>(y));
570 else if constexpr(std::is_floating_point_v<Tx> && std::is_floating_point_v<Ty>)
571 return sycl::fmax(static_cast<TCommon>(x), static_cast<TCommon>(y));
572 else if constexpr(
573 (std::is_floating_point_v<Tx> && std::is_integral_v<Ty>)
574 || (std::is_integral_v<Tx> && std::is_floating_point_v<Ty>) )
575 return sycl::fmax(static_cast<double>(x), static_cast<double>(y)); // mirror CUDA back-end
576 else
577 static_assert(!sizeof(Tx), "Unsupported data types");
578 }
579 };
580
581 //! The SYCL min trait specialization.
582 template<typename Tx, typename Ty>
583 struct Min<math::MinGenericSycl, Tx, Ty, std::enable_if_t<std::is_arithmetic_v<Tx> && std::is_arithmetic_v<Ty>>>
584 {
585 auto operator()(math::MinGenericSycl const&, Tx const& x, Ty const& y)
586 {
587 if constexpr(std::is_integral_v<Tx> && std::is_integral_v<Ty>)
588 return sycl::min(x, y);
589 else if constexpr(std::is_floating_point_v<Tx> || std::is_floating_point_v<Ty>)
590 return sycl::fmin(x, y);
591 else if constexpr(
592 (std::is_floating_point_v<Tx> && std::is_integral_v<Ty>)
593 || (std::is_integral_v<Tx> && std::is_floating_point_v<Ty>) )
594 return sycl::fmin(static_cast<double>(x), static_cast<double>(y)); // mirror CUDA back-end
595 else
596 static_assert(!sizeof(Tx), "Unsupported data types");
597 }
598 };
599
600 //! The SYCL pow trait specialization.
601 template<typename TBase, typename TExp>
602 struct Pow<
603 math::PowGenericSycl,
604 TBase,
605 TExp,
606 std::enable_if_t<std::is_floating_point_v<TBase> && std::is_floating_point_v<TExp>>>
607 {
608 using TCommon = std::common_type_t<TBase, TExp>;
609
610 auto operator()(math::PowGenericSycl const&, TBase const& base, TExp const& exp)
611 {
612 return sycl::pow(static_cast<TCommon>(base), static_cast<TCommon>(exp));
613 }
614 };
615
616 //! The SYCL remainder trait specialization.
617 template<typename Tx, typename Ty>
618 struct Remainder<
619 math::RemainderGenericSycl,
620 Tx,
621 Ty,
622 std::enable_if_t<std::is_floating_point_v<Tx> && std::is_floating_point_v<Ty>>>
623 {
624 using TCommon = std::common_type_t<Tx, Ty>;
625
626 auto operator()(math::RemainderGenericSycl const&, Tx const& x, Ty const& y)
627 {
628 return sycl::remainder(static_cast<TCommon>(x), static_cast<TCommon>(y));
629 }
630 };
631
632 //! The SYCL round trait specialization.
633 template<typename TArg>
634 struct Round<math::RoundGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
635 {
636 auto operator()(math::RoundGenericSycl const&, TArg const& arg)
637 {
638 return sycl::round(arg);
639 }
640 };
641
642 //! The SYCL lround trait specialization.
643 template<typename TArg>
644 struct Lround<math::RoundGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
645 {
646 auto operator()(math::RoundGenericSycl const&, TArg const& arg)
647 {
648 return static_cast<long>(sycl::round(arg));
649 }
650 };
651
652 //! The SYCL llround trait specialization.
653 template<typename TArg>
654 struct Llround<math::RoundGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
655 {
656 auto operator()(math::RoundGenericSycl const&, TArg const& arg)
657 {
658 return static_cast<long long>(sycl::round(arg));
659 }
660 };
661
662 //! The SYCL rsqrt trait specialization.
663 template<typename TArg>
664 struct Rsqrt<math::RsqrtGenericSycl, TArg, std::enable_if_t<std::is_arithmetic_v<TArg>>>
665 {
666 auto operator()(math::RsqrtGenericSycl const&, TArg const& arg)
667 {
668 if constexpr(std::is_floating_point_v<TArg>)
669 return sycl::rsqrt(arg);
670 else if constexpr(std::is_integral_v<TArg>)
671 return sycl::rsqrt(static_cast<double>(arg)); // mirror CUDA back-end and use double for ints
672 else
673 static_assert(!sizeof(TArg), "Unsupported data type");
674 }
675 };
676
677 //! The SYCL sin trait specialization.
678 template<typename TArg>
679 struct Sin<math::SinGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
680 {
681 auto operator()(math::SinGenericSycl const&, TArg const& arg)
682 {
683 return sycl::sin(arg);
684 }
685 };
686
687 //! The SYCL sinh trait specialization.
688 template<typename TArg>
689 struct Sinh<math::SinhGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
690 {
691 auto operator()(math::SinhGenericSycl const&, TArg const& arg)
692 {
693 return sycl::sinh(arg);
694 }
695 };
696
697 //! The SYCL sincos trait specialization.
698 template<typename TArg>
699 struct SinCos<math::SinCosGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
700 {
701 auto operator()(math::SinCosGenericSycl const&, TArg const& arg, TArg& result_sin, TArg& result_cos) -> void
702 {
703 result_sin = sycl::sincos(arg, &result_cos);
704 }
705 };
706
707 //! The SYCL sqrt trait specialization.
708 template<typename TArg>
709 struct Sqrt<math::SqrtGenericSycl, TArg, std::enable_if_t<std::is_arithmetic_v<TArg>>>
710 {
711 auto operator()(math::SqrtGenericSycl const&, TArg const& arg)
712 {
713 if constexpr(std::is_floating_point_v<TArg>)
714 return sycl::sqrt(arg);
715 else if constexpr(std::is_integral_v<TArg>)
716 return sycl::sqrt(static_cast<double>(arg)); // mirror CUDA back-end and use double for ints
717 }
718 };
719
720 //! The SYCL tan trait specialization.
721 template<typename TArg>
722 struct Tan<math::TanGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
723 {
724 auto operator()(math::TanGenericSycl const&, TArg const& arg)
725 {
726 return sycl::tan(arg);
727 }
728 };
729
730 //! The SYCL tanh trait specialization.
731 template<typename TArg>
732 struct Tanh<math::TanhGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
733 {
734 auto operator()(math::TanhGenericSycl const&, TArg const& arg)
735 {
736 return sycl::tanh(arg);
737 }
738 };
739
740 //! The SYCL trunc trait specialization.
741 template<typename TArg>
742 struct Trunc<math::TruncGenericSycl, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
743 {
744 auto operator()(math::TruncGenericSycl const&, TArg const& arg)
745 {
746 return sycl::trunc(arg);
747 }
748 };
749} // namespace alpaka::math::trait
750
751#endif
The math traits.
Definition Complex.hpp:588
ALPAKA_NO_HOST_ACC_WARNING ALPAKA_FN_HOST_ACC auto exp(T const &exp_ctx, TArg const &arg)
Computes the e (Euler's number, 2.7182818) raised to the given power arg.
Definition Traits.hpp:1102
ALPAKA_NO_HOST_ACC_WARNING ALPAKA_FN_HOST_ACC auto arg(T const &arg_ctx, TArgument const &argument)
Computes the complex argument of the value.
Definition Traits.hpp:912
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:298
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:311
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:324
ALPAKA_NO_HOST_ACC_WARNING ALPAKA_FN_HOST_ACC auto operator()(T const &, TArgument const &argument)
Definition Traits.hpp:340
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:353
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:366
ALPAKA_FN_HOST_ACC auto operator()(T const &, Ty const &y, Tx const &x)
Definition Traits.hpp:405
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:379
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:392
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:418
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:431
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:444
ALPAKA_FN_HOST_ACC auto operator()(T const &, TMag const &mag, TSgn const &sgn)
Definition Traits.hpp:457
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:470
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:483
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:495
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:508
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:521
ALPAKA_FN_HOST_ACC auto operator()(T const &, Tx const &x, Ty const &y, Tz const &z)
Definition Traits.hpp:534
ALPAKA_FN_HOST_ACC auto operator()(T const &, Tx const &x, Ty const &y)
Definition Traits.hpp:547
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:560
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:573
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:586
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:716
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:625
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:612
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:599
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:703
ALPAKA_FN_HOST_ACC auto operator()(T const &, Tx const &x, Ty const &y)
Definition Traits.hpp:638
ALPAKA_FN_HOST_ACC auto operator()(T const &, Tx const &x, Ty const &y)
Definition Traits.hpp:651
ALPAKA_FN_HOST_ACC auto operator()(T const &, TBase const &base, TExp const &exp)
Definition Traits.hpp:664
ALPAKA_FN_HOST_ACC auto operator()(T const &, Tx const &x, Ty const &y)
Definition Traits.hpp:677
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:690
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:741
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg, TArg &result_sin, TArg &result_cos)
Definition Traits.hpp:794
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:754
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:767
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:807
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:820
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:833
ALPAKA_FN_HOST_ACC auto operator()(T const &, TArg const &arg)
Definition Traits.hpp:846