diff --git a/dpctl_ext/tensor/CMakeLists.txt b/dpctl_ext/tensor/CMakeLists.txt index 0b166a20273..6f823a818ce 100644 --- a/dpctl_ext/tensor/CMakeLists.txt +++ b/dpctl_ext/tensor/CMakeLists.txt @@ -58,10 +58,10 @@ set(_tensor_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp ) set(_static_lib_trgt simplify_iteration_space) @@ -92,10 +92,10 @@ endif() set(_no_fast_math_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp - # ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp ) #list( #APPEND _no_fast_math_sources diff --git a/dpctl_ext/tensor/__init__.py b/dpctl_ext/tensor/__init__.py index fa76faccc63..8cd8a1896b2 100644 --- a/dpctl_ext/tensor/__init__.py +++ b/dpctl_ext/tensor/__init__.py @@ -27,6 +27,8 @@ # ***************************************************************************** +from dpctl.tensor._search_functions import where + from dpctl_ext.tensor._copy_utils import ( asnumpy, astype, @@ -50,27 +52,39 @@ take_along_axis, ) from dpctl_ext.tensor._manipulation_functions import ( + repeat, roll, ) from dpctl_ext.tensor._reshape import reshape +from ._clip import clip +from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type + __all__ = [ "asnumpy", "astype", + "can_cast", "copy", + "clip", "extract", "eye", + "finfo", "from_numpy", "full", + "iinfo", + "isdtype", "nonzero", "place", "put", "put_along_axis", + "repeat", "reshape", + "result_type", "roll", "take", "take_along_axis", "to_numpy", "tril", "triu", + "where", ] diff --git a/dpctl_ext/tensor/_clip.py b/dpctl_ext/tensor/_clip.py new file mode 100644 index 00000000000..50d3ecd568e --- /dev/null +++ b/dpctl_ext/tensor/_clip.py @@ -0,0 +1,781 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import dpctl +import dpctl.tensor as dpt +import dpctl.tensor._tensor_elementwise_impl as tei +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti +from dpctl_ext.tensor._copy_utils import ( + _empty_like_orderK, + _empty_like_pair_orderK, + _empty_like_triple_orderK, +) +from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl +from dpctl_ext.tensor._type_utils import _can_cast + +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) +from ._type_utils import ( + _resolve_one_strong_one_weak_types, + _resolve_one_strong_two_weak_types, +) + + +def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev): + """ + Checks if both types `arg1_dtype` and `arg2_dtype` can be + cast to `res_dtype` according to the rule `safe` + """ + if arg1_dtype == res_dtype and arg2_dtype == res_dtype: + return None, None, res_dtype + + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + if _can_cast(arg1_dtype, res_dtype, _fp16, _fp64) and _can_cast( + arg2_dtype, res_dtype, _fp16, _fp64 + ): + # prevent unnecessary casting + ret_buf1_dt = None if res_dtype == arg1_dtype else res_dtype + ret_buf2_dt = None if res_dtype == arg2_dtype else res_dtype + return ret_buf1_dt, ret_buf2_dt, res_dtype + else: + return None, None, None + + +def _clip_none(x, val, out, order, _binary_fn): + q1, x_usm_type = x.sycl_queue, x.usm_type + q2, val_usm_type = _get_queue_usm_type(val) + if q2 is None: + exec_q = q1 + res_usm_type = x_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x_usm_type, + val_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + x_shape = x.shape + val_shape = _get_shape(val) + if not isinstance(val_shape, (tuple, list)): + raise TypeError( + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + "lists, tuples, or both" + ) + try: + res_shape = _broadcast_shape_impl( + [ + x_shape, + val_shape, + ] + ) + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{x_shape} and {val_shape}" + ) + sycl_dev = exec_q.sycl_device + x_dtype = x.dtype + val_dtype = _get_dtype(val, sycl_dev) + if not _validate_dtype(val_dtype): + raise ValueError("Operands have unsupported data types") + + val_dtype = _resolve_one_strong_one_weak_types(x_dtype, val_dtype, sycl_dev) + + res_dt = x.dtype + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + if not _can_cast(val_dtype, res_dt, _fp16, _fp64): + raise ValueError( + f"function 'clip' does not support input types " + f"({x_dtype}, {val_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {res_shape}, got {out.shape}" + ) + + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(x, out): + if not ti._same_logical_tensors(x, out): + out = dpt.empty_like(out) + + if isinstance(val, dpt.usm_ndarray): + if ( + ti._array_overlap(val, out) + and not ti._same_logical_tensors(val, out) + and val_dtype == res_dt + ): + out = dpt.empty_like(out) + + if isinstance(val, dpt.usm_ndarray): + val_ary = val + else: + val_ary = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q) + + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + x, + val_ary, + ) + ) + else "C" + ) + if val_dtype == res_dt: + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x, val_ary, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + if x_shape != res_shape: + x = dpt.broadcast_to(x, res_shape) + if val_ary.shape != res_shape: + val_ary = dpt.broadcast_to(val_ary, res_shape) + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + ht_binary_ev, binary_ev = _binary_fn( + src1=x, src2=val_ary, dst=out, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, copy_ev) + out = orig_out + return out + else: + if order == "K": + buf = _empty_like_orderK(val_ary, res_dt) + else: + buf = dpt.empty_like(val_ary, dtype=res_dt, order=order) + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=val_ary, dst=buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x, buf, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if x_shape != res_shape: + x = dpt.broadcast_to(x, res_shape) + buf = dpt.broadcast_to(buf, res_shape) + ht_binary_ev, binary_ev = _binary_fn( + src1=x, + src2=buf, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + return out + + +def clip(x, /, min=None, max=None, out=None, order="K"): + """clip(x, min=None, max=None, out=None, order="K") + + Clips to the range [`min_i`, `max_i`] for each element `x_i` + in `x`. + + Args: + x (usm_ndarray): Array containing elements to clip. + Must be compatible with `min` and `max` according + to broadcasting rules. + min ({None, Union[usm_ndarray, bool, int, float, complex]}, optional): + Array containing minimum values. + Must be compatible with `x` and `max` according + to broadcasting rules. + max ({None, Union[usm_ndarray, bool, int, float, complex]}, optional): + Array containing maximum values. + Must be compatible with `x` and `min` according + to broadcasting rules. + out ({None, usm_ndarray}, optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the newly output array, if parameter `out` is + `None`. + Default: "K". + + Returns: + usm_ndarray: + An array with elements clipped to the range [`min`, `max`]. + The returned array has the same data type as `x`. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + "Expected `x` to be of dpctl.tensor.usm_ndarray type, got " + f"{type(x)}" + ) + if order not in ["K", "C", "F", "A"]: + order = "K" + if x.dtype.kind in "iu": + if isinstance(min, int) and min <= dpt_ext.iinfo(x.dtype).min: + min = None + if isinstance(max, int) and max >= dpt_ext.iinfo(x.dtype).max: + max = None + if min is None and max is None: + exec_q = x.sycl_queue + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + "output array must be of usm_ndarray type, got " + f"{type(out)}" + ) + + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + if out.shape != x.shape: + raise ValueError( + "The shape of input and output arrays are " + f"inconsistent. Expected output shape is {x.shape}, " + f"got {out.shape}" + ) + + if x.dtype != out.dtype: + raise ValueError( + f"Output array of type {x.dtype} is needed, " + f"got {out.dtype}" + ) + + if ( + dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) + is None + ): + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(x, out): + if not ti._same_logical_tensors(x, out): + out = dpt.empty_like(out) + else: + return out + else: + if order == "K": + out = _empty_like_orderK(x, x.dtype) + else: + out = dpt.empty_like(x, order=order) + + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x, dst=out, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_copy_ev, cpy_ev) + out = orig_out + return out + elif max is None: + return _clip_none(x, min, out, order, tei._maximum) + elif min is None: + return _clip_none(x, max, out, order, tei._minimum) + else: + q1, x_usm_type = x.sycl_queue, x.usm_type + q2, min_usm_type = _get_queue_usm_type(min) + q3, max_usm_type = _get_queue_usm_type(max) + if q2 is None and q3 is None: + exec_q = q1 + res_usm_type = x_usm_type + elif q3 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x_usm_type, + min_usm_type, + ) + ) + elif q2 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q3)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x_usm_type, + max_usm_type, + ) + ) + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2, q3)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x_usm_type, + min_usm_type, + max_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + x_shape = x.shape + min_shape = _get_shape(min) + max_shape = _get_shape(max) + if not all( + isinstance(s, (tuple, list)) + for s in ( + min_shape, + max_shape, + ) + ): + raise TypeError( + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + "lists, tuples, or both" + ) + try: + res_shape = _broadcast_shape_impl( + [ + x_shape, + min_shape, + max_shape, + ] + ) + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{x_shape}, {min_shape}, and {max_shape}" + ) + sycl_dev = exec_q.sycl_device + x_dtype = x.dtype + min_dtype = _get_dtype(min, sycl_dev) + max_dtype = _get_dtype(max, sycl_dev) + if not all(_validate_dtype(o) for o in (min_dtype, max_dtype)): + raise ValueError("Operands have unsupported data types") + + min_dtype, max_dtype = _resolve_one_strong_two_weak_types( + x_dtype, min_dtype, max_dtype, sycl_dev + ) + + buf1_dt, buf2_dt, res_dt = _check_clip_dtypes( + x_dtype, + min_dtype, + max_dtype, + sycl_dev, + ) + + if res_dt is None: + raise ValueError( + f"function '{clip}' does not support input types " + f"({x_dtype}, {min_dtype}, {max_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + "output array must be of usm_ndarray type, got " + f"{type(out)}" + ) + + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are " + f"inconsistent. Expected output shape is {res_shape}, " + f"got {out.shape}" + ) + + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, " + f"got {out.dtype}" + ) + + if ( + dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) + is None + ): + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(x, out): + if not ti._same_logical_tensors(x, out): + out = dpt.empty_like(out) + + if isinstance(min, dpt.usm_ndarray): + if ( + ti._array_overlap(min, out) + and not ti._same_logical_tensors(min, out) + and buf1_dt is None + ): + out = dpt.empty_like(out) + + if isinstance(max, dpt.usm_ndarray): + if ( + ti._array_overlap(max, out) + and not ti._same_logical_tensors(max, out) + and buf2_dt is None + ): + out = dpt.empty_like(out) + + if isinstance(min, dpt.usm_ndarray): + a_min = min + else: + a_min = dpt.asarray(min, dtype=min_dtype, sycl_queue=exec_q) + if isinstance(max, dpt.usm_ndarray): + a_max = max + else: + a_max = dpt.asarray(max, dtype=max_dtype, sycl_queue=exec_q) + + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + x, + a_min, + a_max, + ) + ) + else "C" + ) + if buf1_dt is None and buf2_dt is None: + if out is None: + if order == "K": + out = _empty_like_triple_orderK( + x, + a_min, + a_max, + res_dt, + res_shape, + res_usm_type, + exec_q, + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + if x_shape != res_shape: + x = dpt.broadcast_to(x, res_shape) + if a_min.shape != res_shape: + a_min = dpt.broadcast_to(a_min, res_shape) + if a_max.shape != res_shape: + a_max = dpt.broadcast_to(a_max, res_shape) + _manager = SequentialOrderManager[exec_q] + dep_ev = _manager.submitted_events + ht_binary_ev, binary_ev = ti._clip( + src=x, + min=a_min, + max=a_max, + dst=out, + sycl_queue=exec_q, + depends=dep_ev, + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + return out + + elif buf1_dt is None: + if order == "K": + buf2 = _empty_like_orderK(a_max, buf2_dt) + else: + buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order) + _manager = SequentialOrderManager[exec_q] + dep_ev = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_max, dst=buf2, sycl_queue=exec_q, depends=dep_ev + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if out is None: + if order == "K": + out = _empty_like_triple_orderK( + x, + a_min, + buf2, + res_dt, + res_shape, + res_usm_type, + exec_q, + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + x = dpt.broadcast_to(x, res_shape) + if a_min.shape != res_shape: + a_min = dpt.broadcast_to(a_min, res_shape) + buf2 = dpt.broadcast_to(buf2, res_shape) + ht_binary_ev, binary_ev = ti._clip( + src=x, + min=a_min, + max=buf2, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + return out + + elif buf2_dt is None: + if order == "K": + buf1 = _empty_like_orderK(a_min, buf1_dt) + else: + buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order) + _manager = SequentialOrderManager[exec_q] + dep_ev = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_min, dst=buf1, sycl_queue=exec_q, depends=dep_ev + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if out is None: + if order == "K": + out = _empty_like_triple_orderK( + x, + buf1, + a_max, + res_dt, + res_shape, + res_usm_type, + exec_q, + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + x = dpt.broadcast_to(x, res_shape) + buf1 = dpt.broadcast_to(buf1, res_shape) + if a_max.shape != res_shape: + a_max = dpt.broadcast_to(a_max, res_shape) + ht_binary_ev, binary_ev = ti._clip( + src=x, + min=buf1, + max=a_max, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + return out + + if order == "K": + if ( + x.flags.c_contiguous + and a_min.flags.c_contiguous + and a_max.flags.c_contiguous + ): + order = "C" + elif ( + x.flags.f_contiguous + and a_min.flags.f_contiguous + and a_max.flags.f_contiguous + ): + order = "F" + if order == "K": + buf1 = _empty_like_orderK(a_min, buf1_dt) + else: + buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order) + + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_min, dst=buf1, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy1_ev, copy1_ev) + if order == "K": + buf2 = _empty_like_orderK(a_max, buf2_dt) + else: + buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_max, dst=buf2, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy2_ev, copy2_ev) + if out is None: + if order == "K": + out = _empty_like_triple_orderK( + x, buf1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + x = dpt.broadcast_to(x, res_shape) + buf1 = dpt.broadcast_to(buf1, res_shape) + buf2 = dpt.broadcast_to(buf2, res_shape) + ht_, clip_ev = ti._clip( + src=x, + min=buf1, + max=buf2, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + _manager.add_event_pair(ht_, clip_ev) + return out diff --git a/dpctl_ext/tensor/_copy_utils.py b/dpctl_ext/tensor/_copy_utils.py index 5d1ac209c86..af72544a8b0 100644 --- a/dpctl_ext/tensor/_copy_utils.py +++ b/dpctl_ext/tensor/_copy_utils.py @@ -37,12 +37,12 @@ import numpy as np from dpctl.tensor._data_types import _get_dtype from dpctl.tensor._device import normalize_queue_device -from dpctl.tensor._type_utils import _dtype_supported_by_device_impl # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor import dpctl_ext.tensor as dpt_ext import dpctl_ext.tensor._tensor_impl as ti +from dpctl_ext.tensor._type_utils import _dtype_supported_by_device_impl from ._numpy_helper import normalize_axis_index @@ -291,7 +291,7 @@ def _prepare_indices_arrays(inds, q, usm_type): ) # promote to a common integral type if possible - ind_dt = dpt.result_type(*inds) + ind_dt = dpt_ext.result_type(*inds) if ind_dt.kind not in "ui": raise ValueError( "cannot safely promote indices to an integer data type" @@ -1013,7 +1013,7 @@ def astype( else: target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) - if not dpt.can_cast(ary_dtype, target_dtype, casting=casting): + if not dpt_ext.can_cast(ary_dtype, target_dtype, casting=casting): raise TypeError( f"Can not cast from {ary_dtype} to {newdtype} " f"according to rule {casting}." diff --git a/dpctl_ext/tensor/_manipulation_functions.py b/dpctl_ext/tensor/_manipulation_functions.py index fa8fc27876b..f1b8b46dbcb 100644 --- a/dpctl_ext/tensor/_manipulation_functions.py +++ b/dpctl_ext/tensor/_manipulation_functions.py @@ -26,17 +26,20 @@ # THE POSSIBILITY OF SUCH DAMAGE. # ***************************************************************************** +import itertools import operator +import dpctl import dpctl.tensor as dpt import dpctl.utils as dputils import numpy as np # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext import dpctl_ext.tensor._tensor_impl as ti -from ._numpy_helper import normalize_axis_tuple +from ._numpy_helper import normalize_axis_index, normalize_axis_tuple __doc__ = ( "Implementation module for array manipulation " @@ -44,6 +47,274 @@ ) +def _broadcast_shape_impl(shapes): + if len(set(shapes)) == 1: + return shapes[0] + mutable_shapes = False + nds = [len(s) for s in shapes] + biggest = max(nds) + sh_len = len(shapes) + for i in range(sh_len): + diff = biggest - nds[i] + if diff > 0: + ty = type(shapes[i]) + shapes[i] = ty( + itertools.chain(itertools.repeat(1, diff), shapes[i]) + ) + common_shape = [] + for axis in range(biggest): + lengths = [s[axis] for s in shapes] + unique = set(lengths + [1]) + if len(unique) > 2: + raise ValueError( + "Shape mismatch: two or more arrays have " + f"incompatible dimensions on axis ({axis},)" + ) + elif len(unique) == 2: + unique.remove(1) + new_length = unique.pop() + common_shape.append(new_length) + for i in range(sh_len): + if shapes[i][axis] == 1: + if not mutable_shapes: + shapes = [list(s) for s in shapes] + mutable_shapes = True + shapes[i][axis] = new_length + else: + common_shape.append(1) + + return tuple(common_shape) + + +def repeat(x, repeats, /, *, axis=None): + """repeat(x, repeats, axis=None) + + Repeat elements of an array on a per-element basis. + + Args: + x (usm_ndarray): input array + + repeats (Union[int, Sequence[int, ...], usm_ndarray]): + The number of repetitions for each element. + + `repeats` must be broadcast-compatible with `N` where `N` is + `prod(x.shape)` if `axis` is `None` and `x.shape[axis]` + otherwise. + + If `repeats` is an array, it must have an integer data type. + Otherwise, `repeats` must be a Python integer or sequence of + Python integers (i.e., a tuple, list, or range). + + axis (Optional[int]): + The axis along which to repeat values. If `axis` is `None`, the + function repeats elements of the flattened array. Default: `None`. + + Returns: + usm_ndarray: + output array with repeated elements. + + If `axis` is `None`, the returned array is one-dimensional, + otherwise, it has the same shape as `x`, except for the axis along + which elements were repeated. + + The returned array will have the same data type as `x`. + The returned array will be located on the same device as `x` and + have the same USM allocation type as `x`. + + Raises: + AxisError: if `axis` value is invalid. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") + + x_ndim = x.ndim + x_shape = x.shape + if axis is not None: + axis = normalize_axis_index(operator.index(axis), x_ndim) + axis_size = x_shape[axis] + else: + axis_size = x.size + + scalar = False + if isinstance(repeats, int): + if repeats < 0: + raise ValueError("`repeats` must be a positive integer") + usm_type = x.usm_type + exec_q = x.sycl_queue + scalar = True + elif isinstance(repeats, dpt.usm_ndarray): + if repeats.ndim > 1: + raise ValueError( + "`repeats` array must be 0- or 1-dimensional, got " + f"{repeats.ndim}" + ) + exec_q = dpctl.utils.get_execution_queue( + (x.sycl_queue, repeats.sycl_queue) + ) + if exec_q is None: + raise dputils.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + usm_type = dpctl.utils.get_coerced_usm_type( + ( + x.usm_type, + repeats.usm_type, + ) + ) + dpctl.utils.validate_usm_type(usm_type, allow_none=False) + if not dpt_ext.can_cast(repeats.dtype, dpt.int64, casting="same_kind"): + raise TypeError( + f"'repeats' data type {repeats.dtype} cannot be cast to " + "'int64' according to the casting rule ''safe.''" + ) + if repeats.size == 1: + scalar = True + # bring the single element to the host + if repeats.ndim == 0: + repeats = int(repeats) + else: + # Get the single element explicitly + # since non-0D arrays can not be converted to scalars + repeats = int(repeats[0]) + if repeats < 0: + raise ValueError("`repeats` elements must be positive") + else: + if repeats.size != axis_size: + raise ValueError( + "'repeats' array must be broadcastable to the size of " + "the repeated axis" + ) + if not dpt.all(repeats >= 0): + raise ValueError("'repeats' elements must be positive") + + elif isinstance(repeats, (tuple, list, range)): + usm_type = x.usm_type + exec_q = x.sycl_queue + + len_reps = len(repeats) + if len_reps == 1: + repeats = repeats[0] + if repeats < 0: + raise ValueError("`repeats` elements must be positive") + scalar = True + else: + if len_reps != axis_size: + raise ValueError( + "`repeats` sequence must have the same length as the " + "repeated axis" + ) + repeats = dpt.asarray( + repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q + ) + if not dpt.all(repeats >= 0): + raise ValueError("`repeats` elements must be positive") + else: + raise TypeError( + "Expected int, sequence, or `usm_ndarray` for second argument," + f"got {type(repeats)}" + ) + + _manager = dputils.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if scalar: + res_axis_size = repeats * axis_size + if axis is not None: + res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + else: + res_shape = (res_axis_size,) + res = dpt.empty( + res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q + ) + if res_axis_size > 0: + ht_rep_ev, rep_ev = ti._repeat_by_scalar( + src=x, + dst=res, + reps=repeats, + axis=axis, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_rep_ev, rep_ev) + else: + if repeats.dtype != dpt.int64: + rep_buf = dpt.empty( + repeats.shape, + dtype=dpt.int64, + usm_type=usm_type, + sycl_queue=exec_q, + ) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=repeats, dst=rep_buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + cumsum = dpt.empty( + (axis_size,), + dtype=dpt.int64, + usm_type=usm_type, + sycl_queue=exec_q, + ) + # _cumsum_1d synchronizes so `depends` ends here safely + res_axis_size = ti._cumsum_1d( + rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev] + ) + if axis is not None: + res_shape = ( + x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + ) + else: + res_shape = (res_axis_size,) + res = dpt.empty( + res_shape, + dtype=x.dtype, + usm_type=usm_type, + sycl_queue=exec_q, + ) + if res_axis_size > 0: + ht_rep_ev, rep_ev = ti._repeat_by_sequence( + src=x, + dst=res, + reps=rep_buf, + cumsum=cumsum, + axis=axis, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_rep_ev, rep_ev) + else: + cumsum = dpt.empty( + (axis_size,), + dtype=dpt.int64, + usm_type=usm_type, + sycl_queue=exec_q, + ) + res_axis_size = ti._cumsum_1d( + repeats, cumsum, sycl_queue=exec_q, depends=dep_evs + ) + if axis is not None: + res_shape = ( + x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] + ) + else: + res_shape = (res_axis_size,) + res = dpt.empty( + res_shape, + dtype=x.dtype, + usm_type=usm_type, + sycl_queue=exec_q, + ) + if res_axis_size > 0: + ht_rep_ev, rep_ev = ti._repeat_by_sequence( + src=x, + dst=res, + reps=repeats, + cumsum=cumsum, + axis=axis, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_rep_ev, rep_ev) + return res + + def roll(x, /, shift, *, axis=None): """ roll(x, shift, axis) diff --git a/dpctl_ext/tensor/_scalar_utils.py b/dpctl_ext/tensor/_scalar_utils.py new file mode 100644 index 00000000000..86787baea8c --- /dev/null +++ b/dpctl_ext/tensor/_scalar_utils.py @@ -0,0 +1,122 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import numbers + +import dpctl.memory as dpm +import dpctl.tensor as dpt +import numpy as np +from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer + +from ._type_utils import ( + WeakBooleanType, + WeakComplexType, + WeakFloatingType, + WeakIntegralType, + _to_device_supported_dtype, +) + + +def _get_queue_usm_type(o): + """Return SYCL device where object `o` allocated memory, or None.""" + if isinstance(o, dpt.usm_ndarray): + return o.sycl_queue, o.usm_type + elif hasattr(o, "__sycl_usm_array_interface__"): + try: + m = dpm.as_usm_memory(o) + return m.sycl_queue, m.get_usm_type() + except Exception: + return None, None + return None, None + + +def _get_dtype(o, dev): + if isinstance(o, dpt.usm_ndarray): + return o.dtype + if hasattr(o, "__sycl_usm_array_interface__"): + return dpt.asarray(o).dtype + if _is_buffer(o): + host_dt = np.array(o).dtype + dev_dt = _to_device_supported_dtype(host_dt, dev) + return dev_dt + if hasattr(o, "dtype"): + dev_dt = _to_device_supported_dtype(o.dtype, dev) + return dev_dt + if isinstance(o, bool): + return WeakBooleanType(o) + if isinstance(o, int): + return WeakIntegralType(o) + if isinstance(o, float): + return WeakFloatingType(o) + if isinstance(o, complex): + return WeakComplexType(o) + return np.object_ + + +def _validate_dtype(dt) -> bool: + return isinstance( + dt, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), + ) or ( + isinstance(dt, dpt.dtype) + and dt + in [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float16, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + ) + + +def _get_shape(o): + if isinstance(o, dpt.usm_ndarray): + return o.shape + if _is_buffer(o): + return memoryview(o).shape + if isinstance(o, numbers.Number): + return () + return getattr(o, "shape", tuple()) + + +__all__ = [ + "_get_dtype", + "_get_queue_usm_type", + "_get_shape", + "_validate_dtype", +] diff --git a/dpctl_ext/tensor/_search_functions.py b/dpctl_ext/tensor/_search_functions.py new file mode 100644 index 00000000000..053c68e1857 --- /dev/null +++ b/dpctl_ext/tensor/_search_functions.py @@ -0,0 +1,419 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import dpctl +import dpctl.tensor as dpt +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti +from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl + +from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) +from ._type_utils import ( + WeakBooleanType, + WeakComplexType, + WeakFloatingType, + WeakIntegralType, + _all_data_types, + _can_cast, + _is_weak_dtype, + _strong_dtype_num_kind, + _to_device_supported_dtype, + _weak_type_num_kind, +) + + +def _default_dtype_from_weak_type(dt, dev): + if isinstance(dt, WeakBooleanType): + return dpt.bool + if isinstance(dt, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(dt, WeakFloatingType): + return dpt.dtype(ti.default_device_fp_type(dev)) + if isinstance(dt, WeakComplexType): + return dpt.dtype(ti.default_device_complex_type(dev)) + + +def _resolve_two_weak_types(o1_dtype, o2_dtype, dev): + """Resolves two weak data types per NEP-0050""" + if _is_weak_dtype(o1_dtype): + if _is_weak_dtype(o2_dtype): + return _default_dtype_from_weak_type( + o1_dtype, dev + ), _default_dtype_from_weak_type(o2_dtype, dev) + o1_kind_num = _weak_type_num_kind(o1_dtype) + o2_kind_num = _strong_dtype_num_kind(o2_dtype) + if o1_kind_num > o2_kind_num: + if isinstance(o1_dtype, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + if isinstance(o1_dtype, WeakComplexType): + if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: + return dpt.complex64, o2_dtype + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + else: + return o2_dtype, o2_dtype + elif _is_weak_dtype(o2_dtype): + o1_kind_num = _strong_dtype_num_kind(o1_dtype) + o2_kind_num = _weak_type_num_kind(o2_dtype) + if o2_kind_num > o1_kind_num: + if isinstance(o2_dtype, WeakIntegralType): + return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(o2_dtype, WeakComplexType): + if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: + return o1_dtype, dpt.complex64 + return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) + else: + return o1_dtype, o1_dtype + else: + return o1_dtype, o2_dtype + + +def _where_result_type(dt1, dt2, dev): + res_dtype = dpt_ext.result_type(dt1, dt2) + fp16 = dev.has_aspect_fp16 + fp64 = dev.has_aspect_fp64 + + all_dts = _all_data_types(fp16, fp64) + if res_dtype in all_dts: + return res_dtype + else: + for res_dtype_ in all_dts: + if _can_cast(dt1, res_dtype_, fp16, fp64) and _can_cast( + dt2, res_dtype_, fp16, fp64 + ): + return res_dtype_ + return None + + +def where(condition, x1, x2, /, *, order="K", out=None): + """ + Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen + from ``x1`` or ``x2`` depending on ``condition``. + + Args: + condition (usm_ndarray): When ``True`` yields from ``x1``, + and otherwise yields from ``x2``. + Must be compatible with ``x1`` and ``x2`` according + to broadcasting rules. + x1 (Union[usm_ndarray, bool, int, float, complex]): + Array from which values are chosen when ``condition`` is ``True``. + Must be compatible with ``condition`` and ``x2`` according + to broadcasting rules. + x2 (Union[usm_ndarray, bool, int, float, complex]): + Array from which values are chosen when ``condition`` is not + ``True``. + Must be compatible with ``condition`` and ``x2`` according + to broadcasting rules. + order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional): + Memory layout of the new output array, + if parameter ``out`` is ``None``. + Default: ``"K"``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of `out` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + An array with elements from ``x1`` where ``condition`` is ``True``, + and elements from ``x2`` elsewhere. + + The data type of the returned array is determined by applying + the Type Promotion Rules to ``x1`` and ``x2``. + """ + if not isinstance(condition, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}" + ) + if order not in ["K", "C", "F", "A"]: + order = "K" + q1, condition_usm_type = condition.sycl_queue, condition.usm_type + q2, x1_usm_type = _get_queue_usm_type(x1) + q3, x2_usm_type = _get_queue_usm_type(x2) + if q2 is None and q3 is None: + exec_q = q1 + out_usm_type = condition_usm_type + elif q3 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + out_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition_usm_type, + x1_usm_type, + ) + ) + elif q2 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q3)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + out_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition_usm_type, + x2_usm_type, + ) + ) + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2, q3)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + out_usm_type = dpctl.utils.get_coerced_usm_type( + ( + condition_usm_type, + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(out_usm_type, allow_none=False) + condition_shape = condition.shape + x1_shape = _get_shape(x1) + x2_shape = _get_shape(x2) + if not all( + isinstance(s, (tuple, list)) + for s in ( + x1_shape, + x2_shape, + ) + ): + raise TypeError( + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + "lists, tuples, or both" + ) + try: + res_shape = _broadcast_shape_impl( + [ + condition_shape, + x1_shape, + x2_shape, + ] + ) + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{condition_shape}, {x1_shape}, and {x2_shape}" + ) + sycl_dev = exec_q.sycl_device + x1_dtype = _get_dtype(x1, sycl_dev) + x2_dtype = _get_dtype(x2, sycl_dev) + if not all(_validate_dtype(o) for o in (x1_dtype, x2_dtype)): + raise ValueError("Operands have unsupported data types") + x1_dtype, x2_dtype = _resolve_two_weak_types(x1_dtype, x2_dtype, sycl_dev) + out_dtype = _where_result_type(x1_dtype, x2_dtype, sycl_dev) + if out_dtype is None: + raise TypeError( + "function 'where' does not support input " + f"types ({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced " + "to any supported types according to the casting rule ''safe''." + ) + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + "output array must be of usm_ndarray type, got " f"{type(out)}" + ) + + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are " + f"inconsistent. Expected output shape is {res_shape}, " + f"got {out.shape}" + ) + + if out_dtype != out.dtype: + raise ValueError( + f"Output array of type {out_dtype} is needed, " + f"got {out.dtype}" + ) + + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(condition, out) and not ti._same_logical_tensors( + condition, out + ): + out = dpt.empty_like(out) + + if isinstance(x1, dpt.usm_ndarray): + if ( + ti._array_overlap(x1, out) + and not ti._same_logical_tensors(x1, out) + and x1_dtype == out_dtype + ): + out = dpt.empty_like(out) + + if isinstance(x2, dpt.usm_ndarray): + if ( + ti._array_overlap(x2, out) + and not ti._same_logical_tensors(x2, out) + and x2_dtype == out_dtype + ): + out = dpt.empty_like(out) + + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + condition, + x1, + x2, + ) + ) + else "C" + ) + if not isinstance(x1, dpt.usm_ndarray): + x1 = dpt.asarray(x1, dtype=x1_dtype, sycl_queue=exec_q) + if not isinstance(x2, dpt.usm_ndarray): + x2 = dpt.asarray(x2, dtype=x2_dtype, sycl_queue=exec_q) + + if condition.size == 0: + if out is not None: + return out + else: + if order == "K": + return _empty_like_triple_orderK( + condition, + x1, + x2, + out_dtype, + res_shape, + out_usm_type, + exec_q, + ) + else: + return dpt.empty( + res_shape, + dtype=out_dtype, + order=order, + usm_type=out_usm_type, + sycl_queue=exec_q, + ) + + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if x1_dtype != out_dtype: + if order == "K": + _x1 = _empty_like_orderK(x1, out_dtype) + else: + _x1 = dpt.empty_like(x1, dtype=out_dtype, order=order) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=_x1, sycl_queue=exec_q, depends=dep_evs + ) + x1 = _x1 + _manager.add_event_pair(ht_copy1_ev, copy1_ev) + + if x2_dtype != out_dtype: + if order == "K": + _x2 = _empty_like_orderK(x2, out_dtype) + else: + _x2 = dpt.empty_like(x2, dtype=out_dtype, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=_x2, sycl_queue=exec_q, depends=dep_evs + ) + x2 = _x2 + _manager.add_event_pair(ht_copy2_ev, copy2_ev) + + if out is None: + if order == "K": + out = _empty_like_triple_orderK( + condition, x1, x2, out_dtype, res_shape, out_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=out_dtype, + order=order, + usm_type=out_usm_type, + sycl_queue=exec_q, + ) + + if condition_shape != res_shape: + condition = dpt.broadcast_to(condition, res_shape) + if x1_shape != res_shape: + x1 = dpt.broadcast_to(x1, res_shape) + if x2_shape != res_shape: + x2 = dpt.broadcast_to(x2, res_shape) + + dep_evs = _manager.submitted_events + hev, where_ev = ti._where( + condition=condition, + x1=x1, + x2=x2, + dst=out, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(hev, where_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[where_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + + return out diff --git a/dpctl_ext/tensor/_type_utils.py b/dpctl_ext/tensor/_type_utils.py new file mode 100644 index 00000000000..1e386e15dfa --- /dev/null +++ b/dpctl_ext/tensor/_type_utils.py @@ -0,0 +1,999 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +from __future__ import annotations + +import dpctl.tensor as dpt +import numpy as np + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti + + +def _all_data_types(_fp16, _fp64): + _non_fp_types = [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + ] + if _fp64: + if _fp16: + return _non_fp_types + [ + dpt.float16, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + else: + return _non_fp_types + [ + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + else: + if _fp16: + return _non_fp_types + [ + dpt.float16, + dpt.float32, + dpt.complex64, + ] + else: + return _non_fp_types + [ + dpt.float32, + dpt.complex64, + ] + + +def _acceptance_fn_default_binary( + arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev +): + return True + + +def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev): + return True + + +def _acceptance_fn_divide( + arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev +): + # both are being promoted, if the kind of result is + # different than the kind of original input dtypes, + # we use default dtype for the resulting kind. + # This covers, e.g. (array_dtype_i1 / array_dtype_u1) + # result of which in divide is double (in NumPy), but + # regular type promotion rules peg at float16 + if (ret_buf1_dt.kind != arg1_dtype.kind) and ( + ret_buf2_dt.kind != arg2_dtype.kind + ): + default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev) + if res_dt == default_dt: + return True + else: + return False + else: + return True + + +def _acceptance_fn_negative(arg_dtype, buf_dt, res_dt, sycl_dev): + # negative is not defined for boolean data type + if arg_dtype.char == "?": + raise ValueError( + "The `negative` function, the `-` operator, is not supported " + "for inputs of data type bool, use the `~` operator or the " + "`logical_not` function instead" + ) + else: + return True + + +def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev): + # if the kind of result is different from the kind of input, we use the + # default floating-point dtype for the resulting kind. This guarantees + # alignment of reciprocal and divide output types. + if buf_dt.kind != arg_dtype.kind: + default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev) + if res_dt == default_dt: + return True + else: + return False + else: + return True + + +def _acceptance_fn_subtract( + arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev +): + # subtract is not defined for boolean data type + if arg1_dtype.char == "?" and arg2_dtype.char == "?": + raise ValueError( + "The `subtract` function, the `-` operator, is not supported " + "for inputs of data type bool, use the `^` operator, the " + "`bitwise_xor`, or the `logical_xor` function instead" + ) + else: + return True + + +def _can_cast( + from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool, casting="safe" +) -> bool: + """ + Can `from_` be cast to `to_` safely on a device with + fp16 and fp64 aspects as given? + """ + if not _dtype_supported_by_device_impl(to_, _fp16, _fp64): + return False + can_cast_v = np.can_cast(from_, to_, casting=casting) # ask NumPy + if _fp16 and _fp64: + return can_cast_v + if not can_cast_v: + if ( + from_.kind in "biu" + and to_.kind in "fc" + and _is_maximal_inexact_type(to_, _fp16, _fp64) + ): + return True + + return can_cast_v + + +def _dtype_supported_by_device_impl( + dt: dpt.dtype, has_fp16: bool, has_fp64: bool +) -> bool: + if has_fp64: + if not has_fp16: + if dt is dpt.float16: + return False + else: + if dt is dpt.float64: + return False + elif dt is dpt.complex128: + return False + if not has_fp16 and dt is dpt.float16: + return False + return True + + +def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn): + res_dt = query_fn(arg_dtype) + if res_dt: + return None, res_dt + + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + all_dts = _all_data_types(_fp16, _fp64) + for buf_dt in all_dts: + if _can_cast(arg_dtype, buf_dt, _fp16, _fp64): + res_dt = query_fn(buf_dt) + if res_dt: + acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev) + if acceptable: + return buf_dt, res_dt + else: + continue + + return None, None + + +def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn): + res_dt = query_fn(arg1_dtype, arg2_dtype) + if res_dt: + return None, None, res_dt + + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + all_dts = _all_data_types(_fp16, _fp64) + for buf1_dt in all_dts: + for buf2_dt in all_dts: + if _can_cast(arg1_dtype, buf1_dt, _fp16, _fp64) and _can_cast( + arg2_dtype, buf2_dt, _fp16, _fp64 + ): + res_dt = query_fn(buf1_dt, buf2_dt) + if res_dt: + ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt + ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt + if ret_buf1_dt is None or ret_buf2_dt is None: + return ret_buf1_dt, ret_buf2_dt, res_dt + else: + acceptable = acceptance_fn( + arg1_dtype, + arg2_dtype, + ret_buf1_dt, + ret_buf2_dt, + res_dt, + sycl_dev, + ) + if acceptable: + return ret_buf1_dt, ret_buf2_dt, res_dt + else: + continue + + return None, None, None + + +def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev): + res_dt = query_fn(arg1_dtype, arg2_dtype) + if res_dt: + return None, res_dt + + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"): + res_dt = query_fn(arg1_dtype, arg1_dtype) + if res_dt: + return arg1_dtype, res_dt + + return None, None + + +def _get_device_default_dtype(dt_kind, sycl_dev): + if dt_kind == "b": + return dpt.dtype(ti.default_device_bool_type(sycl_dev)) + elif dt_kind == "i": + return dpt.dtype(ti.default_device_int_type(sycl_dev)) + elif dt_kind == "u": + return dpt.dtype(ti.default_device_uint_type(sycl_dev)) + elif dt_kind == "f": + return dpt.dtype(ti.default_device_fp_type(sycl_dev)) + elif dt_kind == "c": + return dpt.dtype(ti.default_device_complex_type(sycl_dev)) + raise RuntimeError + + +def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool): + """ + Return True if data type `dt` is the + maximal size inexact data type + """ + if _fp64: + return dt in [dpt.float64, dpt.complex128] + return dt in [dpt.float32, dpt.complex64] + + +def _to_device_supported_dtype(dt, dev): + has_fp16 = dev.has_aspect_fp16 + has_fp64 = dev.has_aspect_fp64 + + return _to_device_supported_dtype_impl(dt, has_fp16, has_fp64) + + +def _to_device_supported_dtype_impl(dt, has_fp16, has_fp64): + if has_fp64: + if not has_fp16: + if dt is dpt.float16: + return dpt.float32 + else: + if dt is dpt.float64: + return dpt.float32 + elif dt is dpt.complex128: + return dpt.complex64 + if not has_fp16 and dt is dpt.float16: + return dpt.float32 + return dt + + +class WeakBooleanType: + """Python type representing type of Python boolean objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakIntegralType: + """Python type representing type of Python integral objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakFloatingType: + """Python type representing type of Python floating point objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakComplexType: + """Python type representing type of Python complex floating point objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +def _weak_type_num_kind(o): + _map = {"?": 0, "i": 1, "f": 2, "c": 3} + if isinstance(o, WeakBooleanType): + return _map["?"] + if isinstance(o, WeakIntegralType): + return _map["i"] + if isinstance(o, WeakFloatingType): + return _map["f"] + if isinstance(o, WeakComplexType): + return _map["c"] + raise TypeError( + f"Unexpected type {o} while expecting " + "`WeakBooleanType`, `WeakIntegralType`," + "`WeakFloatingType`, or `WeakComplexType`." + ) + + +def _strong_dtype_num_kind(o): + _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3} + if not isinstance(o, dpt.dtype): + raise TypeError + k = o.kind + if k in _map: + return _map[k] + raise ValueError(f"Unrecognized kind {k} for dtype {o}") + + +def _is_weak_dtype(dtype): + return isinstance( + dtype, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), + ) + + +def _resolve_weak_types(o1_dtype, o2_dtype, dev): + """Resolves weak data type per NEP-0050""" + if _is_weak_dtype(o1_dtype): + if _is_weak_dtype(o2_dtype): + raise ValueError + o1_kind_num = _weak_type_num_kind(o1_dtype) + o2_kind_num = _strong_dtype_num_kind(o2_dtype) + if o1_kind_num > o2_kind_num: + if isinstance(o1_dtype, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + if isinstance(o1_dtype, WeakComplexType): + if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: + return dpt.complex64, o2_dtype + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + else: + return o2_dtype, o2_dtype + elif _is_weak_dtype(o2_dtype): + o1_kind_num = _strong_dtype_num_kind(o1_dtype) + o2_kind_num = _weak_type_num_kind(o2_dtype) + if o2_kind_num > o1_kind_num: + if isinstance(o2_dtype, WeakIntegralType): + return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(o2_dtype, WeakComplexType): + if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: + return o1_dtype, dpt.complex64 + return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) + else: + return o1_dtype, o1_dtype + else: + return o1_dtype, o2_dtype + + +def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev): + """ + Resolves weak data type per NEP-0050 for comparisons and + divide, where result type is known and special behavior + is needed to handle mixed integer kinds and Python integers + without overflow + """ + if _is_weak_dtype(o1_dtype): + if _is_weak_dtype(o2_dtype): + raise ValueError + o1_kind_num = _weak_type_num_kind(o1_dtype) + o2_kind_num = _strong_dtype_num_kind(o2_dtype) + if o1_kind_num > o2_kind_num: + if isinstance(o1_dtype, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + if isinstance(o1_dtype, WeakComplexType): + if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: + return dpt.complex64, o2_dtype + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + else: + if o1_kind_num == o2_kind_num and isinstance( + o1_dtype, WeakIntegralType + ): + o1_val = o1_dtype.get() + o2_iinfo = dpt_ext.iinfo(o2_dtype) + if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max): + return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype + return o2_dtype, o2_dtype + elif _is_weak_dtype(o2_dtype): + o1_kind_num = _strong_dtype_num_kind(o1_dtype) + o2_kind_num = _weak_type_num_kind(o2_dtype) + if o2_kind_num > o1_kind_num: + if isinstance(o2_dtype, WeakIntegralType): + return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(o2_dtype, WeakComplexType): + if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: + return o1_dtype, dpt.complex64 + return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) + else: + if o1_kind_num == o2_kind_num and isinstance( + o2_dtype, WeakIntegralType + ): + o2_val = o2_dtype.get() + o1_iinfo = dpt_ext.iinfo(o1_dtype) + if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max): + return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val)) + return o1_dtype, o1_dtype + else: + return o1_dtype, o2_dtype + + +def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev): + """ + Resolves weak data types per NEP-0050, + where the second and third arguments are + permitted to be weak types + """ + if _is_weak_dtype(st_dtype): + raise ValueError + if _is_weak_dtype(dtype1): + if _is_weak_dtype(dtype2): + kind_num1 = _weak_type_num_kind(dtype1) + kind_num2 = _weak_type_num_kind(dtype2) + st_kind_num = _strong_dtype_num_kind(st_dtype) + + if kind_num1 > st_kind_num: + if isinstance(dtype1, WeakIntegralType): + ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev)) + elif isinstance(dtype1, WeakComplexType): + if st_dtype is dpt.float16 or st_dtype is dpt.float32: + ret_dtype1 = dpt.complex64 + ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev) + else: + ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev) + else: + ret_dtype1 = st_dtype + + if kind_num2 > st_kind_num: + if isinstance(dtype2, WeakIntegralType): + ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev)) + elif isinstance(dtype2, WeakComplexType): + if st_dtype is dpt.float16 or st_dtype is dpt.float32: + ret_dtype2 = dpt.complex64 + ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev) + else: + ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev) + else: + ret_dtype2 = st_dtype + + return ret_dtype1, ret_dtype2 + + max_dt_num_kind, max_dtype = max( + [ + (_strong_dtype_num_kind(st_dtype), st_dtype), + (_strong_dtype_num_kind(dtype2), dtype2), + ] + ) + dt1_kind_num = _weak_type_num_kind(dtype1) + if dt1_kind_num > max_dt_num_kind: + if isinstance(dtype1, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)), dtype2 + if isinstance(dtype1, WeakComplexType): + if max_dtype is dpt.float16 or max_dtype is dpt.float32: + return dpt.complex64, dtype2 + return ( + _to_device_supported_dtype(dpt.complex128, dev), + dtype2, + ) + return _to_device_supported_dtype(dpt.float64, dev), dtype2 + else: + return max_dtype, dtype2 + elif _is_weak_dtype(dtype2): + max_dt_num_kind, max_dtype = max( + [ + (_strong_dtype_num_kind(st_dtype), st_dtype), + (_strong_dtype_num_kind(dtype1), dtype1), + ] + ) + dt2_kind_num = _weak_type_num_kind(dtype2) + if dt2_kind_num > max_dt_num_kind: + if isinstance(dtype2, WeakIntegralType): + return dtype1, dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(dtype2, WeakComplexType): + if max_dtype is dpt.float16 or max_dtype is dpt.float32: + return dtype1, dpt.complex64 + return ( + dtype1, + _to_device_supported_dtype(dpt.complex128, dev), + ) + return dtype1, _to_device_supported_dtype(dpt.float64, dev) + else: + return dtype1, max_dtype + else: + # both are strong dtypes + # return unmodified + return dtype1, dtype2 + + +def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev): + """Resolves one weak data type with one strong data type per NEP-0050""" + if _is_weak_dtype(st_dtype): + raise ValueError + if _is_weak_dtype(dtype): + st_kind_num = _strong_dtype_num_kind(st_dtype) + kind_num = _weak_type_num_kind(dtype) + if kind_num > st_kind_num: + if isinstance(dtype, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(dtype, WeakComplexType): + if st_dtype is dpt.float16 or st_dtype is dpt.float32: + return dpt.complex64 + return _to_device_supported_dtype(dpt.complex128, dev) + return _to_device_supported_dtype(dpt.float64, dev) + else: + return st_dtype + else: + return dtype + + +class finfo_object: + """ + `numpy.finfo` subclass which returns Python floating-point scalars for + `eps`, `max`, `min`, and `smallest_normal` attributes. + """ + + def __init__(self, dtype): + _supported_dtype([dpt.dtype(dtype)]) + self._finfo = np.finfo(dtype) + + @property + def bits(self): + """Number of bits occupied by the real-valued floating-point data type.""" + return int(self._finfo.bits) + + @property + def smallest_normal(self): + """ + Smallest positive real-valued floating-point number with full + precision. + """ + return float(self._finfo.smallest_normal) + + @property + def tiny(self): + """An alias for `smallest_normal`""" + return float(self._finfo.tiny) + + @property + def eps(self): + """ + Difference between 1.0 and the next smallest representable real-valued + floating-point number larger than 1.0 according to the IEEE-754 + standard. + """ + return float(self._finfo.eps) + + @property + def epsneg(self): + """ + Difference between 1.0 and the next smallest representable real-valued + floating-point number smaller than 1.0 according to the IEEE-754 + standard. + """ + return float(self._finfo.epsneg) + + @property + def min(self): + """Smallest representable real-valued number.""" + return float(self._finfo.min) + + @property + def max(self): + """Largest representable real-valued number.""" + return float(self._finfo.max) + + @property + def resolution(self): + """The approximate decimal resolution of this type.""" + return float(self._finfo.resolution) + + @property + def precision(self): + """ + The approximate number of decimal digits to which this kind of + floating point type is precise. + """ + return float(self._finfo.precision) + + @property + def dtype(self): + """ + The dtype for which finfo returns information. For complex input, the + returned dtype is the associated floating point dtype for its real and + complex components. + """ + return self._finfo.dtype + + def __str__(self): + return self._finfo.__str__() + + def __repr__(self): + return self._finfo.__repr__() + + +def can_cast(from_, to, /, *, casting="safe") -> bool: + """can_cast(from, to, casting="safe") + + Determines if one data type can be cast to another data type according \ + to Type Promotion Rules. + + Args: + from_ (Union[usm_ndarray, dtype]): + source data type. If `from_` is an array, a device-specific type + promotion rules apply. + to (dtype): + target data type + casting (Optional[str]): + controls what kind of data casting may occur. + + * "no" means data types should not be cast at all. + * "safe" means only casts that preserve values are allowed. + * "same_kind" means only safe casts and casts within a kind, + like `float64` to `float32`, are allowed. + * "unsafe" means any data conversion can be done. + + Default: `"safe"`. + + Returns: + bool: + Gives `True` if cast can occur according to the casting rule. + + Device-specific type promotion rules take into account which data type are + and are not supported by a specific device. + """ + if isinstance(to, dpt.usm_ndarray): + raise TypeError(f"Expected `dpt.dtype` type, got {type(to)}.") + + dtype_to = dpt.dtype(to) + _supported_dtype([dtype_to]) + + if isinstance(from_, dpt.usm_ndarray): + dtype_from = from_.dtype + return _can_cast( + dtype_from, + dtype_to, + from_.sycl_device.has_aspect_fp16, + from_.sycl_device.has_aspect_fp64, + casting=casting, + ) + else: + dtype_from = dpt.dtype(from_) + _supported_dtype([dtype_from]) + # query casting as if all dtypes are supported + return _can_cast(dtype_from, dtype_to, True, True, casting=casting) + + +def result_type(*arrays_and_dtypes): + """ + result_type(*arrays_and_dtypes) + + Returns the dtype that results from applying the Type Promotion Rules to \ + the arguments. + + Args: + arrays_and_dtypes (Union[usm_ndarray, dtype]): + An arbitrary length sequence of usm_ndarray objects or dtypes. + + Returns: + dtype: + The dtype resulting from an operation involving the + input arrays and dtypes. + """ + dtypes = [] + devices = [] + weak_dtypes = [] + for arg_i in arrays_and_dtypes: + if isinstance(arg_i, dpt.usm_ndarray): + devices.append(arg_i.sycl_device) + dtypes.append(arg_i.dtype) + elif isinstance(arg_i, int): + weak_dtypes.append(WeakIntegralType(arg_i)) + elif isinstance(arg_i, float): + weak_dtypes.append(WeakFloatingType(arg_i)) + elif isinstance(arg_i, complex): + weak_dtypes.append(WeakComplexType(arg_i)) + elif isinstance(arg_i, bool): + weak_dtypes.append(WeakBooleanType(arg_i)) + else: + dt = dpt.dtype(arg_i) + _supported_dtype([dt]) + dtypes.append(dt) + + has_fp16 = True + has_fp64 = True + target_dev = None + if devices: + inspected = False + for d in devices: + if inspected: + unsame_fp16_support = d.has_aspect_fp16 != has_fp16 + unsame_fp64_support = d.has_aspect_fp64 != has_fp64 + if unsame_fp16_support or unsame_fp64_support: + raise ValueError( + "Input arrays reside on devices " + "with different device supports; " + "unable to determine which " + "device-specific type promotion rules " + "to use." + ) + else: + has_fp16 = d.has_aspect_fp16 + has_fp64 = d.has_aspect_fp64 + target_dev = d + inspected = True + + if not dtypes and weak_dtypes: + dtypes.append(weak_dtypes[0].get()) + + if not (has_fp16 and has_fp64): + for dt in dtypes: + if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64): + raise ValueError( + f"Argument {dt} is not supported by the device" + ) + res_dt = np.result_type(*dtypes) + res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64) + for wdt in weak_dtypes: + pair = _resolve_weak_types(wdt, res_dt, target_dev) + res_dt = np.result_type(*pair) + res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64) + else: + res_dt = np.result_type(*dtypes) + if weak_dtypes: + weak_dt_obj = [wdt.get() for wdt in weak_dtypes] + res_dt = np.result_type(res_dt, *weak_dt_obj) + + return res_dt + + +def iinfo(dtype, /): + """iinfo(dtype) + + Returns machine limits for integer data types. + + Args: + dtype (dtype, usm_ndarray): + integer dtype or + an array with integer dtype. + + Returns: + iinfo_object: + An object with the following attributes: + + * bits: int + number of bits occupied by the data type + * max: int + largest representable number. + * min: int + smallest representable number. + * dtype: dtype + integer data type. + """ + if isinstance(dtype, dpt.usm_ndarray): + dtype = dtype.dtype + _supported_dtype([dpt.dtype(dtype)]) + return np.iinfo(dtype) + + +def finfo(dtype, /): + """finfo(type) + + Returns machine limits for floating-point data types. + + Args: + dtype (dtype, usm_ndarray): floating-point dtype or + an array with floating point data type. + If complex, the information is about its component + data type. + + Returns: + finfo_object: + an object have the following attributes: + + * bits: int + number of bits occupied by dtype. + * eps: float + difference between 1.0 and the next smallest representable + real-valued floating-point number larger than 1.0 according + to the IEEE-754 standard. + * max: float + largest representable real-valued number. + * min: float + smallest representable real-valued number. + * smallest_normal: float + smallest positive real-valued floating-point number with + full precision. + * dtype: dtype + real-valued floating-point data type. + + """ + if isinstance(dtype, dpt.usm_ndarray): + dtype = dtype.dtype + _supported_dtype([dpt.dtype(dtype)]) + return finfo_object(dtype) + + +def _supported_dtype(dtypes): + for dtype in dtypes: + if dtype.char not in "?bBhHiIlLqQefdFD": + raise ValueError(f"Dpctl doesn't support dtype {dtype}.") + return True + + +def isdtype(dtype, kind): + """isdtype(dtype, kind) + + Returns a boolean indicating whether a provided `dtype` is + of a specified data type `kind`. + + See [array API](array_api) for more information. + + [array_api]: https://data-apis.org/array-api/latest/ + """ + + if not isinstance(dtype, np.dtype): + raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype}") + + if isinstance(kind, np.dtype): + return dtype == kind + + elif isinstance(kind, str): + if kind == "bool": + return dtype == np.dtype("bool") + elif kind == "signed integer": + return dtype.kind == "i" + elif kind == "unsigned integer": + return dtype.kind == "u" + elif kind == "integral": + return dtype.kind in "iu" + elif kind == "real floating": + return dtype.kind == "f" + elif kind == "complex floating": + return dtype.kind == "c" + elif kind == "numeric": + return dtype.kind in "iufc" + else: + raise ValueError(f"Unrecognized data type kind: {kind}") + + elif isinstance(kind, tuple): + return any(isdtype(dtype, k) for k in kind) + + else: + raise TypeError(f"Unsupported data type kind: {kind}") + + +def _default_accumulation_dtype(inp_dt, q): + """Gives default output data type for given input data + type `inp_dt` when accumulation is performed on queue `q` + """ + inp_kind = inp_dt.kind + if inp_kind in "bi": + res_dt = dpt.dtype(ti.default_device_int_type(q)) + if inp_dt.itemsize > res_dt.itemsize: + res_dt = inp_dt + elif inp_kind in "u": + res_dt = dpt.dtype(ti.default_device_uint_type(q)) + res_ii = dpt_ext.iinfo(res_dt) + inp_ii = dpt_ext.iinfo(inp_dt) + if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max: + pass + else: + res_dt = inp_dt + elif inp_kind in "fc": + res_dt = inp_dt + + return res_dt + + +def _default_accumulation_dtype_fp_types(inp_dt, q): + """Gives default output data type for given input data + type `inp_dt` when accumulation is performed on queue `q` + and the accumulation supports only floating-point data types + """ + inp_kind = inp_dt.kind + if inp_kind in "biu": + res_dt = dpt.dtype(ti.default_device_fp_type(q)) + can_cast_v = dpt_ext.can_cast(inp_dt, res_dt) + if not can_cast_v: + _fp64 = q.sycl_device.has_aspect_fp64 + res_dt = dpt.float64 if _fp64 else dpt.float32 + elif inp_kind in "f": + res_dt = inp_dt + elif inp_kind in "c": + raise ValueError("function not defined for complex types") + + return res_dt + + +__all__ = [ + "_find_buf_dtype", + "_find_buf_dtype2", + "_to_device_supported_dtype", + "_acceptance_fn_default_unary", + "_acceptance_fn_reciprocal", + "_acceptance_fn_default_binary", + "_acceptance_fn_divide", + "_acceptance_fn_negative", + "_acceptance_fn_subtract", + "_resolve_one_strong_one_weak_types", + "_resolve_one_strong_two_weak_types", + "_resolve_weak_types", + "_resolve_weak_types_all_py_ints", + "_weak_type_num_kind", + "_strong_dtype_num_kind", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + "WeakBooleanType", + "WeakIntegralType", + "WeakFloatingType", + "WeakComplexType", + "_default_accumulation_dtype", + "_default_accumulation_dtype_fp_types", + "_find_buf_dtype_in_place_op", +] diff --git a/dpctl_ext/tensor/libtensor/include/kernels/clip.hpp b/dpctl_ext/tensor/libtensor/include/kernels/clip.hpp new file mode 100644 index 00000000000..7ce50b2b62b --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/clip.hpp @@ -0,0 +1,357 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for dpctl.tensor.clip. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include +#include + +#include + +#include "dpctl_tensor_types.hpp" +#include "kernels/alignment.hpp" +#include "utils/math_utils.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels::clip +{ + +using dpctl::tensor::ssize_t; +using namespace dpctl::tensor::offset_utils; + +using dpctl::tensor::kernels::alignment_utils:: + disabled_sg_loadstore_wrapper_krn; +using dpctl::tensor::kernels::alignment_utils::is_aligned; +using dpctl::tensor::kernels::alignment_utils::required_alignment; + +using dpctl::tensor::sycl_utils::sub_group_load; +using dpctl::tensor::sycl_utils::sub_group_store; + +template +T clip(const T &x, const T &min, const T &max) +{ + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::max_complex; + using dpctl::tensor::math_utils::min_complex; + return min_complex(max_complex(x, min), max); + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) { + auto tmp = (std::isnan(x) || x > min) ? x : min; + return (std::isnan(tmp) || tmp < max) ? tmp : max; + } + else if constexpr (std::is_same_v) { + return (x || min) && max; + } + else { + auto tmp = (x > min) ? x : min; + return (tmp < max) ? tmp : max; + } +} + +template +class ClipContigFunctor +{ +private: + std::size_t nelems = 0; + const T *x_p = nullptr; + const T *min_p = nullptr; + const T *max_p = nullptr; + T *dst_p = nullptr; + +public: + ClipContigFunctor(std::size_t nelems_, + const T *x_p_, + const T *min_p_, + const T *max_p_, + T *dst_p_) + : nelems(nelems_), x_p(x_p_), min_p(min_p_), max_p(max_p_), + dst_p(dst_p_) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + static constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz; + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value || !enable_sg_loadstore) { + const std::uint16_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const std::size_t gid = ndit.get_global_linear_id(); + const std::uint16_t nelems_per_sg = sgSize * nelems_per_wi; + + const std::size_t start = + (gid / sgSize) * (nelems_per_sg - sgSize) + gid; + const std::size_t end = std::min(nelems, start + nelems_per_sg); + + for (std::size_t offset = start; offset < end; offset += sgSize) { + dst_p[offset] = clip(x_p[offset], min_p[offset], max_p[offset]); + } + } + else { + auto sg = ndit.get_sub_group(); + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const std::size_t base = + nelems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + nelems_per_wi * sgSize < nelems) { + sycl::vec dst_vec; +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + const std::size_t idx = base + it * sgSize; + auto x_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&x_p[idx]); + auto min_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&min_p[idx]); + auto max_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&max_p[idx]); + auto dst_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&dst_p[idx]); + + const sycl::vec x_vec = + sub_group_load(sg, x_multi_ptr); + const sycl::vec min_vec = + sub_group_load(sg, min_multi_ptr); + const sycl::vec max_vec = + sub_group_load(sg, max_multi_ptr); +#pragma unroll + for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) { + dst_vec[vec_id] = clip(x_vec[vec_id], min_vec[vec_id], + max_vec[vec_id]); + } + sub_group_store(sg, dst_vec, dst_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems; k += sgSize) { + dst_p[k] = clip(x_p[k], min_p[k], max_p[k]); + } + } + } + } +}; + +template +class clip_contig_kernel; + +typedef sycl::event (*clip_contig_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + const char *, + const char *, + const char *, + char *, + const std::vector &); + +template +sycl::event clip_contig_impl(sycl::queue &q, + std::size_t nelems, + const char *x_cp, + const char *min_cp, + const char *max_cp, + char *dst_cp, + const std::vector &depends) +{ + const T *x_tp = reinterpret_cast(x_cp); + const T *min_tp = reinterpret_cast(min_cp); + const T *max_tp = reinterpret_cast(max_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + sycl::event clip_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + std::size_t lws = 64; + static constexpr std::uint8_t vec_sz = 4; + static constexpr std::uint8_t n_vecs = 2; + const std::size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + if (is_aligned(x_cp) && + is_aligned(min_cp) && + is_aligned(max_cp) && + is_aligned(dst_cp)) + { + static constexpr bool enable_sg_loadstore = true; + using KernelName = clip_contig_kernel; + using Impl = + ClipContigFunctor; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + Impl(nelems, x_tp, min_tp, max_tp, dst_tp)); + } + else { + static constexpr bool disable_sg_loadstore = false; + using InnerKernelName = clip_contig_kernel; + using KernelName = + disabled_sg_loadstore_wrapper_krn; + using Impl = + ClipContigFunctor; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + Impl(nelems, x_tp, min_tp, max_tp, dst_tp)); + } + }); + + return clip_ev; +} + +template +class ClipStridedFunctor +{ +private: + const T *x_p = nullptr; + const T *min_p = nullptr; + const T *max_p = nullptr; + T *dst_p = nullptr; + IndexerT indexer; + +public: + ClipStridedFunctor(const T *x_p_, + const T *min_p_, + const T *max_p_, + T *dst_p_, + const IndexerT &indexer_) + : x_p(x_p_), min_p(min_p_), max_p(max_p_), dst_p(dst_p_), + indexer(indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + std::size_t gid = id[0]; + auto offsets = indexer(static_cast(gid)); + dst_p[offsets.get_fourth_offset()] = clip( + x_p[offsets.get_first_offset()], min_p[offsets.get_second_offset()], + max_p[offsets.get_third_offset()]); + } +}; + +template +class clip_strided_kernel; + +typedef sycl::event (*clip_strided_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + int, + const char *, + const char *, + const char *, + char *, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event clip_strided_impl(sycl::queue &q, + std::size_t nelems, + int nd, + const char *x_cp, + const char *min_cp, + const char *max_cp, + char *dst_cp, + const ssize_t *shape_strides, + ssize_t x_offset, + ssize_t min_offset, + ssize_t max_offset, + ssize_t dst_offset, + const std::vector &depends) +{ + const T *x_tp = reinterpret_cast(x_cp); + const T *min_tp = reinterpret_cast(min_cp); + const T *max_tp = reinterpret_cast(max_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + sycl::event clip_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const FourOffsets_StridedIndexer indexer{ + nd, x_offset, min_offset, max_offset, dst_offset, shape_strides}; + + using KernelName = clip_strided_kernel; + using Impl = ClipStridedFunctor; + + cgh.parallel_for( + sycl::range<1>(nelems), + Impl(x_tp, min_tp, max_tp, dst_tp, indexer)); + }); + + return clip_ev; +} + +template +struct ClipStridedFactory +{ + fnT get() + { + fnT fn = clip_strided_impl; + return fn; + } +}; + +template +struct ClipContigFactory +{ + fnT get() + { + + fnT fn = clip_contig_impl; + return fn; + } +}; + +} // namespace dpctl::tensor::kernels::clip diff --git a/dpctl_ext/tensor/libtensor/include/kernels/repeat.hpp b/dpctl_ext/tensor/libtensor/include/kernels/repeat.hpp new file mode 100644 index 00000000000..aab9a709f01 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/repeat.hpp @@ -0,0 +1,462 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor repeating operations. +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include + +#include "dpctl_tensor_types.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels::repeat +{ + +using dpctl::tensor::ssize_t; +using namespace dpctl::tensor::offset_utils; + +template +class repeat_by_sequence_kernel; + +template +class RepeatSequenceFunctor +{ +private: + const T *src = nullptr; + T *dst = nullptr; + const repT *reps = nullptr; + const repT *cumsum = nullptr; + std::size_t src_axis_nelems = 1; + OrthogIndexer orthog_strider; + SrcAxisIndexer src_axis_strider; + DstAxisIndexer dst_axis_strider; + RepIndexer reps_strider; + +public: + RepeatSequenceFunctor(const T *src_, + T *dst_, + const repT *reps_, + const repT *cumsum_, + std::size_t src_axis_nelems_, + const OrthogIndexer &orthog_strider_, + const SrcAxisIndexer &src_axis_strider_, + const DstAxisIndexer &dst_axis_strider_, + const RepIndexer &reps_strider_) + : src(src_), dst(dst_), reps(reps_), cumsum(cumsum_), + src_axis_nelems(src_axis_nelems_), orthog_strider(orthog_strider_), + src_axis_strider(src_axis_strider_), + dst_axis_strider(dst_axis_strider_), reps_strider(reps_strider_) + { + } + + void operator()(sycl::id<1> idx) const + { + std::size_t id = idx[0]; + auto i_orthog = id / src_axis_nelems; + auto i_along = id - (i_orthog * src_axis_nelems); + + auto orthog_offsets = orthog_strider(i_orthog); + auto src_offset = orthog_offsets.get_first_offset(); + auto dst_offset = orthog_offsets.get_second_offset(); + + auto val = src[src_offset + src_axis_strider(i_along)]; + auto last = cumsum[i_along]; + auto first = last - reps[reps_strider(i_along)]; + for (auto i = first; i < last; ++i) { + dst[dst_offset + dst_axis_strider(i)] = val; + } + } +}; + +typedef sycl::event (*repeat_by_sequence_fn_ptr_t)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + const char *, + const char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event + repeat_by_sequence_impl(sycl::queue &q, + std::size_t orthog_nelems, + std::size_t src_axis_nelems, + const char *src_cp, + char *dst_cp, + const char *reps_cp, + const char *cumsum_cp, + int orthog_nd, + const ssize_t *orthog_src_dst_shape_and_strides, + ssize_t src_offset, + ssize_t dst_offset, + ssize_t src_axis_shape, + ssize_t src_axis_stride, + ssize_t dst_axis_shape, + ssize_t dst_axis_stride, + ssize_t reps_shape, + ssize_t reps_stride, + const std::vector &depends) +{ + sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const T *src_tp = reinterpret_cast(src_cp); + const repT *reps_tp = reinterpret_cast(reps_cp); + const repT *cumsum_tp = reinterpret_cast(cumsum_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + // orthog ndim indexer + const TwoOffsets_StridedIndexer orthog_indexer{ + orthog_nd, src_offset, dst_offset, + orthog_src_dst_shape_and_strides}; + // indexers along repeated axis + const Strided1DIndexer src_axis_indexer{/* size */ src_axis_shape, + /* step */ src_axis_stride}; + const Strided1DIndexer dst_axis_indexer{/* size */ dst_axis_shape, + /* step */ dst_axis_stride}; + // indexer along reps array + const Strided1DIndexer reps_indexer{/* size */ reps_shape, + /* step */ reps_stride}; + + const std::size_t gws = orthog_nelems * src_axis_nelems; + + cgh.parallel_for>( + sycl::range<1>(gws), + RepeatSequenceFunctor( + src_tp, dst_tp, reps_tp, cumsum_tp, src_axis_nelems, + orthog_indexer, src_axis_indexer, dst_axis_indexer, + reps_indexer)); + }); + + return repeat_ev; +} + +template +struct RepeatSequenceFactory +{ + fnT get() + { + fnT fn = repeat_by_sequence_impl; + return fn; + } +}; + +typedef sycl::event (*repeat_by_sequence_1d_fn_ptr_t)( + sycl::queue &, + std::size_t, + const char *, + char *, + const char *, + const char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event repeat_by_sequence_1d_impl(sycl::queue &q, + std::size_t src_nelems, + const char *src_cp, + char *dst_cp, + const char *reps_cp, + const char *cumsum_cp, + int src_nd, + const ssize_t *src_shape_strides, + ssize_t dst_shape, + ssize_t dst_stride, + ssize_t reps_shape, + ssize_t reps_stride, + const std::vector &depends) +{ + sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const T *src_tp = reinterpret_cast(src_cp); + const repT *reps_tp = reinterpret_cast(reps_cp); + const repT *cumsum_tp = reinterpret_cast(cumsum_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + // orthog ndim indexer + static constexpr TwoZeroOffsets_Indexer orthog_indexer{}; + // indexers along repeated axis + const StridedIndexer src_indexer{src_nd, 0, src_shape_strides}; + const Strided1DIndexer dst_indexer{/* size */ dst_shape, + /* step */ dst_stride}; + // indexer along reps array + const Strided1DIndexer reps_indexer{/* size */ reps_shape, + /* step */ reps_stride}; + + const std::size_t gws = src_nelems; + + cgh.parallel_for>( + sycl::range<1>(gws), + RepeatSequenceFunctor( + src_tp, dst_tp, reps_tp, cumsum_tp, src_nelems, orthog_indexer, + src_indexer, dst_indexer, reps_indexer)); + }); + + return repeat_ev; +} + +template +struct RepeatSequence1DFactory +{ + fnT get() + { + fnT fn = repeat_by_sequence_1d_impl; + return fn; + } +}; + +template +class repeat_by_scalar_kernel; + +template +class RepeatScalarFunctor +{ +private: + const T *src = nullptr; + T *dst = nullptr; + ssize_t reps = 1; + std::size_t dst_axis_nelems = 0; + OrthogIndexer orthog_strider; + SrcAxisIndexer src_axis_strider; + DstAxisIndexer dst_axis_strider; + +public: + RepeatScalarFunctor(const T *src_, + T *dst_, + const ssize_t reps_, + std::size_t dst_axis_nelems_, + const OrthogIndexer &orthog_strider_, + const SrcAxisIndexer &src_axis_strider_, + const DstAxisIndexer &dst_axis_strider_) + : src(src_), dst(dst_), reps(reps_), dst_axis_nelems(dst_axis_nelems_), + orthog_strider(orthog_strider_), src_axis_strider(src_axis_strider_), + dst_axis_strider(dst_axis_strider_) + { + } + + void operator()(sycl::id<1> idx) const + { + std::size_t id = idx[0]; + auto i_orthog = id / dst_axis_nelems; + auto i_along = id - (i_orthog * dst_axis_nelems); + + auto orthog_offsets = orthog_strider(i_orthog); + auto src_offset = orthog_offsets.get_first_offset(); + auto dst_offset = orthog_offsets.get_second_offset(); + + auto dst_axis_offset = dst_axis_strider(i_along); + auto src_axis_offset = src_axis_strider(i_along / reps); + dst[dst_offset + dst_axis_offset] = src[src_offset + src_axis_offset]; + } +}; + +typedef sycl::event (*repeat_by_scalar_fn_ptr_t)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + const ssize_t, + int, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event repeat_by_scalar_impl(sycl::queue &q, + std::size_t orthog_nelems, + std::size_t dst_axis_nelems, + const char *src_cp, + char *dst_cp, + const ssize_t reps, + int orthog_nd, + const ssize_t *orthog_shape_and_strides, + ssize_t src_offset, + ssize_t dst_offset, + ssize_t src_axis_shape, + ssize_t src_axis_stride, + ssize_t dst_axis_shape, + ssize_t dst_axis_stride, + const std::vector &depends) +{ + sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const T *src_tp = reinterpret_cast(src_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + // orthog ndim indexer + const TwoOffsets_StridedIndexer orthog_indexer{ + orthog_nd, src_offset, dst_offset, orthog_shape_and_strides}; + // indexers along repeated axis + const Strided1DIndexer src_axis_indexer{/* size */ src_axis_shape, + /* step */ src_axis_stride}; + const Strided1DIndexer dst_axis_indexer{/* size */ dst_axis_shape, + /* step */ dst_axis_stride}; + + const std::size_t gws = orthog_nelems * dst_axis_nelems; + + cgh.parallel_for>( + sycl::range<1>(gws), + RepeatScalarFunctor( + src_tp, dst_tp, reps, dst_axis_nelems, orthog_indexer, + src_axis_indexer, dst_axis_indexer)); + }); + + return repeat_ev; +} + +template +struct RepeatScalarFactory +{ + fnT get() + { + fnT fn = repeat_by_scalar_impl; + return fn; + } +}; + +typedef sycl::event (*repeat_by_scalar_1d_fn_ptr_t)( + sycl::queue &, + std::size_t, + const char *, + char *, + const ssize_t, + int, + const ssize_t *, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event repeat_by_scalar_1d_impl(sycl::queue &q, + std::size_t dst_nelems, + const char *src_cp, + char *dst_cp, + const ssize_t reps, + int src_nd, + const ssize_t *src_shape_strides, + ssize_t dst_shape, + ssize_t dst_stride, + const std::vector &depends) +{ + sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const T *src_tp = reinterpret_cast(src_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + // orthog ndim indexer + static constexpr TwoZeroOffsets_Indexer orthog_indexer{}; + // indexers along repeated axis + const StridedIndexer src_indexer(src_nd, 0, src_shape_strides); + const Strided1DIndexer dst_indexer{/* size */ dst_shape, + /* step */ dst_stride}; + + const std::size_t gws = dst_nelems; + + cgh.parallel_for>( + sycl::range<1>(gws), + RepeatScalarFunctor(src_tp, dst_tp, reps, + dst_nelems, orthog_indexer, + src_indexer, dst_indexer)); + }); + + return repeat_ev; +} + +template +struct RepeatScalar1DFactory +{ + fnT get() + { + fnT fn = repeat_by_scalar_1d_impl; + return fn; + } +}; + +} // namespace dpctl::tensor::kernels::repeat diff --git a/dpctl_ext/tensor/libtensor/include/kernels/where.hpp b/dpctl_ext/tensor/libtensor/include/kernels/where.hpp new file mode 100644 index 00000000000..b92a3a76c9c --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/where.hpp @@ -0,0 +1,339 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for dpctl.tensor.where. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include + +#include + +#include "dpctl_tensor_types.hpp" +#include "kernels/alignment.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels::search +{ + +using dpctl::tensor::ssize_t; +using namespace dpctl::tensor::offset_utils; + +using dpctl::tensor::kernels::alignment_utils:: + disabled_sg_loadstore_wrapper_krn; +using dpctl::tensor::kernels::alignment_utils::is_aligned; +using dpctl::tensor::kernels::alignment_utils::required_alignment; + +using dpctl::tensor::sycl_utils::sub_group_load; +using dpctl::tensor::sycl_utils::sub_group_store; + +template +class where_strided_kernel; +template +class where_contig_kernel; + +template +class WhereContigFunctor +{ +private: + std::size_t nelems = 0; + const condT *cond_p = nullptr; + const T *x1_p = nullptr; + const T *x2_p = nullptr; + T *dst_p = nullptr; + +public: + WhereContigFunctor(std::size_t nelems_, + const condT *cond_p_, + const T *x1_p_, + const T *x2_p_, + T *dst_p_) + : nelems(nelems_), cond_p(cond_p_), x1_p(x1_p_), x2_p(x2_p_), + dst_p(dst_p_) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + static constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz; + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!enable_sg_loadstore || is_complex::value || + is_complex::value) + { + const std::uint16_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const std::size_t gid = ndit.get_global_linear_id(); + + const std::uint16_t nelems_per_sg = sgSize * nelems_per_wi; + const std::size_t start = + (gid / sgSize) * (nelems_per_sg - sgSize) + gid; + const std::size_t end = std::min(nelems, start + nelems_per_sg); + for (std::size_t offset = start; offset < end; offset += sgSize) { + using dpctl::tensor::type_utils::convert_impl; + const bool check = convert_impl(cond_p[offset]); + dst_p[offset] = check ? x1_p[offset] : x2_p[offset]; + } + } + else { + auto sg = ndit.get_sub_group(); + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const std::size_t base = + nelems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + nelems_per_wi * sgSize < nelems) { + sycl::vec dst_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + const std::size_t idx = base + it * sgSize; + auto x1_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&x1_p[idx]); + auto x2_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&x2_p[idx]); + auto cond_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&cond_p[idx]); + auto dst_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&dst_p[idx]); + + const sycl::vec x1_vec = + sub_group_load(sg, x1_multi_ptr); + const sycl::vec x2_vec = + sub_group_load(sg, x2_multi_ptr); + const sycl::vec cond_vec = + sub_group_load(sg, cond_multi_ptr); +#pragma unroll + for (std::uint8_t k = 0; k < vec_sz; ++k) { + dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k]; + } + sub_group_store(sg, dst_vec, dst_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems; k += sgSize) { + dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k]; + } + } + } + } +}; + +typedef sycl::event (*where_contig_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + const char *, + const char *, + const char *, + char *, + const std::vector &); + +template +sycl::event where_contig_impl(sycl::queue &q, + std::size_t nelems, + const char *cond_cp, + const char *x1_cp, + const char *x2_cp, + char *dst_cp, + const std::vector &depends) +{ + const condT *cond_tp = reinterpret_cast(cond_cp); + const T *x1_tp = reinterpret_cast(x1_cp); + const T *x2_tp = reinterpret_cast(x2_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + std::size_t lws = 64; + static constexpr std::uint8_t vec_sz = 4u; + static constexpr std::uint8_t n_vecs = 2u; + const std::size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + if (is_aligned(cond_cp) && + is_aligned(x1_cp) && + is_aligned(x2_cp) && + is_aligned(dst_cp)) + { + static constexpr bool enable_sg_loadstore = true; + using KernelName = where_contig_kernel; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + WhereContigFunctor(nelems, cond_tp, x1_tp, + x2_tp, dst_tp)); + } + else { + static constexpr bool disable_sg_loadstore = false; + using InnerKernelName = + where_contig_kernel; + using KernelName = + disabled_sg_loadstore_wrapper_krn; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + WhereContigFunctor(nelems, cond_tp, x1_tp, + x2_tp, dst_tp)); + } + }); + + return where_ev; +} + +template +class WhereStridedFunctor +{ +private: + const T *x1_p = nullptr; + const T *x2_p = nullptr; + T *dst_p = nullptr; + const condT *cond_p = nullptr; + IndexerT indexer; + +public: + WhereStridedFunctor(const condT *cond_p_, + const T *x1_p_, + const T *x2_p_, + T *dst_p_, + const IndexerT &indexer_) + : x1_p(x1_p_), x2_p(x2_p_), dst_p(dst_p_), cond_p(cond_p_), + indexer(indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + std::size_t gid = id[0]; + auto offsets = indexer(static_cast(gid)); + + using dpctl::tensor::type_utils::convert_impl; + bool check = + convert_impl(cond_p[offsets.get_first_offset()]); + + dst_p[offsets.get_fourth_offset()] = + check ? x1_p[offsets.get_second_offset()] + : x2_p[offsets.get_third_offset()]; + } +}; + +typedef sycl::event (*where_strided_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + int, + const char *, + const char *, + const char *, + char *, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event where_strided_impl(sycl::queue &q, + std::size_t nelems, + int nd, + const char *cond_cp, + const char *x1_cp, + const char *x2_cp, + char *dst_cp, + const ssize_t *shape_strides, + ssize_t x1_offset, + ssize_t x2_offset, + ssize_t cond_offset, + ssize_t dst_offset, + const std::vector &depends) +{ + const condT *cond_tp = reinterpret_cast(cond_cp); + const T *x1_tp = reinterpret_cast(x1_cp); + const T *x2_tp = reinterpret_cast(x2_cp); + T *dst_tp = reinterpret_cast(dst_cp); + + sycl::event where_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const FourOffsets_StridedIndexer indexer{ + nd, cond_offset, x1_offset, x2_offset, dst_offset, shape_strides}; + + cgh.parallel_for< + where_strided_kernel>( + sycl::range<1>(nelems), + WhereStridedFunctor( + cond_tp, x1_tp, x2_tp, dst_tp, indexer)); + }); + + return where_ev; +} + +template +struct WhereStridedFactory +{ + fnT get() + { + fnT fn = where_strided_impl; + return fn; + } +}; + +template +struct WhereContigFactory +{ + fnT get() + { + fnT fn = where_contig_impl; + return fn; + } +}; + +} // namespace dpctl::tensor::kernels::search diff --git a/dpctl_ext/tensor/libtensor/source/clip.cpp b/dpctl_ext/tensor/libtensor/source/clip.cpp new file mode 100644 index 00000000000..1414689bc4b --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/clip.cpp @@ -0,0 +1,267 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines Python API for implementation functions of +/// dpctl.tensor.clip +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "clip.hpp" +#include "kernels/clip.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::kernels::clip::clip_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::clip::clip_strided_impl_fn_ptr_t; + +static clip_contig_impl_fn_ptr_t clip_contig_dispatch_vector[td_ns::num_types]; +static clip_strided_impl_fn_ptr_t + clip_strided_dispatch_vector[td_ns::num_types]; + +void init_clip_dispatch_vectors(void) +{ + using namespace td_ns; + using dpctl::tensor::kernels::clip::ClipContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(clip_contig_dispatch_vector); + + using dpctl::tensor::kernels::clip::ClipStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(clip_strided_dispatch_vector); +} + +using dpctl::utils::keep_args_alive; + +std::pair + py_clip(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &min, + const dpctl::tensor::usm_ndarray &max, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) +{ + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, min, max, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + int nd = src.get_ndim(); + int min_nd = min.get_ndim(); + int max_nd = max.get_ndim(); + int dst_nd = dst.get_ndim(); + + if (nd != min_nd || nd != max_nd) { + throw py::value_error( + "Input arrays are not of appropriate dimension for clip kernel."); + } + + if (nd != dst_nd) { + throw py::value_error( + "Destination is not of appropriate dimension for clip kernel."); + } + + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *min_shape = min.get_shape_raw(); + const py::ssize_t *max_shape = max.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + + bool shapes_equal(true); + std::size_t nelems(1); + for (int i = 0; i < nd; ++i) { + const auto &sh_i = dst_shape[i]; + nelems *= static_cast(sh_i); + shapes_equal = shapes_equal && (min_shape[i] == sh_i) && + (max_shape[i] == sh_i) && (src_shape[i] == sh_i); + } + + if (!shapes_equal) { + throw py::value_error("Arrays are not of matching shapes."); + } + + if (nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if ((overlap(dst, src) && !same_logical_tensors(dst, src)) || + (overlap(dst, min) && !same_logical_tensors(dst, min)) || + (overlap(dst, max) && !same_logical_tensors(dst, max))) + { + throw py::value_error("Destination array overlaps with input."); + } + + int min_typenum = min.get_typenum(); + int max_typenum = max.get_typenum(); + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int min_typeid = array_types.typenum_to_lookup_id(min_typenum); + int max_typeid = array_types.typenum_to_lookup_id(max_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid || src_typeid != min_typeid || + src_typeid != max_typeid) + { + throw py::value_error("Input, min, max, and destination arrays must " + "have the same data type"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems); + + char *src_data = src.get_data(); + char *min_data = min.get_data(); + char *max_data = max.get_data(); + char *dst_data = dst.get_data(); + + bool is_min_c_contig = min.is_c_contiguous(); + bool is_min_f_contig = min.is_f_contiguous(); + + bool is_max_c_contig = max.is_c_contiguous(); + bool is_max_f_contig = max.is_f_contiguous(); + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_dst_f_contig = dst.is_f_contiguous(); + + bool all_c_contig = (is_min_c_contig && is_max_c_contig && + is_src_c_contig && is_dst_c_contig); + bool all_f_contig = (is_min_f_contig && is_max_f_contig && + is_src_f_contig && is_dst_f_contig); + + if (all_c_contig || all_f_contig) { + auto fn = clip_contig_dispatch_vector[src_typeid]; + + sycl::event clip_ev = + fn(exec_q, nelems, src_data, min_data, max_data, dst_data, depends); + sycl::event ht_ev = + keep_args_alive(exec_q, {src, min, max, dst}, {clip_ev}); + + return std::make_pair(ht_ev, clip_ev); + } + + auto const &src_strides = src.get_strides_vector(); + auto const &min_strides = min.get_strides_vector(); + auto const &max_strides = max.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src_strides; + shT simplified_min_strides; + shT simplified_max_strides; + shT simplified_dst_strides; + py::ssize_t src_offset(0); + py::ssize_t min_offset(0); + py::ssize_t max_offset(0); + py::ssize_t dst_offset(0); + + dpctl::tensor::py_internal::simplify_iteration_space_4( + nd, src_shape, src_strides, min_strides, max_strides, dst_strides, + // outputs + simplified_shape, simplified_src_strides, simplified_min_strides, + simplified_max_strides, simplified_dst_strides, src_offset, min_offset, + max_offset, dst_offset); + + auto fn = clip_strided_dispatch_vector[src_typeid]; + + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, + // common shape and strides + simplified_shape, simplified_src_strides, simplified_min_strides, + simplified_max_strides, simplified_dst_strides); + auto packed_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + sycl::event copy_shape_strides_ev = std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shape_strides = packed_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event clip_ev = fn(exec_q, nelems, nd, src_data, min_data, max_data, + dst_data, packed_shape_strides, src_offset, + min_offset, max_offset, dst_offset, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {clip_ev}, packed_shape_strides_owner); + host_task_events.push_back(temporaries_cleanup_ev); + + sycl::event arg_cleanup_ev = + keep_args_alive(exec_q, {src, min, max, dst}, host_task_events); + + return std::make_pair(arg_cleanup_ev, clip_ev); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/clip.hpp b/dpctl_ext/tensor/libtensor/source/clip.hpp new file mode 100644 index 00000000000..de8f0e559b6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/clip.hpp @@ -0,0 +1,57 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines Python API for implementation functions of +/// dpctl.tensor.clip +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include + +#include + +#include "dpnp4pybind11.hpp" + +namespace dpctl::tensor::py_internal +{ + +extern std::pair + py_clip(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &min, + const dpctl::tensor::usm_ndarray &max, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends); + +extern void init_clip_dispatch_vectors(void); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/repeat.cpp b/dpctl_ext/tensor/libtensor/source/repeat.cpp new file mode 100644 index 00000000000..4bba1d35a08 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/repeat.cpp @@ -0,0 +1,820 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/repeat.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "simplify_iteration_space.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::kernels::repeat::repeat_by_sequence_fn_ptr_t; +static repeat_by_sequence_fn_ptr_t + repeat_by_sequence_dispatch_vector[td_ns::num_types]; + +using dpctl::tensor::kernels::repeat::repeat_by_sequence_1d_fn_ptr_t; +static repeat_by_sequence_1d_fn_ptr_t + repeat_by_sequence_1d_dispatch_vector[td_ns::num_types]; + +using dpctl::tensor::kernels::repeat::repeat_by_scalar_fn_ptr_t; +static repeat_by_scalar_fn_ptr_t + repeat_by_scalar_dispatch_vector[td_ns::num_types]; + +using dpctl::tensor::kernels::repeat::repeat_by_scalar_1d_fn_ptr_t; +static repeat_by_scalar_1d_fn_ptr_t + repeat_by_scalar_1d_dispatch_vector[td_ns::num_types]; + +void init_repeat_dispatch_vectors(void) +{ + using dpctl::tensor::kernels::repeat::RepeatSequenceFactory; + td_ns::DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(repeat_by_sequence_dispatch_vector); + + using dpctl::tensor::kernels::repeat::RepeatSequence1DFactory; + td_ns::DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(repeat_by_sequence_1d_dispatch_vector); + + using dpctl::tensor::kernels::repeat::RepeatScalarFactory; + td_ns::DispatchVectorBuilder + dvb3; + dvb3.populate_dispatch_vector(repeat_by_scalar_dispatch_vector); + + using dpctl::tensor::kernels::repeat::RepeatScalar1DFactory; + td_ns::DispatchVectorBuilder + dvb4; + dvb4.populate_dispatch_vector(repeat_by_scalar_1d_dispatch_vector); +} + +std::pair + py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + int axis, + sycl::queue &exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + if (axis < 0 || (axis + 1 > src_nd && src_nd > 0) || + (axis > 0 && src_nd == 0)) { + throw py::value_error("Specified axis is invalid."); + } + + int dst_nd = dst.get_ndim(); + if ((src_nd != dst_nd && src_nd > 0) || (src_nd == 0 && dst_nd > 1)) { + throw py::value_error("Number of dimensions of source and destination " + "arrays is not consistent"); + } + + int reps_nd = reps.get_ndim(); + if (reps_nd != 1) { + throw py::value_error("`reps` array must be 1-dimensional"); + } + + if (cumsum.get_ndim() != 1) { + throw py::value_error("`cumsum` array must be 1-dimensional."); + } + + if (!cumsum.is_c_contiguous()) { + throw py::value_error("Expecting `cumsum` array to be C-contiguous."); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, reps, cumsum, dst})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t reps_sz = reps.get_size(); + std::size_t cumsum_sz = cumsum.get_size(); + + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + bool same_orthog_dims(true); + std::size_t orthog_nelems(1); // number of orthogonal iterations + for (auto i = 0; i < axis; ++i) { + auto src_sh_i = src_shape[i]; + orthog_nelems *= src_sh_i; + same_orthog_dims = same_orthog_dims && (src_sh_i == dst_shape[i]); + } + for (auto i = axis + 1; i < src_nd; ++i) { + auto src_sh_i = src_shape[i]; + orthog_nelems *= src_sh_i; + same_orthog_dims = same_orthog_dims && (src_sh_i == dst_shape[i]); + } + + std::size_t src_axis_nelems(1); + if (src_nd > 0) { + src_axis_nelems = src_shape[axis]; + } + std::size_t dst_axis_nelems(dst_shape[axis]); + + // shape at repeated axis must be equal to the sum of reps + if (!same_orthog_dims || src_axis_nelems != reps_sz || + src_axis_nelems != cumsum_sz) + { + throw py::value_error("Inconsistent array dimensions"); + } + + if (orthog_nelems == 0 || src_axis_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, orthog_nelems * dst_axis_nelems); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with src or reps + if (overlap(dst, src) || overlap(dst, reps) || overlap(dst, cumsum)) { + throw py::value_error("Destination array overlaps with inputs"); + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + int reps_typenum = reps.get_typenum(); + int cumsum_typenum = cumsum.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + int reps_typeid = array_types.typenum_to_lookup_id(reps_typenum); + int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error( + "Destination array must have the same elemental data type"); + } + + static constexpr int int64_typeid = + static_cast(td_ns::typenum_t::INT64); + if (cumsum_typeid != int64_typeid) { + throw py::value_error( + "Unexpected data type of `cumsum` array, expecting " + "'int64'"); + } + + if (reps_typeid != cumsum_typeid) { + throw py::value_error("`reps` array must have the same elemental " + "data type as cumsum"); + } + + const char *src_data_p = src.get_data(); + const char *reps_data_p = reps.get_data(); + const char *cumsum_data_p = cumsum.get_data(); + char *dst_data_p = dst.get_data(); + + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + auto reps_shape_vec = reps.get_shape_vector(); + auto reps_strides_vec = reps.get_strides_vector(); + + sycl::event repeat_ev; + std::vector host_task_events{}; + if (axis == 0 && src_nd < 2) { + // empty orthogonal directions + + auto fn = repeat_by_sequence_1d_dispatch_vector[src_typeid]; + + assert(dst_shape_vec.size() == 1); + assert(dst_strides_vec.size() == 1); + + if (src_nd == 0) { + src_shape_vec = {0}; + src_strides_vec = {0}; + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + auto packed_src_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_src_shape_strides = + packed_src_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + repeat_ev = + fn(exec_q, src_axis_nelems, src_data_p, dst_data_p, reps_data_p, + cumsum_data_p, src_nd, packed_src_shape_strides, + dst_shape_vec[0], dst_strides_vec[0], reps_shape_vec[0], + reps_strides_vec[0], all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {repeat_ev}, packed_src_shape_strides_owner); + host_task_events.push_back(cleanup_tmp_allocations_ev); + } + else { + // non-empty orthogonal directions + + auto fn = repeat_by_sequence_dispatch_vector[src_typeid]; + + int orthog_nd = src_nd - 1; + + using shT = std::vector; + shT orthog_src_shape; + shT orthog_src_strides; + shT axis_src_shape; + shT axis_src_stride; + dpctl::tensor::py_internal::split_iteration_space( + src_shape_vec, src_strides_vec, axis, axis + 1, orthog_src_shape, + axis_src_shape, orthog_src_strides, axis_src_stride); + + shT orthog_dst_shape; + shT orthog_dst_strides; + shT axis_dst_shape; + shT axis_dst_stride; + dpctl::tensor::py_internal::split_iteration_space( + dst_shape_vec, dst_strides_vec, axis, axis + 1, orthog_dst_shape, + axis_dst_shape, orthog_dst_strides, axis_dst_stride); + + assert(orthog_src_shape.size() == static_cast(orthog_nd)); + assert(orthog_dst_shape.size() == static_cast(orthog_nd)); + assert(std::equal(orthog_src_shape.begin(), orthog_src_shape.end(), + orthog_dst_shape.begin())); + + shT simplified_orthog_shape; + shT simplified_orthog_src_strides; + shT simplified_orthog_dst_strides; + + const py::ssize_t *_shape = orthog_src_shape.data(); + + py::ssize_t orthog_src_offset(0); + py::ssize_t orthog_dst_offset(0); + dpctl::tensor::py_internal::simplify_iteration_space( + orthog_nd, _shape, orthog_src_strides, orthog_dst_strides, + // output + simplified_orthog_shape, simplified_orthog_src_strides, + simplified_orthog_dst_strides, orthog_src_offset, + orthog_dst_offset); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, simplified_orthog_shape, + simplified_orthog_src_strides, simplified_orthog_dst_strides); + auto packed_shapes_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_shapes_strides = + packed_shapes_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + repeat_ev = fn(exec_q, orthog_nelems, src_axis_nelems, src_data_p, + dst_data_p, reps_data_p, cumsum_data_p, + // data to build orthog indexer + orthog_nd, packed_shapes_strides, orthog_src_offset, + orthog_dst_offset, + // data to build indexers along repeated axis in src + axis_src_shape[0], axis_src_stride[0], + // data to build indexer along repeated axis in dst + axis_dst_shape[0], axis_dst_stride[0], + // data to build indexer for reps array + reps_shape_vec[0], reps_strides_vec[0], all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {repeat_ev}, packed_shapes_strides_owner); + host_task_events.push_back(cleanup_tmp_allocations_ev); + } + + sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive( + exec_q, {src, reps, cumsum, dst}, host_task_events); + + return std::make_pair(py_obj_management_host_task_ev, repeat_ev); +} + +std::pair + py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + sycl::queue &exec_q, + const std::vector &depends) +{ + + int dst_nd = dst.get_ndim(); + if (dst_nd != 1) { + throw py::value_error( + "`dst` array must be 1-dimensional when repeating a full array"); + } + + int reps_nd = reps.get_ndim(); + if (reps_nd != 1) { + throw py::value_error("`reps` array must be 1-dimensional"); + } + + if (cumsum.get_ndim() != 1) { + throw py::value_error("`cumsum` array must be 1-dimensional."); + } + + if (!cumsum.is_c_contiguous()) { + throw py::value_error("Expecting `cumsum` array to be C-contiguous."); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, reps, cumsum, dst})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t src_sz = src.get_size(); + std::size_t reps_sz = reps.get_size(); + std::size_t cumsum_sz = cumsum.get_size(); + + // shape at repeated axis must be equal to the sum of reps + if (src_sz != reps_sz || src_sz != cumsum_sz) { + throw py::value_error("Inconsistent array dimensions"); + } + + if (src_sz == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, + dst.get_size()); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with src, cumsum, or reps + if (overlap(dst, src) || overlap(dst, reps) || overlap(dst, cumsum)) { + throw py::value_error("Destination array overlaps with inputs"); + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + int reps_typenum = reps.get_typenum(); + int cumsum_typenum = cumsum.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + int reps_typeid = array_types.typenum_to_lookup_id(reps_typenum); + int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error( + "Destination array must have the same elemental data type"); + } + + static constexpr int int64_typeid = + static_cast(td_ns::typenum_t::INT64); + if (cumsum_typeid != int64_typeid) { + throw py::value_error( + "Unexpected data type of `cumsum` array, expecting " + "'int64'"); + } + + if (reps_typeid != cumsum_typeid) { + throw py::value_error("`reps` array must have the same elemental " + "data type as cumsum"); + } + + const char *src_data_p = src.get_data(); + const char *reps_data_p = reps.get_data(); + const char *cumsum_data_p = cumsum.get_data(); + char *dst_data_p = dst.get_data(); + + int src_nd = src.get_ndim(); + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + if (src_nd == 0) { + src_shape_vec = {0}; + src_strides_vec = {0}; + } + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + auto reps_shape_vec = reps.get_shape_vector(); + auto reps_strides_vec = reps.get_strides_vector(); + + std::vector host_task_events{}; + + auto fn = repeat_by_sequence_1d_dispatch_vector[src_typeid]; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + auto packed_src_shapes_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_src_shapes_strides = + packed_src_shapes_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event repeat_ev = fn( + exec_q, src_sz, src_data_p, dst_data_p, reps_data_p, cumsum_data_p, + src_nd, packed_src_shapes_strides, dst_shape_vec[0], dst_strides_vec[0], + reps_shape_vec[0], reps_strides_vec[0], all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {repeat_ev}, packed_src_shapes_strides_owner); + host_task_events.push_back(cleanup_tmp_allocations_ev); + + sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive( + exec_q, {src, reps, cumsum, dst}, host_task_events); + + return std::make_pair(py_obj_management_host_task_ev, repeat_ev); +} + +std::pair + py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, + int axis, + sycl::queue &exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + if (axis < 0 || (axis + 1 > src_nd && src_nd > 0) || + (axis > 0 && src_nd == 0)) { + throw py::value_error("Specified axis is invalid."); + } + + int dst_nd = dst.get_ndim(); + if ((src_nd != dst_nd && src_nd > 0) || (src_nd == 0 && dst_nd > 1)) { + throw py::value_error("Number of dimensions of source and destination " + "arrays is not consistent"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + bool same_orthog_dims(true); + std::size_t orthog_nelems(1); // number of orthogonal iterations + for (auto i = 0; i < axis; ++i) { + auto src_sh_i = src_shape[i]; + orthog_nelems *= src_sh_i; + same_orthog_dims = same_orthog_dims && (src_sh_i == dst_shape[i]); + } + for (auto i = axis + 1; i < src_nd; ++i) { + auto src_sh_i = src_shape[i]; + orthog_nelems *= src_sh_i; + same_orthog_dims = same_orthog_dims && (src_sh_i == dst_shape[i]); + } + + std::size_t src_axis_nelems(1); + if (src_nd > 0) { + src_axis_nelems = src_shape[axis]; + } + std::size_t dst_axis_nelems(dst_shape[axis]); + + // shape at repeated axis must be equal to the shape of src at the axis * + // reps + if (!same_orthog_dims || (src_axis_nelems * reps) != dst_axis_nelems) { + throw py::value_error("Inconsistent array dimensions"); + } + + if (orthog_nelems == 0 || src_axis_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, orthog_nelems * (src_axis_nelems * reps)); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with src + if (overlap(dst, src)) { + throw py::value_error("Destination array overlaps with inputs"); + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error( + "Destination array must have the same elemental data type"); + } + + const char *src_data_p = src.get_data(); + char *dst_data_p = dst.get_data(); + + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + sycl::event repeat_ev; + std::vector host_task_events{}; + if (axis == 0 && src_nd < 2) { + // empty orthogonal directions + + auto fn = repeat_by_scalar_1d_dispatch_vector[src_typeid]; + + assert(dst_shape_vec.size() == 1); + assert(dst_strides_vec.size() == 1); + + if (src_nd == 0) { + src_shape_vec = {0}; + src_strides_vec = {0}; + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + auto packed_src_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_src_shape_strides = + packed_src_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + repeat_ev = fn(exec_q, dst_axis_nelems, src_data_p, dst_data_p, reps, + src_nd, packed_src_shape_strides, dst_shape_vec[0], + dst_strides_vec[0], all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {repeat_ev}, packed_src_shape_strides_owner); + + host_task_events.push_back(cleanup_tmp_allocations_ev); + } + else { + // non-empty orthogonal directions + + auto fn = repeat_by_scalar_dispatch_vector[src_typeid]; + + int orthog_nd = src_nd - 1; + + using shT = std::vector; + shT orthog_src_shape; + shT orthog_src_strides; + shT axis_src_shape; + shT axis_src_stride; + dpctl::tensor::py_internal::split_iteration_space( + src_shape_vec, src_strides_vec, axis, axis + 1, orthog_src_shape, + axis_src_shape, orthog_src_strides, axis_src_stride); + + shT orthog_dst_shape; + shT orthog_dst_strides; + shT axis_dst_shape; + shT axis_dst_stride; + dpctl::tensor::py_internal::split_iteration_space( + dst_shape_vec, dst_strides_vec, axis, axis + 1, orthog_dst_shape, + axis_dst_shape, orthog_dst_strides, axis_dst_stride); + + assert(orthog_src_shape.size() == static_cast(orthog_nd)); + assert(orthog_dst_shape.size() == static_cast(orthog_nd)); + assert(std::equal(orthog_src_shape.begin(), orthog_src_shape.end(), + orthog_dst_shape.begin())); + + shT simplified_orthog_shape; + shT simplified_orthog_src_strides; + shT simplified_orthog_dst_strides; + + const py::ssize_t *_shape = orthog_src_shape.data(); + + py::ssize_t orthog_src_offset(0); + py::ssize_t orthog_dst_offset(0); + + dpctl::tensor::py_internal::simplify_iteration_space( + orthog_nd, _shape, orthog_src_strides, orthog_dst_strides, + // output + simplified_orthog_shape, simplified_orthog_src_strides, + simplified_orthog_dst_strides, orthog_src_offset, + orthog_dst_offset); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, simplified_orthog_shape, + simplified_orthog_src_strides, simplified_orthog_dst_strides); + auto packed_shapes_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_shapes_strides = + packed_shapes_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + repeat_ev = fn(exec_q, orthog_nelems, dst_axis_nelems, src_data_p, + dst_data_p, reps, + // data to build orthog indexer + orthog_nd, packed_shapes_strides, orthog_src_offset, + orthog_dst_offset, + // data to build indexer along repeated axis in src + axis_src_shape[0], axis_src_stride[0], + // data to build indexer along repeated axis in dst + axis_dst_shape[0], axis_dst_stride[0], all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {repeat_ev}, packed_shapes_strides_owner); + host_task_events.push_back(cleanup_tmp_allocations_ev); + } + + sycl::event py_obj_management_host_task_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(py_obj_management_host_task_ev, repeat_ev); +} + +std::pair + py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, + sycl::queue &exec_q, + const std::vector &depends) +{ + int dst_nd = dst.get_ndim(); + if (dst_nd != 1) { + throw py::value_error( + "`dst` array must be 1-dimensional when repeating a full array"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t src_sz = src.get_size(); + std::size_t dst_sz = dst.get_size(); + + // shape at repeated axis must be equal to the shape of src at the axis * + // reps + if ((src_sz * reps) != dst_sz) { + throw py::value_error("Inconsistent array dimensions"); + } + + if (src_sz == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, + src_sz * reps); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with src + if (overlap(dst, src)) { + throw py::value_error("Destination array overlaps with inputs"); + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error( + "Destination array must have the same elemental data type"); + } + + const char *src_data_p = src.get_data(); + char *dst_data_p = dst.get_data(); + + int src_nd = src.get_ndim(); + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + + if (src_nd == 0) { + src_shape_vec = {0}; + src_strides_vec = {0}; + } + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + std::vector host_task_events{}; + + auto fn = repeat_by_scalar_1d_dispatch_vector[src_typeid]; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, src_shape_vec, src_strides_vec); + auto packed_src_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_src_shape_strides = + packed_src_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event repeat_ev = fn(exec_q, dst_sz, src_data_p, dst_data_p, reps, + src_nd, packed_src_shape_strides, + dst_shape_vec[0], dst_strides_vec[0], all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {repeat_ev}, packed_src_shape_strides_owner); + host_task_events.push_back(cleanup_tmp_allocations_ev); + + sycl::event py_obj_management_host_task_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(py_obj_management_host_task_ev, repeat_ev); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/repeat.hpp b/dpctl_ext/tensor/libtensor/source/repeat.hpp new file mode 100644 index 00000000000..5835377fb29 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/repeat.hpp @@ -0,0 +1,83 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_repeat_dispatch_vectors(void); + +extern std::pair + py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + int axis, + sycl::queue &exec_q, + const std::vector &depends); + +extern std::pair + py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + sycl::queue &exec_q, + const std::vector &depends); + +extern std::pair + py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, + int axis, + sycl::queue &exec_q, + const std::vector &depends); + +extern std::pair + py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, + sycl::queue &exec_q, + const std::vector &depends); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/tensor_ctors.cpp b/dpctl_ext/tensor/libtensor/source/tensor_ctors.cpp index 98ab488e587..e6bc3b5dfb6 100644 --- a/dpctl_ext/tensor/libtensor/source/tensor_ctors.cpp +++ b/dpctl_ext/tensor/libtensor/source/tensor_ctors.cpp @@ -45,7 +45,7 @@ #include "accumulators.hpp" #include "boolean_advanced_indexing.hpp" -// #include "clip.hpp" +#include "clip.hpp" #include "copy_and_cast_usm_to_usm.hpp" #include "copy_as_contig.hpp" #include "copy_for_reshape.hpp" @@ -57,12 +57,12 @@ #include "integer_advanced_indexing.hpp" #include "kernels/dpctl_tensor_types.hpp" // #include "linear_sequences.hpp" -// #include "repeat.hpp" +#include "repeat.hpp" #include "simplify_iteration_space.hpp" #include "triul_ctor.hpp" #include "utils/memory_overlap.hpp" #include "utils/strided_iters.hpp" -// #include "where.hpp" +#include "where.hpp" #include "zeros_ctor.hpp" namespace py = pybind11; @@ -119,8 +119,8 @@ using dpctl::tensor::py_internal::py_place; /* ================= Repeat ====================*/ using dpctl::tensor::py_internal::py_cumsum_1d; -// using dpctl::tensor::py_internal::py_repeat_by_scalar; -// using dpctl::tensor::py_internal::py_repeat_by_sequence; +using dpctl::tensor::py_internal::py_repeat_by_scalar; +using dpctl::tensor::py_internal::py_repeat_by_sequence; /* ================ Eye ================== */ @@ -132,10 +132,10 @@ using dpctl::tensor::py_internal::usm_ndarray_triul; /* =========================== Where ============================== */ -// using dpctl::tensor::py_internal::py_where; +using dpctl::tensor::py_internal::py_where; /* =========================== Clip ============================== */ -// using dpctl::tensor::py_internal::py_clip; +using dpctl::tensor::py_internal::py_clip; // populate dispatch tables void init_dispatch_tables(void) @@ -145,7 +145,7 @@ void init_dispatch_tables(void) init_copy_and_cast_usm_to_usm_dispatch_tables(); init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables(); init_advanced_indexing_dispatch_tables(); - // init_where_dispatch_tables(); + init_where_dispatch_tables(); return; } @@ -169,9 +169,9 @@ void init_dispatch_vectors(void) populate_mask_positions_dispatch_vectors(); populate_cumsum_1d_dispatch_vectors(); - // init_repeat_dispatch_vectors(); + init_repeat_dispatch_vectors(); - // init_clip_dispatch_vectors(); + init_clip_dispatch_vectors(); return; } @@ -446,55 +446,53 @@ PYBIND11_MODULE(_tensor_impl, m) py::arg("mask_shape"), py::arg("sycl_queue"), py::arg("depends") = py::list()); - // m.def("_where", &py_where, "", py::arg("condition"), py::arg("x1"), - // py::arg("x2"), py::arg("dst"), py::arg("sycl_queue"), - // py::arg("depends") = py::list()); - - // auto repeat_sequence = [](const dpctl::tensor::usm_ndarray &src, - // const dpctl::tensor::usm_ndarray &dst, - // const dpctl::tensor::usm_ndarray &reps, - // const dpctl::tensor::usm_ndarray &cumsum, - // std::optional axis, sycl::queue &exec_q, - // const std::vector depends) - // -> std::pair { - // if (axis) { - // return py_repeat_by_sequence(src, dst, reps, cumsum, - // axis.value(), - // exec_q, depends); - // } - // else { - // return py_repeat_by_sequence(src, dst, reps, cumsum, exec_q, - // depends); - // } - // }; - // m.def("_repeat_by_sequence", repeat_sequence, py::arg("src"), - // py::arg("dst"), py::arg("reps"), py::arg("cumsum"), - // py::arg("axis"), py::arg("sycl_queue"), py::arg("depends") = - // py::list()); - - // auto repeat_scalar = [](const dpctl::tensor::usm_ndarray &src, - // const dpctl::tensor::usm_ndarray &dst, - // const py::ssize_t reps, std::optional axis, - // sycl::queue &exec_q, - // const std::vector depends) - // -> std::pair { - // if (axis) { - // return py_repeat_by_scalar(src, dst, reps, axis.value(), exec_q, - // depends); - // } - // else { - // return py_repeat_by_scalar(src, dst, reps, exec_q, depends); - // } - // }; - // m.def("_repeat_by_scalar", repeat_scalar, py::arg("src"), py::arg("dst"), - // py::arg("reps"), py::arg("axis"), py::arg("sycl_queue"), - // py::arg("depends") = py::list()); - - // m.def("_clip", &py_clip, - // "Clamps elements of array `x` to the range " - // "[`min`, `max] and writes the result to the " - // "array `dst` for each element of `x`, `min`, and `max`." - // "Returns a tuple of events: (hev, ev)", - // py::arg("src"), py::arg("min"), py::arg("max"), py::arg("dst"), - // py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_where", &py_where, "", py::arg("condition"), py::arg("x1"), + py::arg("x2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto repeat_sequence = [](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const dpctl::tensor::usm_ndarray &reps, + const dpctl::tensor::usm_ndarray &cumsum, + std::optional axis, sycl::queue &exec_q, + const std::vector depends) + -> std::pair { + if (axis) { + return py_repeat_by_sequence(src, dst, reps, cumsum, axis.value(), + exec_q, depends); + } + else { + return py_repeat_by_sequence(src, dst, reps, cumsum, exec_q, + depends); + } + }; + m.def("_repeat_by_sequence", repeat_sequence, py::arg("src"), + py::arg("dst"), py::arg("reps"), py::arg("cumsum"), py::arg("axis"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto repeat_scalar = [](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + const py::ssize_t reps, std::optional axis, + sycl::queue &exec_q, + const std::vector depends) + -> std::pair { + if (axis) { + return py_repeat_by_scalar(src, dst, reps, axis.value(), exec_q, + depends); + } + else { + return py_repeat_by_scalar(src, dst, reps, exec_q, depends); + } + }; + m.def("_repeat_by_scalar", repeat_scalar, py::arg("src"), py::arg("dst"), + py::arg("reps"), py::arg("axis"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + m.def("_clip", &py_clip, + "Clamps elements of array `x` to the range " + "[`min`, `max] and writes the result to the " + "array `dst` for each element of `x`, `min`, and `max`." + "Returns a tuple of events: (hev, ev)", + py::arg("src"), py::arg("min"), py::arg("max"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); } diff --git a/dpctl_ext/tensor/libtensor/source/where.cpp b/dpctl_ext/tensor/libtensor/source/where.cpp new file mode 100644 index 00000000000..1afdbf45c66 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/where.cpp @@ -0,0 +1,266 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines Python API for implementation functions of +/// dpctl.tensor.where +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/where.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "simplify_iteration_space.hpp" +#include "where.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::kernels::search::where_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::search::where_strided_impl_fn_ptr_t; + +static where_contig_impl_fn_ptr_t where_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static where_strided_impl_fn_ptr_t + where_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::utils::keep_args_alive; + +std::pair + py_where(const dpctl::tensor::usm_ndarray &condition, + const dpctl::tensor::usm_ndarray &x1, + const dpctl::tensor::usm_ndarray &x2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) +{ + + if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, condition, dst})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + int nd = condition.get_ndim(); + int x1_nd = x1.get_ndim(); + int x2_nd = x2.get_ndim(); + int dst_nd = dst.get_ndim(); + + if (nd != x1_nd || nd != x2_nd) { + throw py::value_error( + "Input arrays are not of appropriate dimension for where kernel."); + } + + if (nd != dst_nd) { + throw py::value_error( + "Destination is not of appropriate dimension for where kernel."); + } + + const py::ssize_t *x1_shape = x1.get_shape_raw(); + const py::ssize_t *x2_shape = x2.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + const py::ssize_t *cond_shape = condition.get_shape_raw(); + + bool shapes_equal(true); + std::size_t nelems(1); + for (int i = 0; i < nd; ++i) { + const auto &sh_i = dst_shape[i]; + nelems *= static_cast(sh_i); + shapes_equal = shapes_equal && (x1_shape[i] == sh_i) && + (x2_shape[i] == sh_i) && (cond_shape[i] == sh_i); + } + + if (!shapes_equal) { + throw py::value_error("Axes are not of matching shapes."); + } + + if (nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if ((overlap(dst, condition) && !same_logical_tensors(dst, condition)) || + (overlap(dst, x1) && !same_logical_tensors(dst, x1)) || + (overlap(dst, x2) && !same_logical_tensors(dst, x2))) + { + throw py::value_error("Destination array overlaps with input."); + } + + int x1_typenum = x1.get_typenum(); + int x2_typenum = x2.get_typenum(); + int cond_typenum = condition.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int cond_typeid = array_types.typenum_to_lookup_id(cond_typenum); + int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); + int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (x1_typeid != x2_typeid || x1_typeid != dst_typeid) { + throw py::value_error("Value arrays must have the same data type"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems); + + char *cond_data = condition.get_data(); + char *x1_data = x1.get_data(); + char *x2_data = x2.get_data(); + char *dst_data = dst.get_data(); + + bool is_x1_c_contig = x1.is_c_contiguous(); + bool is_x1_f_contig = x1.is_f_contiguous(); + + bool is_x2_c_contig = x2.is_c_contiguous(); + bool is_x2_f_contig = x2.is_f_contiguous(); + + bool is_cond_c_contig = condition.is_c_contiguous(); + bool is_cond_f_contig = condition.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_dst_f_contig = dst.is_f_contiguous(); + + bool all_c_contig = (is_x1_c_contig && is_x2_c_contig && is_cond_c_contig && + is_dst_c_contig); + bool all_f_contig = (is_x1_f_contig && is_x2_f_contig && is_cond_f_contig && + is_dst_f_contig); + + if (all_c_contig || all_f_contig) { + auto contig_fn = where_contig_dispatch_table[x1_typeid][cond_typeid]; + + auto where_ev = contig_fn(exec_q, nelems, cond_data, x1_data, x2_data, + dst_data, depends); + sycl::event ht_ev = + keep_args_alive(exec_q, {x1, x2, dst, condition}, {where_ev}); + + return std::make_pair(ht_ev, where_ev); + } + + auto const &cond_strides = condition.get_strides_vector(); + auto const &x1_strides = x1.get_strides_vector(); + auto const &x2_strides = x2.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_cond_strides; + shT simplified_x1_strides; + shT simplified_x2_strides; + shT simplified_dst_strides; + py::ssize_t cond_offset(0); + py::ssize_t x1_offset(0); + py::ssize_t x2_offset(0); + py::ssize_t dst_offset(0); + + dpctl::tensor::py_internal::simplify_iteration_space_4( + nd, x1_shape, cond_strides, x1_strides, x2_strides, dst_strides, + // outputs + simplified_shape, simplified_cond_strides, simplified_x1_strides, + simplified_x2_strides, simplified_dst_strides, cond_offset, x1_offset, + x2_offset, dst_offset); + + auto fn = where_strided_dispatch_table[x1_typeid][cond_typeid]; + + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, + // common shape and strides + simplified_shape, simplified_cond_strides, simplified_x1_strides, + simplified_x2_strides, simplified_dst_strides); + auto packed_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + sycl::event copy_shape_strides_ev = std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shape_strides = packed_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + assert(all_deps.size() == depends.size() + 1); + + sycl::event where_ev = fn(exec_q, nelems, nd, cond_data, x1_data, x2_data, + dst_data, packed_shape_strides, cond_offset, + x1_offset, x2_offset, dst_offset, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {where_ev}, packed_shape_strides_owner); + host_task_events.push_back(temporaries_cleanup_ev); + + sycl::event arg_cleanup_ev = + keep_args_alive(exec_q, {x1, x2, condition, dst}, host_task_events); + + return std::make_pair(arg_cleanup_ev, where_ev); +} + +void init_where_dispatch_tables(void) +{ + using namespace td_ns; + using dpctl::tensor::kernels::search::WhereContigFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(where_contig_dispatch_table); + + using dpctl::tensor::kernels::search::WhereStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(where_strided_dispatch_table); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/where.hpp b/dpctl_ext/tensor/libtensor/source/where.hpp new file mode 100644 index 00000000000..ba81d8b1164 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/where.hpp @@ -0,0 +1,57 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file declares Python API for implementation functions of +/// dpctl.tensor.where +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include + +#include + +#include "dpnp4pybind11.hpp" + +namespace dpctl::tensor::py_internal +{ + +extern std::pair + py_where(const dpctl::tensor::usm_ndarray &, + const dpctl::tensor::usm_ndarray &, + const dpctl::tensor::usm_ndarray &, + const dpctl::tensor::usm_ndarray &, + sycl::queue &, + const std::vector &); + +extern void init_where_dispatch_tables(void); + +} // namespace dpctl::tensor::py_internal diff --git a/dpnp/dpnp_algo/dpnp_arraycreation.py b/dpnp/dpnp_algo/dpnp_arraycreation.py index 47edf63a68b..f3dd1815356 100644 --- a/dpnp/dpnp_algo/dpnp_arraycreation.py +++ b/dpnp/dpnp_algo/dpnp_arraycreation.py @@ -243,7 +243,7 @@ def dpnp_linspace( # Needed a special handling for denormal numbers (when step == 0), # see numpy#5437 for more details. # Note, dpt.where() is used to avoid a synchronization branch. - usm_res = dpt.where( + usm_res = dpt_ext.where( step == 0, (usm_res / step_num) * delta, usm_res * step ) else: diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 0b6d882c53d..564627bacf2 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -38,12 +38,12 @@ import warnings import dpctl.tensor as dpt -import dpctl.tensor._type_utils as dtu from dpctl.tensor._numpy_helper import AxisError # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._type_utils as dtu import dpnp from . import memory as dpm diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index f305b106221..bc190db70c4 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -250,7 +250,7 @@ def choose(a, choices, out=None, mode="wrap"): res_usm_type, exec_q = get_usm_allocations(choices + [inds]) # apply type promotion to input choices - res_dt = dpt.result_type(*choices) + res_dt = dpt_ext.result_type(*choices) if len(choices) > 1: choices = tuple( map( diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index e988bbaa237..08fd55c58ac 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -1270,7 +1270,7 @@ def can_cast(from_, to, casting="safe"): if dpnp.is_supported_array_type(from_) else dpnp.dtype(from_) ) - return dpt.can_cast(dtype_from, to, casting=casting) + return dpt_ext.can_cast(dtype_from, to, casting=casting) def column_stack(tup): @@ -2837,7 +2837,7 @@ def repeat(a, repeats, axis=None): a = dpnp.ravel(a) usm_arr = dpnp.get_usm_ndarray(a) - usm_res = dpt.repeat(usm_arr, repeats, axis=axis) + usm_res = dpt_ext.repeat(usm_arr, repeats, axis=axis) return dpnp_array._create_from_usm_ndarray(usm_res) @@ -3195,7 +3195,7 @@ def result_type(*arrays_and_dtypes): ) for X in arrays_and_dtypes ] - return dpt.result_type(*usm_arrays_and_dtypes) + return dpt_ext.result_type(*usm_arrays_and_dtypes) def roll(x, shift, axis=None): diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index e339c24d384..3dbf07be080 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -40,6 +40,7 @@ """ # pylint: disable=protected-access +# pylint: disable=duplicate-code # pylint: disable=no-name-in-module @@ -48,15 +49,17 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti -import dpctl.tensor._type_utils as dtu import dpctl.utils as dpu import numpy from dpctl.tensor._numpy_helper import ( normalize_axis_index, normalize_axis_tuple, ) -from dpctl.tensor._type_utils import _acceptance_fn_divide +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._type_utils as dtu import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi @@ -727,7 +730,7 @@ def clip(a, /, min=None, max=None, *, out=None, order="K", **kwargs): usm_max = None if max is None else dpnp.get_usm_ndarray_or_scalar(max) usm_out = None if out is None else dpnp.get_usm_ndarray(out) - usm_res = dpt.clip(usm_arr, usm_min, usm_max, out=usm_out, order=order) + usm_res = dpt_ext.clip(usm_arr, usm_min, usm_max, out=usm_out, order=order) if out is not None and isinstance(out, dpnp_array): return out return dpnp_array._create_from_usm_ndarray(usm_res) @@ -1561,7 +1564,7 @@ def diff(a, n=1, axis=-1, prepend=None, append=None): mkl_fn_to_call="_mkl_div_to_call", mkl_impl_fn="_div", binary_inplace_fn=ti._divide_inplace, - acceptance_fn=_acceptance_fn_divide, + acceptance_fn=dtu._acceptance_fn_divide, ) diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 16ab633d506..a2389978d50 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -44,6 +44,7 @@ # pylint: disable=no-name-in-module # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext import dpctl_ext.tensor._tensor_impl as dti import dpnp @@ -473,5 +474,7 @@ def where(condition, x=None, y=None, /, *, order="K", out=None): usm_condition = dpnp.get_usm_ndarray(condition) usm_out = None if out is None else dpnp.get_usm_ndarray(out) - usm_res = dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out) + usm_res = dpt_ext.where( + usm_condition, usm_x, usm_y, order=order, out=usm_out + ) return dpnp.get_result_array(usm_res, out) diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index a46f06c10e0..9894bd30470 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -45,8 +45,10 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti -import dpctl.tensor._type_utils as dtu +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor._type_utils as dtu import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi diff --git a/dpnp/dpnp_iface_types.py b/dpnp/dpnp_iface_types.py index 8fdb9e1d3d3..f133333d6b8 100644 --- a/dpnp/dpnp_iface_types.py +++ b/dpnp/dpnp_iface_types.py @@ -40,6 +40,9 @@ import dpctl.tensor as dpt import numpy +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext import dpnp from .dpnp_array import dpnp_array @@ -211,7 +214,7 @@ def finfo(dtype): """ if isinstance(dtype, dpnp_array): dtype = dtype.dtype - return dpt.finfo(dtype) + return dpt_ext.finfo(dtype) # pylint: disable=redefined-outer-name @@ -244,7 +247,7 @@ def iinfo(dtype): if isinstance(dtype, dpnp_array): dtype = dtype.dtype - return dpt.iinfo(dtype) + return dpt_ext.iinfo(dtype) def isdtype(dtype, kind): @@ -298,7 +301,7 @@ def isdtype(dtype, kind): elif isinstance(kind, tuple): kind = tuple(dpt.dtype(k) if isinstance(k, type) else k for k in kind) - return dpt.isdtype(dtype, kind) + return dpt_ext.isdtype(dtype, kind) def issubdtype(arg1, arg2): diff --git a/dpnp/dpnp_utils/dpnp_utils_common.py b/dpnp/dpnp_utils/dpnp_utils_common.py index e4bde2e1ec8..aa294fefe27 100644 --- a/dpnp/dpnp_utils/dpnp_utils_common.py +++ b/dpnp/dpnp_utils/dpnp_utils_common.py @@ -29,8 +29,9 @@ from collections.abc import Iterable -import dpctl.tensor._type_utils as dtu - +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor._type_utils as dtu import dpnp from dpnp.dpnp_utils import map_dtype_to_device diff --git a/dpnp/tests/test_indexing.py b/dpnp/tests/test_indexing.py index 9a55efe138b..79c41a2f45f 100644 --- a/dpnp/tests/test_indexing.py +++ b/dpnp/tests/test_indexing.py @@ -4,8 +4,6 @@ import dpctl.tensor as dpt import numpy import pytest -from dpctl.tensor._numpy_helper import AxisError -from dpctl.tensor._type_utils import _to_device_supported_dtype from dpctl.utils import ExecutionPlacementError from numpy.testing import ( assert_, @@ -16,6 +14,11 @@ ) import dpnp + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +from dpctl_ext.tensor._numpy_helper import AxisError +from dpctl_ext.tensor._type_utils import _to_device_supported_dtype from dpnp.dpnp_array import dpnp_array from .helper import (