Skip to content

Comments

task: add patch methods for mkl_random#90

Open
jharlow-intel wants to merge 1 commit intoIntelPython:dev-milestonefrom
jharlow-intel:task/patch-numpy
Open

task: add patch methods for mkl_random#90
jharlow-intel wants to merge 1 commit intoIntelPython:dev-milestonefrom
jharlow-intel:task/patch-numpy

Conversation

@jharlow-intel
Copy link
Collaborator

Gonna need some expert's eyes on this one, but it built fine and the tests look okay.

Definitely going to have to make sure it properly interfaced numpy.random though

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds monkey-patching functionality to mkl_random, allowing users to temporarily or permanently replace numpy.random functions with their mkl_random equivalents. The implementation provides both imperative (monkey_patch(), restore()) and context manager (mkl_random.mkl_random()) interfaces.

Changes:

  • Adds new _patch.pyx Cython module implementing patching logic with thread-local state tracking
  • Extends setup.py to build the new _patch extension module
  • Exports patch functions (monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random) in __init__.py
  • Adds comprehensive test suite in test_patch.py covering basic patching, restoration, and context manager functionality

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 13 comments.

File Description
setup.py Adds Extension configuration for mkl_random._patch module
mkl_random/src/_patch.pyx Implements core patching logic with thread-local storage, patch/unpatch methods, and context manager
mkl_random/init.py Exports patch-related functions to public API
mkl_random/tests/test_patch.py Adds tests for patching, restoration, context manager, and patched function behavior

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@@ -0,0 +1,275 @@
# Copyright (c) 2019, Intel Corporation
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copyright year is 2019, while other files in the codebase use 2017 (e.g., init.py, test_random.py). This inconsistency should be corrected to match the project's copyright dating convention, likely using the year the file was actually created or the standard project copyright year.

Suggested change
# Copyright (c) 2019, Intel Corporation
# Copyright (c) 2017, Intel Corporation

Copilot uses AI. Check for mistakes.
Comment on lines +86 to 92
Extension(
"mkl_random._patch",
sources=[join("mkl_random", "src", "_patch.pyx")],
include_dirs=[np.get_include()],
define_macros=defs + [("NDEBUG", None)],
language="c",
)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _patch extension includes define_macros with PY_ARRAY_UNIQUE_SYMBOL, but _patch.pyx doesn't use the NumPy C API (no cimport numpy or usage of numpy C structures). While this doesn't cause harm, these macros are unnecessary for this extension and could be removed for clarity. The extension only needs include_dirs=[np.get_include()] to compile successfully with Cython's NumPy support.

Copilot uses AI. Check for mistakes.
Comment on lines +185 to +194
_tls = threading_local()


def _is_tls_initialized():
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)


def _initialize_tls():
_tls.patch = patch()
_tls.initialized = True
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thread-local storage approach creates a mismatch between thread-local state tracking and global module patching. Since numpy.random is a module-level singleton, patching it in one thread affects all threads, but the is_patched() status is tracked per-thread. This means:

  • Thread A can patch numpy.random and _tls.patch._is_patched becomes True for Thread A
  • Thread B checks is_patched() and gets False (Thread B has its own _tls)
  • But numpy.random is actually patched globally for both threads

Consider using a module-level (not thread-local) patch tracker, with appropriate locking if thread-safety is needed.

Copilot uses AI. Check for mistakes.
Comment on lines +269 to +275
def __enter__(self):
monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict)
return self

def __exit__(self, *exc):
restore()
return False
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested context managers will incorrectly restore the patch when the inner context exits. For example:

with mkl_random.mkl_random(np):
    # numpy.random is patched
    with mkl_random.mkl_random(np):
        # still patched
        pass
    # Inner __exit__ calls restore(), unpatching numpy.random
    # But outer context still expects it to be patched

Consider implementing reference counting or a stack-based approach to handle nested contexts correctly.

Copilot uses AI. Check for mistakes.
Comment on lines +125 to +127
# If already patched, only allow idempotent re-entry for the same numpy module.
if self._is_patched:
if self._numpy_module is numpy_module:
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The identity check self._numpy_module is numpy_module may fail in some scenarios where the same module is imported differently (e.g., through different import mechanisms or in certain testing scenarios). Consider using self._numpy_module is not None and self._numpy_module.__name__ == numpy_module.__name__ for a more robust comparison, or document that the exact same module object must be passed.

Suggested change
# If already patched, only allow idempotent re-entry for the same numpy module.
if self._is_patched:
if self._numpy_module is numpy_module:
# If already patched, only allow idempotent re-entry for the same numpy module
# (by module name) to avoid failures when the same module is imported via
# different mechanisms.
if self._is_patched:
same_module = (
self._numpy_module is not None
and numpy_module is not None
and getattr(self._numpy_module, "__name__", None)
== getattr(numpy_module, "__name__", None)
)
if same_module:

Copilot uses AI. Check for mistakes.
assert "normal" in names
assert "RandomState" in names
finally:
mkl_random.restore()
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use_in_numpy function is exported in init.py and documented as a "backward-compatible alias" but has no test coverage. Consider adding a test to verify that use_in_numpy works identically to monkey_patch, ensuring the backward compatibility is maintained.

