task: add patch methods for mkl_random#90
task: add patch methods for mkl_random#90jharlow-intel wants to merge 1 commit intoIntelPython:dev-milestonefrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds monkey-patching functionality to mkl_random, allowing users to temporarily or permanently replace numpy.random functions with their mkl_random equivalents. The implementation provides both imperative (monkey_patch(), restore()) and context manager (mkl_random.mkl_random()) interfaces.
Changes:
- Adds new
_patch.pyxCython module implementing patching logic with thread-local state tracking - Extends
setup.pyto build the new_patchextension module - Exports patch functions (
monkey_patch,use_in_numpy,restore,is_patched,patched_names,mkl_random) in__init__.py - Adds comprehensive test suite in
test_patch.pycovering basic patching, restoration, and context manager functionality
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 13 comments.
| File | Description |
|---|---|
| setup.py | Adds Extension configuration for mkl_random._patch module |
| mkl_random/src/_patch.pyx | Implements core patching logic with thread-local storage, patch/unpatch methods, and context manager |
| mkl_random/init.py | Exports patch-related functions to public API |
| mkl_random/tests/test_patch.py | Adds tests for patching, restoration, context manager, and patched function behavior |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -0,0 +1,275 @@ | |||
| # Copyright (c) 2019, Intel Corporation | |||
There was a problem hiding this comment.
The copyright year is 2019, while other files in the codebase use 2017 (e.g., init.py, test_random.py). This inconsistency should be corrected to match the project's copyright dating convention, likely using the year the file was actually created or the standard project copyright year.
| # Copyright (c) 2019, Intel Corporation | |
| # Copyright (c) 2017, Intel Corporation |
| Extension( | ||
| "mkl_random._patch", | ||
| sources=[join("mkl_random", "src", "_patch.pyx")], | ||
| include_dirs=[np.get_include()], | ||
| define_macros=defs + [("NDEBUG", None)], | ||
| language="c", | ||
| ) |
There was a problem hiding this comment.
The _patch extension includes define_macros with PY_ARRAY_UNIQUE_SYMBOL, but _patch.pyx doesn't use the NumPy C API (no cimport numpy or usage of numpy C structures). While this doesn't cause harm, these macros are unnecessary for this extension and could be removed for clarity. The extension only needs include_dirs=[np.get_include()] to compile successfully with Cython's NumPy support.
| _tls = threading_local() | ||
|
|
||
|
|
||
| def _is_tls_initialized(): | ||
| return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) | ||
|
|
||
|
|
||
| def _initialize_tls(): | ||
| _tls.patch = patch() | ||
| _tls.initialized = True |
There was a problem hiding this comment.
The thread-local storage approach creates a mismatch between thread-local state tracking and global module patching. Since numpy.random is a module-level singleton, patching it in one thread affects all threads, but the is_patched() status is tracked per-thread. This means:
- Thread A can patch numpy.random and _tls.patch._is_patched becomes True for Thread A
- Thread B checks is_patched() and gets False (Thread B has its own _tls)
- But numpy.random is actually patched globally for both threads
Consider using a module-level (not thread-local) patch tracker, with appropriate locking if thread-safety is needed.
| def __enter__(self): | ||
| monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict) | ||
| return self | ||
|
|
||
| def __exit__(self, *exc): | ||
| restore() | ||
| return False |
There was a problem hiding this comment.
Nested context managers will incorrectly restore the patch when the inner context exits. For example:
with mkl_random.mkl_random(np):
# numpy.random is patched
with mkl_random.mkl_random(np):
# still patched
pass
# Inner __exit__ calls restore(), unpatching numpy.random
# But outer context still expects it to be patchedConsider implementing reference counting or a stack-based approach to handle nested contexts correctly.
| # If already patched, only allow idempotent re-entry for the same numpy module. | ||
| if self._is_patched: | ||
| if self._numpy_module is numpy_module: |
There was a problem hiding this comment.
The identity check self._numpy_module is numpy_module may fail in some scenarios where the same module is imported differently (e.g., through different import mechanisms or in certain testing scenarios). Consider using self._numpy_module is not None and self._numpy_module.__name__ == numpy_module.__name__ for a more robust comparison, or document that the exact same module object must be passed.
| # If already patched, only allow idempotent re-entry for the same numpy module. | |
| if self._is_patched: | |
| if self._numpy_module is numpy_module: | |
| # If already patched, only allow idempotent re-entry for the same numpy module | |
| # (by module name) to avoid failures when the same module is imported via | |
| # different mechanisms. | |
| if self._is_patched: | |
| same_module = ( | |
| self._numpy_module is not None | |
| and numpy_module is not None | |
| and getattr(self._numpy_module, "__name__", None) | |
| == getattr(numpy_module, "__name__", None) | |
| ) | |
| if same_module: |
| assert "normal" in names | ||
| assert "RandomState" in names | ||
| finally: | ||
| mkl_random.restore() |
There was a problem hiding this comment.
The use_in_numpy function is exported in init.py and documented as a "backward-compatible alias" but has no test coverage. Consider adding a test to verify that use_in_numpy works identically to monkey_patch, ensuring the backward compatibility is maintained.
| mkl_random.restore() | |
| mkl_random.restore() | |
| def test_use_in_numpy_is_alias_for_monkey_patch(): | |
| """ | |
| Test that use_in_numpy is a backward-compatible alias for monkey_patch. | |
| """ | |
| assert hasattr(mkl_random, "use_in_numpy") | |
| assert mkl_random.use_in_numpy is mkl_random.monkey_patch |
|
|
||
|
|
||
| def _is_tls_initialized(): | ||
| return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) |
There was a problem hiding this comment.
The _is_tls_initialized() function uses getattr(_tls, "initialized", None) is not None followed by and (_tls.initialized is True). This is redundant - the second check alone (getattr(_tls, "initialized", False)) would be simpler and equivalent. Consider simplifying to: return getattr(_tls, "initialized", False)
| return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) | |
| return getattr(_tls, "initialized", False) |
| import numpy as np | ||
| import mkl_random | ||
| import pytest | ||
|
|
||
| def test_is_patched(): | ||
| """ | ||
| Test that is_patched() returns correct status. | ||
| """ | ||
| assert not mkl_random.is_patched() | ||
| mkl_random.monkey_patch(np) | ||
| assert mkl_random.is_patched() | ||
| mkl_random.restore() | ||
| assert not mkl_random.is_patched() | ||
|
|
||
| def test_monkey_patch_and_restore(): | ||
| """ | ||
| Test that monkey_patch replaces and restore brings back original functions. | ||
| """ | ||
| # Store original functions | ||
| orig_normal = np.random.normal | ||
| orig_randint = np.random.randint | ||
| orig_RandomState = np.random.RandomState | ||
|
|
||
| try: | ||
| mkl_random.monkey_patch(np) | ||
|
|
||
| # Check that functions are now different objects | ||
| assert np.random.normal is not orig_normal | ||
| assert np.random.randint is not orig_randint | ||
| assert np.random.RandomState is not orig_RandomState | ||
|
|
||
| # Check that they are from mkl_random | ||
| assert np.random.normal is mkl_random.mklrand.normal | ||
| assert np.random.RandomState is mkl_random.mklrand.RandomState | ||
|
|
||
| finally: | ||
| mkl_random.restore() | ||
|
|
||
| # Check that original functions are restored | ||
| assert mkl_random.is_patched() is False | ||
| assert np.random.normal is orig_normal | ||
| assert np.random.randint is orig_randint | ||
| assert np.random.RandomState is orig_RandomState | ||
|
|
||
| def test_context_manager(): | ||
| """ | ||
| Test that the context manager patches and automatically restores. | ||
| """ | ||
| orig_uniform = np.random.uniform | ||
| assert not mkl_random.is_patched() | ||
|
|
||
| with mkl_random.mkl_random(np): | ||
| assert mkl_random.is_patched() is True | ||
| assert np.random.uniform is not orig_uniform | ||
| # Smoke test inside context | ||
| arr = np.random.uniform(size=10) | ||
| assert arr.shape == (10,) | ||
|
|
||
| assert not mkl_random.is_patched() | ||
| assert np.random.uniform is orig_uniform | ||
|
|
||
| def test_patched_functions_callable(): | ||
| """ | ||
| Smoke test to ensure some patched functions can be called without error. | ||
| """ | ||
| mkl_random.monkey_patch(np) | ||
| try: | ||
| # These calls should now be routed to mkl_random's implementations | ||
| x = np.random.standard_normal(size=100) | ||
| assert x.shape == (100,) | ||
|
|
||
| y = np.random.randint(0, 100, size=50) | ||
| assert y.shape == (50,) | ||
| assert np.all(y >= 0) and np.all(y < 100) | ||
|
|
||
| st = np.random.RandomState(12345) | ||
| z = st.rand(10) | ||
| assert z.shape == (10,) | ||
|
|
||
| finally: | ||
| mkl_random.restore() | ||
|
|
||
| def test_patched_names(): | ||
| """ | ||
| Test that patched_names() returns a list of patched symbols. | ||
| """ | ||
| try: | ||
| mkl_random.monkey_patch(np) | ||
| names = mkl_random.patched_names() | ||
| assert isinstance(names, list) | ||
| assert len(names) > 0 | ||
| assert "normal" in names | ||
| assert "RandomState" in names | ||
| finally: | ||
| mkl_random.restore() |
There was a problem hiding this comment.
There is no test coverage for edge cases such as:
- Calling restore() when not patched (should be safe/no-op based on line 163-164 of _patch.pyx)
- Calling monkey_patch() multiple times with the same numpy module (idempotent behavior from lines 126-128)
- Calling monkey_patch() with different numpy modules sequentially
- Error handling when numpy_module doesn't have a 'random' attribute
Consider adding tests for these edge cases to ensure the implementation handles them correctly.
| -------- | ||
| >>> import numpy as np | ||
| >>> import mkl_random | ||
| >>> with mkl_random.mkl_random(): |
There was a problem hiding this comment.
The example in the docstring shows with mkl_random.mkl_random(): without passing numpy module, but the test at line 52 shows with mkl_random.mkl_random(np):. The example should clarify whether passing np is optional (defaulting to the imported numpy) or if it's recommended to be explicit. Consider updating the example to match the test pattern for consistency.
| >>> with mkl_random.mkl_random(): | |
| >>> with mkl_random.mkl_random(np): |
| for n, v in self._originals.items(): | ||
| setattr(np_random, n, v) | ||
|
|
||
| self._numpy_module = None | ||
| self._originals = {} | ||
| self._patched = [] | ||
| self._is_patched = False | ||
|
|
There was a problem hiding this comment.
The do_unpatch method doesn't handle potential exceptions when restoring attributes. If setattr() fails for any name (e.g., if the module structure changed or an attribute became read-only), the unpatch operation will be incomplete, leaving _is_patched as True and some attributes not restored. Consider wrapping the restore loop in error handling or using a try-finally to ensure state is cleaned up even if restoration fails.
| for n, v in self._originals.items(): | |
| setattr(np_random, n, v) | |
| self._numpy_module = None | |
| self._originals = {} | |
| self._patched = [] | |
| self._is_patched = False | |
| try: | |
| for n, v in self._originals.items(): | |
| setattr(np_random, n, v) | |
| finally: | |
| self._numpy_module = None | |
| self._originals = {} | |
| self._patched = [] | |
| self._is_patched = False |
Gonna need some expert's eyes on this one, but it built fine and the tests look okay.
Definitely going to have to make sure it properly interfaced numpy.random though