diff --git a/src/tree/kth_smallest_element_in_a_bst.py b/src/tree/kth_smallest_element_in_a_bst.py index 962ac8f..4a063bf 100644 --- a/src/tree/kth_smallest_element_in_a_bst.py +++ b/src/tree/kth_smallest_element_in_a_bst.py @@ -28,7 +28,6 @@ def __init__(self, val=0, left=None, right=None): class Solution: def kthSmallest(self, root, k): - self.k = k """ Find the kth smallest element in a BST. @@ -42,20 +41,23 @@ def kthSmallest(self, root, k): Time Complexity: O(h + k) where h is height Space Complexity: O(h) """ + count = k + def traverse(root): + nonlocal count if root is None: return - left_result = traverse(root.left) - # Check "is not None" so that 0 is accepted as a valid answer + left_result = traverse(root.left) + # Check "is not None" so that 0 is accepted as a valid answer if left_result is not None: return left_result - self.k -= 1 - if self.k == 0: + count -= 1 + if count == 0: return root.val - return traverse(root.right) + return traverse(root.right) return traverse(root) diff --git a/tests/test_kth_smallest_element_in_a_bst.py b/tests/test_kth_smallest_element_in_a_bst.py index 5010e3d..65c8eff 100644 --- a/tests/test_kth_smallest_element_in_a_bst.py +++ b/tests/test_kth_smallest_element_in_a_bst.py @@ -35,3 +35,20 @@ def test_example_2(self): """Test case from example 2""" root = TreeNode(5, TreeNode(3, TreeNode(2, TreeNode(1)), TreeNode(4)), TreeNode(6)) assert self.solution.kthSmallest(root, 3) == 3 + + def test_solution_reusability(self): + """Test that the same Solution instance can be reused for multiple calls""" + # First call + root1 = TreeNode(3, TreeNode(1, None, TreeNode(2)), TreeNode(4)) + result1 = self.solution.kthSmallest(root1, 1) + assert result1 == 1 + + # Second call on the same instance should work correctly + root2 = TreeNode(5, TreeNode(3, TreeNode(2, TreeNode(1)), TreeNode(4)), TreeNode(6)) + result2 = self.solution.kthSmallest(root2, 3) + assert result2 == 3 + + # Third call with different k value + root3 = TreeNode(3, TreeNode(1, None, TreeNode(2)), TreeNode(4)) + result3 = self.solution.kthSmallest(root3, 2) + assert result3 == 2