Suggested change
mkl_random.restore()
mkl_random.restore()
def test_use_in_numpy_is_alias_for_monkey_patch():
"""
Test that use_in_numpy is a backward-compatible alias for monkey_patch.
"""
assert hasattr(mkl_random, "use_in_numpy")
assert mkl_random.use_in_numpy is mkl_random.monkey_patch

Copilot uses AI. Check for mistakes.


def _is_tls_initialized():
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _is_tls_initialized() function uses getattr(_tls, "initialized", None) is not None followed by and (_tls.initialized is True). This is redundant - the second check alone (getattr(_tls, "initialized", False)) would be simpler and equivalent. Consider simplifying to: return getattr(_tls, "initialized", False)

Suggested change
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)
return getattr(_tls, "initialized", False)

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +95
import numpy as np
import mkl_random
import pytest

def test_is_patched():
"""
Test that is_patched() returns correct status.
"""
assert not mkl_random.is_patched()
mkl_random.monkey_patch(np)
assert mkl_random.is_patched()
mkl_random.restore()
assert not mkl_random.is_patched()

def test_monkey_patch_and_restore():
"""
Test that monkey_patch replaces and restore brings back original functions.
"""
# Store original functions
orig_normal = np.random.normal
orig_randint = np.random.randint
orig_RandomState = np.random.RandomState

try:
mkl_random.monkey_patch(np)

# Check that functions are now different objects
assert np.random.normal is not orig_normal
assert np.random.randint is not orig_randint
assert np.random.RandomState is not orig_RandomState

# Check that they are from mkl_random
assert np.random.normal is mkl_random.mklrand.normal
assert np.random.RandomState is mkl_random.mklrand.RandomState

finally:
mkl_random.restore()

# Check that original functions are restored
assert mkl_random.is_patched() is False
assert np.random.normal is orig_normal
assert np.random.randint is orig_randint
assert np.random.RandomState is orig_RandomState

def test_context_manager():
"""
Test that the context manager patches and automatically restores.
"""
orig_uniform = np.random.uniform
assert not mkl_random.is_patched()

with mkl_random.mkl_random(np):
assert mkl_random.is_patched() is True
assert np.random.uniform is not orig_uniform
# Smoke test inside context
arr = np.random.uniform(size=10)
assert arr.shape == (10,)

assert not mkl_random.is_patched()
assert np.random.uniform is orig_uniform

def test_patched_functions_callable():
"""
Smoke test to ensure some patched functions can be called without error.
"""
mkl_random.monkey_patch(np)
try:
# These calls should now be routed to mkl_random's implementations
x = np.random.standard_normal(size=100)
assert x.shape == (100,)

y = np.random.randint(0, 100, size=50)
assert y.shape == (50,)
assert np.all(y >= 0) and np.all(y < 100)

st = np.random.RandomState(12345)
z = st.rand(10)
assert z.shape == (10,)

finally:
mkl_random.restore()

def test_patched_names():
"""
Test that patched_names() returns a list of patched symbols.
"""
try:
mkl_random.monkey_patch(np)
names = mkl_random.patched_names()
assert isinstance(names, list)
assert len(names) > 0
assert "normal" in names
assert "RandomState" in names
finally:
mkl_random.restore()
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no test coverage for edge cases such as:

  • Calling restore() when not patched (should be safe/no-op based on line 163-164 of _patch.pyx)
  • Calling monkey_patch() multiple times with the same numpy module (idempotent behavior from lines 126-128)
  • Calling monkey_patch() with different numpy modules sequentially
  • Error handling when numpy_module doesn't have a 'random' attribute

Consider adding tests for these edge cases to ensure the implementation handles them correctly.

Copilot uses AI. Check for mistakes.
--------
>>> import numpy as np
>>> import mkl_random
>>> with mkl_random.mkl_random():
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example in the docstring shows with mkl_random.mkl_random(): without passing numpy module, but the test at line 52 shows with mkl_random.mkl_random(np):. The example should clarify whether passing np is optional (defaulting to the imported numpy) or if it's recommended to be explicit. Consider updating the example to match the test pattern for consistency.

Suggested change
>>> with mkl_random.mkl_random():
>>> with mkl_random.mkl_random(np):

Copilot uses AI. Check for mistakes.
Comment on lines +167 to +174
for n, v in self._originals.items():
setattr(np_random, n, v)

self._numpy_module = None
self._originals = {}
self._patched = []
self._is_patched = False

Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The do_unpatch method doesn't handle potential exceptions when restoring attributes. If setattr() fails for any name (e.g., if the module structure changed or an attribute became read-only), the unpatch operation will be incomplete, leaving _is_patched as True and some attributes not restored. Consider wrapping the restore loop in error handling or using a try-finally to ensure state is cleaned up even if restoration fails.

Suggested change
for n, v in self._originals.items():
setattr(np_random, n, v)
self._numpy_module = None
self._originals = {}
self._patched = []
self._is_patched = False
try:
for n, v in self._originals.items():
setattr(np_random, n, v)
finally:
self._numpy_module = None
self._originals = {}
self._patched = []
self._is_patched = False

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant