Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mkl_random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
test = PytestTester(__name__)
del PytestTester

from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random

del _init_helper
275 changes: 275 additions & 0 deletions mkl_random/src/_patch.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright (c) 2019, Intel Corporation
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copyright year is 2019, while other files in the codebase use 2017 (e.g., init.py, test_random.py). This inconsistency should be corrected to match the project's copyright dating convention, likely using the year the file was actually created or the standard project copyright year.

Suggested change
# Copyright (c) 2019, Intel Corporation
# Copyright (c) 2017, Intel Corporation

Copilot uses AI. Check for mistakes.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of Intel Corporation nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# distutils: language = c
# cython: language_level=3

"""
Patch NumPy's `numpy.random` symbols to use mkl_random implementations.

This is attribute-level monkey patching. It can replace legacy APIs like
`numpy.random.RandomState` and global distribution functions, but it does not
replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
compatible replacements.
"""

from threading import local as threading_local
from contextlib import ContextDecorator

import numpy as _np
from . import mklrand as _mr


cdef tuple _DEFAULT_NAMES = (
# Legacy seeding / state
"seed",
"get_state",
"set_state",
"RandomState",

# Common global sampling helpers
"random",
"random_sample",
"sample",
"rand",
"randn",
"bytes",

# Integers
"randint",

# Common distributions (only patched if present on both sides)
"standard_normal",
"normal",
"uniform",
"exponential",
"gamma",
"beta",
"chisquare",
"f",
"lognormal",
"laplace",
"logistic",
"multivariate_normal",
"poisson",
"power",
"rayleigh",
"triangular",
"vonmises",
"wald",
"weibull",
"zipf",

# Permutations / choices
"choice",
"permutation",
"shuffle",
)


cdef class patch:
cdef bint _is_patched
cdef object _numpy_module
cdef object _originals # dict: name -> original object
cdef object _patched # list of names actually patched

def __cinit__(self):
self._is_patched = False
self._numpy_module = None
self._originals = {}
self._patched = []

def do_patch(self, numpy_module=None, names=None, bint strict=False):
"""
Patch the given numpy module (default: imported numpy) in-place.

Parameters
----------
numpy_module : module, optional
The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`).
names : iterable[str], optional
Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES.
strict : bool
If True, raise if any requested symbol cannot be patched.
"""
if numpy_module is None:
numpy_module = _np
if names is None:
names = _DEFAULT_NAMES

if not hasattr(numpy_module, "random"):
raise TypeError("Expected a numpy-like module with a `.random` attribute.")

# If already patched, only allow idempotent re-entry for the same numpy module.
if self._is_patched:
if self._numpy_module is numpy_module:
Comment on lines +125 to +127
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The identity check self._numpy_module is numpy_module may fail in some scenarios where the same module is imported differently (e.g., through different import mechanisms or in certain testing scenarios). Consider using self._numpy_module is not None and self._numpy_module.__name__ == numpy_module.__name__ for a more robust comparison, or document that the exact same module object must be passed.

Suggested change
# If already patched, only allow idempotent re-entry for the same numpy module.
if self._is_patched:
if self._numpy_module is numpy_module:
# If already patched, only allow idempotent re-entry for the same numpy module
# (by module name) to avoid failures when the same module is imported via
# different mechanisms.
if self._is_patched:
same_module = (
self._numpy_module is not None
and numpy_module is not None
and getattr(self._numpy_module, "__name__", None)
== getattr(numpy_module, "__name__", None)
)
if same_module:

Copilot uses AI. Check for mistakes.
return
raise RuntimeError("Already patched a different numpy module; call restore() first.")

np_random = numpy_module.random

originals = {}
patched = []
missing = []

for name in names:
if not hasattr(np_random, name) or not hasattr(_mr, name):
missing.append(name)
continue
originals[name] = getattr(np_random, name)
setattr(np_random, name, getattr(_mr, name))
patched.append(name)

if strict and missing:
# revert partial patch before raising
for n, v in originals.items():
setattr(np_random, n, v)
raise AttributeError(
"Could not patch these names (missing on numpy.random or mkl_random.mklrand): "
+ ", ".join([str(x) for x in missing])
)

self._numpy_module = numpy_module
self._originals = originals
self._patched = patched
self._is_patched = True

def do_unpatch(self):
"""
Restore the previously patched numpy module.
"""
if not self._is_patched:
return
numpy_module = self._numpy_module
np_random = numpy_module.random
for n, v in self._originals.items():
setattr(np_random, n, v)

self._numpy_module = None
self._originals = {}
self._patched = []
self._is_patched = False

Comment on lines +167 to +174
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The do_unpatch method doesn't handle potential exceptions when restoring attributes. If setattr() fails for any name (e.g., if the module structure changed or an attribute became read-only), the unpatch operation will be incomplete, leaving _is_patched as True and some attributes not restored. Consider wrapping the restore loop in error handling or using a try-finally to ensure state is cleaned up even if restoration fails.

