77import os
88import pathlib
99from collections import defaultdict
10- from contextlib import redirect_stderr , redirect_stdout
10+ from contextlib import contextmanager , redirect_stderr , redirect_stdout
1111from queue import Queue
1212from threading import Thread
1313from typing import Dict , List , Optional
@@ -248,6 +248,17 @@ def run_cell(self) -> List[IPytestResult]:
248248
249249 return test_results
250250
251+ @contextmanager
252+ def traceback_handling (self , debug : bool ):
253+ """Context manager to temporarily modify traceback behavior"""
254+ original_traceback = self .shell ._showtraceback
255+ try :
256+ if not debug :
257+ self .shell ._showtraceback = lambda * args , ** kwargs : None
258+ yield
259+ finally :
260+ self .shell ._showtraceback = original_traceback
261+
251262 @cell_magic
252263 def ipytest (self , line : str , cell : str ):
253264 """The `%%ipytest` cell magic"""
@@ -270,56 +281,53 @@ def ipytest(self, line: str, cell: str):
270281 self .threaded = True
271282 self .test_queue = Queue ()
272283
273- # If debug is in the line, then we want to show the traceback
274- if self .debug :
275- self .shell ._showtraceback = self ._orig_traceback
276- else :
277- self .shell ._showtraceback = lambda * args , ** kwargs : None
278-
279- # Get the module containing the test(s)
280- if (
281- module_name := get_module_name (
282- " " .join (line_contents ), self .shell .user_global_ns
283- )
284- ) is None :
285- raise TestModuleNotFoundError
284+ with self .traceback_handling (self .debug ):
285+ # Get the module containing the test(s)
286+ if (
287+ module_name := get_module_name (
288+ " " .join (line_contents ), self .shell .user_global_ns
289+ )
290+ ) is None :
291+ raise TestModuleNotFoundError
286292
287- self .module_name = module_name
293+ self .module_name = module_name
288294
289- # Check that the test module file exists
290- if not (
291- module_file := pathlib .Path (f"tutorial/tests/test_{ self .module_name } .py" )
292- ).exists ():
293- raise FileNotFoundError (module_file )
295+ # Check that the test module file exists
296+ if not (
297+ module_file := pathlib .Path (
298+ f"tutorial/tests/test_{ self .module_name } .py"
299+ )
300+ ).exists ():
301+ raise FileNotFoundError (module_file )
294302
295- self .module_file = module_file
303+ self .module_file = module_file
296304
297- # Run the cell
298- results = self .run_cell ()
305+ # Run the cell
306+ results = self .run_cell ()
299307
300- # If in debug mode, display debug information first
301- if self .debug :
302- debug_output = DebugOutput (
303- module_name = self .module_name ,
304- module_file = self .module_file ,
305- results = results ,
306- )
307- display (HTML (debug_output .to_html ()))
308-
309- # Parse the AST of the test module to retrieve the solution code
310- ast_parser = AstParser (self .module_file )
311- # Display the test results and the solution code
312- for result in results :
313- solution = (
314- ast_parser .get_solution_code (result .function .name )
315- if result .function and result .function .name
316- else None
317- )
318- TestResultOutput (
319- result ,
320- solution ,
321- self .shell .openai_client , # type: ignore
322- ).display_results ()
308+ # If in debug mode, display debug information first
309+ if self .debug :
310+ debug_output = DebugOutput (
311+ module_name = self .module_name ,
312+ module_file = self .module_file ,
313+ results = results ,
314+ )
315+ display (HTML (debug_output .to_html ()))
316+
317+ # Parse the AST of the test module to retrieve the solution code
318+ ast_parser = AstParser (self .module_file )
319+ # Display the test results and the solution code
320+ for result in results :
321+ solution = (
322+ ast_parser .get_solution_code (result .function .name )
323+ if result .function and result .function .name
324+ else None
325+ )
326+ TestResultOutput (
327+ result ,
328+ solution ,
329+ self .shell .openai_client , # type: ignore
330+ ).display_results ()
323331
324332
325333def load_ipython_extension (ipython ):
0 commit comments