@@ -335,25 +335,23 @@ template <typename T> class TryToGetPointerVecT {
335335 using type = decltype (check(T()));
336336};
337337
338- template <typename T, typename = typename detail::enable_if_t <
339- TryToGetPointerT<T>::value, std::true_type>>
340- typename TryToGetPointerVecT<T>::type TryToGetPointer (T &t) {
338+ template <
339+ typename To, typename From,
340+ typename = typename detail::enable_if_t <TryToGetPointerT<From>::value>>
341+ To ConvertNonVectorType (From &t) {
341342 // TODO find the better way to get the pointer to underlying data from vec
342343 // class
343- return reinterpret_cast <typename TryToGetPointerVecT<T>::type >(t.get ());
344+ return reinterpret_cast <To >(t.get ());
344345}
345346
346- template <typename T>
347- typename TryToGetPointerVecT<T *>::type TryToGetPointer (T *t) {
348- // TODO find the better way to get the pointer to underlying data from vec
349- // class
350- return reinterpret_cast <typename TryToGetPointerVecT<T *>::type>(t);
347+ template <typename To, typename From> To ConvertNonVectorType (From *t) {
348+ return reinterpret_cast <To>(t);
351349}
352350
353- template <typename T , typename = typename detail:: enable_if_t <
354- !TryToGetPointerT<T >::value, std::false_type> >
355- T TryToGetPointer (T &t) {
356- return t ;
351+ template <typename To , typename From>
352+ typename detail:: enable_if_t < !TryToGetPointerT<From >::value, To >
353+ ConvertNonVectorType (From &t) {
354+ return static_cast <To>(t) ;
357355}
358356
359357// select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on
@@ -398,13 +396,14 @@ using select_cl_scalar_t = conditional_t<
398396 conditional_t <std::is_same<T, half>::value,
399397 sycl::detail::half_impl::BIsRepresentationT, T>>>;
400398
401- // select_cl_vector_or_scalar does cl_* type selection for element type of
402- // a vector type T and does scalar type substitution. If T is not
403- // vector or scalar unmodified T is returned.
404- template <typename T, typename Enable = void > struct select_cl_vector_or_scalar ;
399+ // select_cl_vector_or_scalar_or_ptr does cl_* type selection for element type
400+ // of a vector type T, pointer type substitution, and scalar type substitution.
401+ // If T is not vector, scalar, or pointer unmodified T is returned.
402+ template <typename T, typename Enable = void >
403+ struct select_cl_vector_or_scalar_or_ptr ;
405404
406405template <typename T>
407- struct select_cl_vector_or_scalar <
406+ struct select_cl_vector_or_scalar_or_ptr <
408407 T, typename detail::enable_if_t <is_vgentype<T>::value>> {
409408 using type =
410409 // select_cl_scalar_t returns _Float16, so, we try to instantiate vec
@@ -417,17 +416,31 @@ struct select_cl_vector_or_scalar<
417416};
418417
419418template <typename T>
420- struct select_cl_vector_or_scalar <
421- T, typename detail::enable_if_t <!is_vgentype<T>::value>> {
419+ struct select_cl_vector_or_scalar_or_ptr <
420+ T, typename detail::enable_if_t <!is_vgentype<T>::value &&
421+ !std::is_pointer<T>::value>> {
422422 using type = select_cl_scalar_t <T>;
423423};
424424
425- // select_cl_mptr_or_vector_or_scalar does cl_* type selection for type
426- // pointed by multi_ptr or for element type of a vector type T and does
427- // scalar type substitution. If T is not mutlti_ptr or vector or scalar
428- // unmodified T is returned.
425+ template <typename T>
426+ struct select_cl_vector_or_scalar_or_ptr <
427+ T, typename detail::enable_if_t <!is_vgentype<T>::value &&
428+ std::is_pointer<T>::value>> {
429+ using elem_ptr_type = typename select_cl_vector_or_scalar_or_ptr<
430+ std::remove_pointer_t <T>>::type *;
431+ #ifdef __SYCL_DEVICE_ONLY__
432+ using type = typename DecoratedType<elem_ptr_type, deduce_AS<T>::value>::type;
433+ #else
434+ using type = elem_ptr_type;
435+ #endif
436+ };
437+
438+ // select_cl_mptr_or_vector_or_scalar_or_ptr does cl_* type selection for type
439+ // pointed by multi_ptr, for raw pointers, for element type of a vector type T,
440+ // and does scalar type substitution. If T is not mutlti_ptr or vector or
441+ // scalar or pointer unmodified T is returned.
429442template <typename T, typename Enable = void >
430- struct select_cl_mptr_or_vector_or_scalar ;
443+ struct select_cl_mptr_or_vector_or_scalar_or_ptr ;
431444
432445// this struct helps to use std::uint8_t instead of std::byte,
433446// which is not supported on device
@@ -444,25 +457,25 @@ template <> struct TypeHelper<std::byte> {
444457template <typename T> using type_helper = typename TypeHelper<T>::RetType;
445458
446459template <typename T>
447- struct select_cl_mptr_or_vector_or_scalar <
460+ struct select_cl_mptr_or_vector_or_scalar_or_ptr <
448461 T, typename detail::enable_if_t <is_genptr<T>::value &&
449462 !std::is_pointer<T>::value>> {
450- using type = multi_ptr<typename select_cl_vector_or_scalar <
463+ using type = multi_ptr<typename select_cl_vector_or_scalar_or_ptr <
451464 type_helper<typename T::element_type>>::type,
452465 T::address_space>;
453466};
454467
455468template <typename T>
456- struct select_cl_mptr_or_vector_or_scalar <
469+ struct select_cl_mptr_or_vector_or_scalar_or_ptr <
457470 T, typename detail::enable_if_t <!is_genptr<T>::value ||
458471 std::is_pointer<T>::value>> {
459- using type = typename select_cl_vector_or_scalar <T>::type;
472+ using type = typename select_cl_vector_or_scalar_or_ptr <T>::type;
460473};
461474
462475// All types converting shortcut.
463476template <typename T>
464477using SelectMatchingOpenCLType_t =
465- typename select_cl_mptr_or_vector_or_scalar <T>::type;
478+ typename select_cl_mptr_or_vector_or_scalar_or_ptr <T>::type;
466479
467480// Converts T to OpenCL friendly
468481//
@@ -492,7 +505,7 @@ typename detail::enable_if_t<!(is_vgentype<FROM>::value &&
492505 sizeof (TO) == sizeof (FROM),
493506 TO>
494507convertDataToType (FROM t) {
495- return TryToGetPointer (t);
508+ return ConvertNonVectorType<TO> (t);
496509}
497510
498511// Used for all, any and select relational built-in functions
0 commit comments