Suggested change
for n, v in self._originals.items():
setattr(np_random, n, v)
self._numpy_module = None
self._originals = {}
self._patched = []
self._is_patched = False
try:
for n, v in self._originals.items():
setattr(np_random, n, v)
finally:
self._numpy_module = None
self._originals = {}
self._patched = []
self._is_patched = False

Copilot uses AI. Check for mistakes.
def is_patched(self):
return self._is_patched

def patched_names(self):
"""
Returns list of names that were actually patched.
"""
return list(self._patched)


_tls = threading_local()


def _is_tls_initialized():
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _is_tls_initialized() function uses getattr(_tls, "initialized", None) is not None followed by and (_tls.initialized is True). This is redundant - the second check alone (getattr(_tls, "initialized", False)) would be simpler and equivalent. Consider simplifying to: return getattr(_tls, "initialized", False)

Suggested change
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)
return getattr(_tls, "initialized", False)

Copilot uses AI. Check for mistakes.


def _initialize_tls():
_tls.patch = patch()
_tls.initialized = True
Comment on lines +185 to +194
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thread-local storage approach creates a mismatch between thread-local state tracking and global module patching. Since numpy.random is a module-level singleton, patching it in one thread affects all threads, but the is_patched() status is tracked per-thread. This means:

  • Thread A can patch numpy.random and _tls.patch._is_patched becomes True for Thread A
  • Thread B checks is_patched() and gets False (Thread B has its own _tls)
  • But numpy.random is actually patched globally for both threads

Consider using a module-level (not thread-local) patch tracker, with appropriate locking if thread-safety is needed.

Copilot uses AI. Check for mistakes.


def monkey_patch(numpy_module=None, names=None, strict=False):
"""
Enables using mkl_random in the given NumPy module by patching `numpy.random`.

Examples
--------
>>> import numpy as np
>>> import mkl_random
>>> mkl_random.is_patched()
False
>>> mkl_random.monkey_patch(np)
>>> mkl_random.is_patched()
True
>>> mkl_random.restore()
>>> mkl_random.is_patched()
False
"""
if not _is_tls_initialized():
_initialize_tls()
_tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict))


def use_in_numpy(numpy_module=None, names=None, strict=False):
"""
Backward-compatible alias for monkey_patch().
"""
monkey_patch(numpy_module=numpy_module, names=names, strict=strict)


def restore():
"""
Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols.
"""
if not _is_tls_initialized():
_initialize_tls()
_tls.patch.do_unpatch()


def is_patched():
"""
Returns whether NumPy has been patched with mkl_random.
"""
if not _is_tls_initialized():
_initialize_tls()
return bool(_tls.patch.is_patched())


def patched_names():
"""
Returns the names actually patched in `numpy.random`.
"""
if not _is_tls_initialized():
_initialize_tls()
return _tls.patch.patched_names()


class mkl_random(ContextDecorator):
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class name mkl_random creates a naming collision with the module name. While this works (accessed as mkl_random.mkl_random), it can be confusing. Consider renaming the class to something like MklRandomContext or patch_context to avoid the collision and improve clarity. However, this would be a breaking API change if the current naming is intentional.

Copilot uses AI. Check for mistakes.
"""
Context manager and decorator to temporarily patch NumPy's `numpy.random`.

Examples
--------
>>> import numpy as np
>>> import mkl_random
>>> with mkl_random.mkl_random():
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example in the docstring shows with mkl_random.mkl_random(): without passing numpy module, but the test at line 52 shows with mkl_random.mkl_random(np):. The example should clarify whether passing np is optional (defaulting to the imported numpy) or if it's recommended to be explicit. Consider updating the example to match the test pattern for consistency.

Suggested change
>>> with mkl_random.mkl_random():
>>> with mkl_random.mkl_random(np):

Copilot uses AI. Check for mistakes.
... x = np.random.normal(size=10)
"""
def __init__(self, numpy_module=None, names=None, strict=False):
self._numpy_module = numpy_module
self._names = names
self._strict = strict

def __enter__(self):
monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict)
return self

def __exit__(self, *exc):
restore()
return False
Comment on lines +269 to +275
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested context managers will incorrectly restore the patch when the inner context exits. For example:

with mkl_random.mkl_random(np):
    # numpy.random is patched
    with mkl_random.mkl_random(np):
        # still patched
        pass
    # Inner __exit__ calls restore(), unpatching numpy.random
    # But outer context still expects it to be patched

Consider implementing reference counting or a stack-based approach to handle nested contexts correctly.

Copilot uses AI. Check for mistakes.
95 changes: 95 additions & 0 deletions mkl_random/tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test file is missing the Intel Corporation copyright and BSD license header that is present in all other Python/Cython files in the codebase (see mkl_random/init.py:1-25, mkl_random/tests/test_random.py:1-25, mkl_random/src/_patch.pyx:1-24). Add the appropriate copyright header to maintain consistency with the project's licensing requirements.

