-
Notifications
You must be signed in to change notification settings - Fork 14
task: add patch methods for mkl_random #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-milestone
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,275 @@ | ||||||||||||||||||||||||||||||||
| # Copyright (c) 2019, Intel Corporation | ||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||
| # Redistribution and use in source and binary forms, with or without | ||||||||||||||||||||||||||||||||
| # modification, are permitted provided that the following conditions are met: | ||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||
| # * Redistributions of source code must retain the above copyright notice, | ||||||||||||||||||||||||||||||||
| # this list of conditions and the following disclaimer. | ||||||||||||||||||||||||||||||||
| # * Redistributions in binary form must reproduce the above copyright | ||||||||||||||||||||||||||||||||
| # notice, this list of conditions and the following disclaimer in the | ||||||||||||||||||||||||||||||||
| # documentation and/or other materials provided with the distribution. | ||||||||||||||||||||||||||||||||
| # * Neither the name of Intel Corporation nor the names of its contributors | ||||||||||||||||||||||||||||||||
| # may be used to endorse or promote products derived from this software | ||||||||||||||||||||||||||||||||
| # without specific prior written permission. | ||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||||||||||||||||||||||||||||||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||||||||||||||||||||||||||||||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||||||||||||||||||||||||||||||||
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE | ||||||||||||||||||||||||||||||||
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||||||||||||||||||||||||||||||||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||||||||||||||||||||||||||||||||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||||||||||||||||||||||||||||||||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||||||||||||||||||||||||||||||||
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||||||||||||||||||||||||||||||||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # distutils: language = c | ||||||||||||||||||||||||||||||||
| # cython: language_level=3 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Patch NumPy's `numpy.random` symbols to use mkl_random implementations. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| This is attribute-level monkey patching. It can replace legacy APIs like | ||||||||||||||||||||||||||||||||
| `numpy.random.RandomState` and global distribution functions, but it does not | ||||||||||||||||||||||||||||||||
| replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully | ||||||||||||||||||||||||||||||||
| compatible replacements. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| from threading import local as threading_local | ||||||||||||||||||||||||||||||||
| from contextlib import ContextDecorator | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import numpy as _np | ||||||||||||||||||||||||||||||||
| from . import mklrand as _mr | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| cdef tuple _DEFAULT_NAMES = ( | ||||||||||||||||||||||||||||||||
| # Legacy seeding / state | ||||||||||||||||||||||||||||||||
| "seed", | ||||||||||||||||||||||||||||||||
| "get_state", | ||||||||||||||||||||||||||||||||
| "set_state", | ||||||||||||||||||||||||||||||||
| "RandomState", | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Common global sampling helpers | ||||||||||||||||||||||||||||||||
| "random", | ||||||||||||||||||||||||||||||||
| "random_sample", | ||||||||||||||||||||||||||||||||
| "sample", | ||||||||||||||||||||||||||||||||
| "rand", | ||||||||||||||||||||||||||||||||
| "randn", | ||||||||||||||||||||||||||||||||
| "bytes", | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Integers | ||||||||||||||||||||||||||||||||
| "randint", | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Common distributions (only patched if present on both sides) | ||||||||||||||||||||||||||||||||
| "standard_normal", | ||||||||||||||||||||||||||||||||
| "normal", | ||||||||||||||||||||||||||||||||
| "uniform", | ||||||||||||||||||||||||||||||||
| "exponential", | ||||||||||||||||||||||||||||||||
| "gamma", | ||||||||||||||||||||||||||||||||
| "beta", | ||||||||||||||||||||||||||||||||
| "chisquare", | ||||||||||||||||||||||||||||||||
| "f", | ||||||||||||||||||||||||||||||||
| "lognormal", | ||||||||||||||||||||||||||||||||
| "laplace", | ||||||||||||||||||||||||||||||||
| "logistic", | ||||||||||||||||||||||||||||||||
| "multivariate_normal", | ||||||||||||||||||||||||||||||||
| "poisson", | ||||||||||||||||||||||||||||||||
| "power", | ||||||||||||||||||||||||||||||||
| "rayleigh", | ||||||||||||||||||||||||||||||||
| "triangular", | ||||||||||||||||||||||||||||||||
| "vonmises", | ||||||||||||||||||||||||||||||||
| "wald", | ||||||||||||||||||||||||||||||||
| "weibull", | ||||||||||||||||||||||||||||||||
| "zipf", | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Permutations / choices | ||||||||||||||||||||||||||||||||
| "choice", | ||||||||||||||||||||||||||||||||
| "permutation", | ||||||||||||||||||||||||||||||||
| "shuffle", | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| cdef class patch: | ||||||||||||||||||||||||||||||||
| cdef bint _is_patched | ||||||||||||||||||||||||||||||||
| cdef object _numpy_module | ||||||||||||||||||||||||||||||||
| cdef object _originals # dict: name -> original object | ||||||||||||||||||||||||||||||||
| cdef object _patched # list of names actually patched | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def __cinit__(self): | ||||||||||||||||||||||||||||||||
| self._is_patched = False | ||||||||||||||||||||||||||||||||
| self._numpy_module = None | ||||||||||||||||||||||||||||||||
| self._originals = {} | ||||||||||||||||||||||||||||||||
| self._patched = [] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def do_patch(self, numpy_module=None, names=None, bint strict=False): | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Patch the given numpy module (default: imported numpy) in-place. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||
| numpy_module : module, optional | ||||||||||||||||||||||||||||||||
| The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`). | ||||||||||||||||||||||||||||||||
| names : iterable[str], optional | ||||||||||||||||||||||||||||||||
| Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES. | ||||||||||||||||||||||||||||||||
| strict : bool | ||||||||||||||||||||||||||||||||
| If True, raise if any requested symbol cannot be patched. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| if numpy_module is None: | ||||||||||||||||||||||||||||||||
| numpy_module = _np | ||||||||||||||||||||||||||||||||
| if names is None: | ||||||||||||||||||||||||||||||||
| names = _DEFAULT_NAMES | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if not hasattr(numpy_module, "random"): | ||||||||||||||||||||||||||||||||
| raise TypeError("Expected a numpy-like module with a `.random` attribute.") | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # If already patched, only allow idempotent re-entry for the same numpy module. | ||||||||||||||||||||||||||||||||
| if self._is_patched: | ||||||||||||||||||||||||||||||||
| if self._numpy_module is numpy_module: | ||||||||||||||||||||||||||||||||
|
Comment on lines
+125
to
+127
|
||||||||||||||||||||||||||||||||
| # If already patched, only allow idempotent re-entry for the same numpy module. | |
| if self._is_patched: | |
| if self._numpy_module is numpy_module: | |
| # If already patched, only allow idempotent re-entry for the same numpy module | |
| # (by module name) to avoid failures when the same module is imported via | |
| # different mechanisms. | |
| if self._is_patched: | |
| same_module = ( | |
| self._numpy_module is not None | |
| and numpy_module is not None | |
| and getattr(self._numpy_module, "__name__", None) | |
| == getattr(numpy_module, "__name__", None) | |
| ) | |
| if same_module: |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The do_unpatch method doesn't handle potential exceptions when restoring attributes. If setattr() fails for any name (e.g., if the module structure changed or an attribute became read-only), the unpatch operation will be incomplete, leaving _is_patched as True and some attributes not restored. Consider wrapping the restore loop in error handling or using a try-finally to ensure state is cleaned up even if restoration fails.
| for n, v in self._originals.items(): | |
| setattr(np_random, n, v) | |
| self._numpy_module = None | |
| self._originals = {} | |
| self._patched = [] | |
| self._is_patched = False | |
| try: | |
| for n, v in self._originals.items(): | |
| setattr(np_random, n, v) | |
| finally: | |
| self._numpy_module = None | |
| self._originals = {} | |
| self._patched = [] | |
| self._is_patched = False |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _is_tls_initialized() function uses getattr(_tls, "initialized", None) is not None followed by and (_tls.initialized is True). This is redundant - the second check alone (getattr(_tls, "initialized", False)) would be simpler and equivalent. Consider simplifying to: return getattr(_tls, "initialized", False)
| return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) | |
| return getattr(_tls, "initialized", False) |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thread-local storage approach creates a mismatch between thread-local state tracking and global module patching. Since numpy.random is a module-level singleton, patching it in one thread affects all threads, but the is_patched() status is tracked per-thread. This means:
- Thread A can patch numpy.random and _tls.patch._is_patched becomes True for Thread A
- Thread B checks is_patched() and gets False (Thread B has its own _tls)
- But numpy.random is actually patched globally for both threads
Consider using a module-level (not thread-local) patch tracker, with appropriate locking if thread-safety is needed.
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class name mkl_random creates a naming collision with the module name. While this works (accessed as mkl_random.mkl_random), it can be confusing. Consider renaming the class to something like MklRandomContext or patch_context to avoid the collision and improve clarity. However, this would be a breaking API change if the current naming is intentional.
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example in the docstring shows with mkl_random.mkl_random(): without passing numpy module, but the test at line 52 shows with mkl_random.mkl_random(np):. The example should clarify whether passing np is optional (defaulting to the imported numpy) or if it's recommended to be explicit. Consider updating the example to match the test pattern for consistency.
| >>> with mkl_random.mkl_random(): | |
| >>> with mkl_random.mkl_random(np): |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nested context managers will incorrectly restore the patch when the inner context exits. For example:
with mkl_random.mkl_random(np):
# numpy.random is patched
with mkl_random.mkl_random(np):
# still patched
pass
# Inner __exit__ calls restore(), unpatching numpy.random
# But outer context still expects it to be patchedConsider implementing reference counting or a stack-based approach to handle nested contexts correctly.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,95 @@ | ||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
| import mkl_random | ||||||||||||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_is_patched(): | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| Test that is_patched() returns correct status. | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| assert not mkl_random.is_patched() | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.monkey_patch(np) | ||||||||||||||||||||||||||||||||||||||||||||
| assert mkl_random.is_patched() | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.restore() | ||||||||||||||||||||||||||||||||||||||||||||
| assert not mkl_random.is_patched() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_monkey_patch_and_restore(): | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| Test that monkey_patch replaces and restore brings back original functions. | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| # Store original functions | ||||||||||||||||||||||||||||||||||||||||||||
| orig_normal = np.random.normal | ||||||||||||||||||||||||||||||||||||||||||||
| orig_randint = np.random.randint | ||||||||||||||||||||||||||||||||||||||||||||
| orig_RandomState = np.random.RandomState | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.monkey_patch(np) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Check that functions are now different objects | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.normal is not orig_normal | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.randint is not orig_randint | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.RandomState is not orig_RandomState | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Check that they are from mkl_random | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.normal is mkl_random.mklrand.normal | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.RandomState is mkl_random.mklrand.RandomState | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.restore() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Check that original functions are restored | ||||||||||||||||||||||||||||||||||||||||||||
| assert mkl_random.is_patched() is False | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.normal is orig_normal | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.randint is orig_randint | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.RandomState is orig_RandomState | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_context_manager(): | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| Test that the context manager patches and automatically restores. | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| orig_uniform = np.random.uniform | ||||||||||||||||||||||||||||||||||||||||||||
| assert not mkl_random.is_patched() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| with mkl_random.mkl_random(np): | ||||||||||||||||||||||||||||||||||||||||||||
| assert mkl_random.is_patched() is True | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.uniform is not orig_uniform | ||||||||||||||||||||||||||||||||||||||||||||
| # Smoke test inside context | ||||||||||||||||||||||||||||||||||||||||||||
| arr = np.random.uniform(size=10) | ||||||||||||||||||||||||||||||||||||||||||||
| assert arr.shape == (10,) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| assert not mkl_random.is_patched() | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.random.uniform is orig_uniform | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_patched_functions_callable(): | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| Smoke test to ensure some patched functions can be called without error. | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.monkey_patch(np) | ||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||
| # These calls should now be routed to mkl_random's implementations | ||||||||||||||||||||||||||||||||||||||||||||
| x = np.random.standard_normal(size=100) | ||||||||||||||||||||||||||||||||||||||||||||
| assert x.shape == (100,) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| y = np.random.randint(0, 100, size=50) | ||||||||||||||||||||||||||||||||||||||||||||
| assert y.shape == (50,) | ||||||||||||||||||||||||||||||||||||||||||||
| assert np.all(y >= 0) and np.all(y < 100) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| st = np.random.RandomState(12345) | ||||||||||||||||||||||||||||||||||||||||||||
| z = st.rand(10) | ||||||||||||||||||||||||||||||||||||||||||||
| assert z.shape == (10,) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.restore() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def test_patched_names(): | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| Test that patched_names() returns a list of patched symbols. | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.monkey_patch(np) | ||||||||||||||||||||||||||||||||||||||||||||
| names = mkl_random.patched_names() | ||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(names, list) | ||||||||||||||||||||||||||||||||||||||||||||
| assert len(names) > 0 | ||||||||||||||||||||||||||||||||||||||||||||
| assert "normal" in names | ||||||||||||||||||||||||||||||||||||||||||||
| assert "RandomState" in names | ||||||||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.restore() | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
| mkl_random.restore() | |
| mkl_random.restore() | |
| def test_monkey_patch_strict_raises_attribute_error(): | |
| """ | |
| Test that strict mode raises AttributeError when patching non-existent names. | |
| """ | |
| # Attempt to patch a clearly non-existent symbol in strict mode. | |
| with pytest.raises(AttributeError): | |
| mkl_random.monkey_patch(np, strict=True, names=["nonexistent_symbol"]) |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use_in_numpy function is exported in init.py and documented as a "backward-compatible alias" but has no test coverage. Consider adding a test to verify that use_in_numpy works identically to monkey_patch, ensuring the backward compatibility is maintained.
| mkl_random.restore() | |
| mkl_random.restore() | |
| def test_use_in_numpy_is_alias_for_monkey_patch(): | |
| """ | |
| Test that use_in_numpy is a backward-compatible alias for monkey_patch. | |
| """ | |
| assert hasattr(mkl_random, "use_in_numpy") | |
| assert mkl_random.use_in_numpy is mkl_random.monkey_patch |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no test coverage for edge cases such as:
- Calling restore() when not patched (should be safe/no-op based on line 163-164 of _patch.pyx)
- Calling monkey_patch() multiple times with the same numpy module (idempotent behavior from lines 126-128)
- Calling monkey_patch() with different numpy modules sequentially
- Error handling when numpy_module doesn't have a 'random' attribute
Consider adding tests for these edge cases to ensure the implementation handles them correctly.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,14 @@ def extensions(): | |
| extra_compile_args = eca, | ||
| define_macros=defs + [("NDEBUG", None)], | ||
| language="c++" | ||
| ), | ||
|
|
||
| Extension( | ||
| "mkl_random._patch", | ||
| sources=[join("mkl_random", "src", "_patch.pyx")], | ||
| include_dirs=[np.get_include()], | ||
| define_macros=defs + [("NDEBUG", None)], | ||
| language="c", | ||
| ) | ||
|
Comment on lines
+86
to
92
|
||
| ] | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copyright year is 2019, while other files in the codebase use 2017 (e.g., init.py, test_random.py). This inconsistency should be corrected to match the project's copyright dating convention, likely using the year the file was actually created or the standard project copyright year.