Copilot uses AI. Check for mistakes.
import mkl_random
import pytest

def test_is_patched():
"""
Test that is_patched() returns correct status.
"""
assert not mkl_random.is_patched()
mkl_random.monkey_patch(np)
assert mkl_random.is_patched()
mkl_random.restore()
assert not mkl_random.is_patched()

def test_monkey_patch_and_restore():
"""
Test that monkey_patch replaces and restore brings back original functions.
"""
# Store original functions
orig_normal = np.random.normal
orig_randint = np.random.randint
orig_RandomState = np.random.RandomState

try:
mkl_random.monkey_patch(np)

# Check that functions are now different objects
assert np.random.normal is not orig_normal
assert np.random.randint is not orig_randint
assert np.random.RandomState is not orig_RandomState

# Check that they are from mkl_random
assert np.random.normal is mkl_random.mklrand.normal
assert np.random.RandomState is mkl_random.mklrand.RandomState

finally:
mkl_random.restore()

# Check that original functions are restored
assert mkl_random.is_patched() is False
assert np.random.normal is orig_normal
assert np.random.randint is orig_randint
assert np.random.RandomState is orig_RandomState

def test_context_manager():
"""
Test that the context manager patches and automatically restores.
"""
orig_uniform = np.random.uniform
assert not mkl_random.is_patched()

with mkl_random.mkl_random(np):
assert mkl_random.is_patched() is True
assert np.random.uniform is not orig_uniform
# Smoke test inside context
arr = np.random.uniform(size=10)
assert arr.shape == (10,)

assert not mkl_random.is_patched()
assert np.random.uniform is orig_uniform

def test_patched_functions_callable():
"""
Smoke test to ensure some patched functions can be called without error.
"""
mkl_random.monkey_patch(np)
try:
# These calls should now be routed to mkl_random's implementations
x = np.random.standard_normal(size=100)
assert x.shape == (100,)

y = np.random.randint(0, 100, size=50)
assert y.shape == (50,)
assert np.all(y >= 0) and np.all(y < 100)

st = np.random.RandomState(12345)
z = st.rand(10)
assert z.shape == (10,)

finally:
mkl_random.restore()

def test_patched_names():
"""
Test that patched_names() returns a list of patched symbols.
"""
try:
mkl_random.monkey_patch(np)
names = mkl_random.patched_names()
assert isinstance(names, list)
assert len(names) > 0
assert "normal" in names
assert "RandomState" in names
finally:
mkl_random.restore()
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no test coverage for the strict parameter functionality. The strict mode is designed to raise AttributeError when symbols cannot be patched, but this behavior is not tested. Consider adding a test that verifies strict mode works correctly, for example by trying to patch with a list that includes non-existent names.

Suggested change
mkl_random.restore()
mkl_random.restore()
def test_monkey_patch_strict_raises_attribute_error():
"""
Test that strict mode raises AttributeError when patching non-existent names.
"""
# Attempt to patch a clearly non-existent symbol in strict mode.
with pytest.raises(AttributeError):
mkl_random.monkey_patch(np, strict=True, names=["nonexistent_symbol"])

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use_in_numpy function is exported in init.py and documented as a "backward-compatible alias" but has no test coverage. Consider adding a test to verify that use_in_numpy works identically to monkey_patch, ensuring the backward compatibility is maintained.

Suggested change
mkl_random.restore()
mkl_random.restore()
def test_use_in_numpy_is_alias_for_monkey_patch():
"""
Test that use_in_numpy is a backward-compatible alias for monkey_patch.
"""
assert hasattr(mkl_random, "use_in_numpy")
assert mkl_random.use_in_numpy is mkl_random.monkey_patch

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +95
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no test coverage for edge cases such as:

  • Calling restore() when not patched (should be safe/no-op based on line 163-164 of _patch.pyx)
  • Calling monkey_patch() multiple times with the same numpy module (idempotent behavior from lines 126-128)
  • Calling monkey_patch() with different numpy modules sequentially
  • Error handling when numpy_module doesn't have a 'random' attribute

Consider adding tests for these edge cases to ensure the implementation handles them correctly.

Copilot uses AI. Check for mistakes.
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def extensions():
extra_compile_args = eca,
define_macros=defs + [("NDEBUG", None)],
language="c++"
),

Extension(
"mkl_random._patch",
sources=[join("mkl_random", "src", "_patch.pyx")],
include_dirs=[np.get_include()],
define_macros=defs + [("NDEBUG", None)],
language="c",
)
Comment on lines +86 to 92
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _patch extension includes define_macros with PY_ARRAY_UNIQUE_SYMBOL, but _patch.pyx doesn't use the NumPy C API (no cimport numpy or usage of numpy C structures). While this doesn't cause harm, these macros are unnecessary for this extension and could be removed for clarity. The extension only needs include_dirs=[np.get_include()] to compile successfully with Cython's NumPy support.

Copilot uses AI. Check for mistakes.
]

Expand Down
Loading