From ccf175fd5bfa1b53c5ffc388abc56a281752562b Mon Sep 17 00:00:00 2001 From: mrava87 Date: Tue, 20 Jan 2026 22:45:01 +0000 Subject: [PATCH] test: change with pytest.raises code pattern --- pytests/test_basicoperators.py | 4 +--- pytests/test_combine.py | 8 ++------ pytests/test_dwts.py | 4 +--- pytests/test_ffts.py | 12 +++--------- pytests/test_fourierradon.py | 8 ++------ pytests/test_functionoperator.py | 4 +--- pytests/test_linearoperator.py | 8 ++------ pytests/test_lsm.py | 4 +--- pytests/test_nonstatconvolve.py | 32 ++++++++------------------------ pytests/test_pad.py | 8 ++------ pytests/test_radon.py | 8 ++------ pytests/test_shift.py | 4 +--- pytests/test_sparsity.py | 24 ++++++------------------ 13 files changed, 32 insertions(+), 96 deletions(-) diff --git a/pytests/test_basicoperators.py b/pytests/test_basicoperators.py index 564e8551..447cfa95 100644 --- a/pytests/test_basicoperators.py +++ b/pytests/test_basicoperators.py @@ -592,10 +592,8 @@ def test_Sum2D_forceflat(par): assert y.shape == (par["ny"],) assert xadj.shape == (par["ny"], par["nx"]) - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="Operators have conflicting forceflat"): Sop_True * Sop_False.H - error_message = str(exception_info.value) - assert "Operators have conflicting forceflat" in error_message Sop = Sop_True * Sop_None.H assert Sop.forceflat is True diff --git a/pytests/test_combine.py b/pytests/test_combine.py index ee0324a5..ea7fdacc 100644 --- a/pytests/test_combine.py +++ b/pytests/test_combine.py @@ -39,13 +39,11 @@ def test_VStack_incosistent_columns(par): """ G1 = np.random.normal(0, 10, (par["ny"], par["nx"])).astype(par["dtype"]) G2 = np.random.normal(0, 10, (par["ny"], par["nx"] + 1)).astype(par["dtype"]) - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="different number of columns"): VStack( [MatrixMult(G1, dtype=par["dtype"]), MatrixMult(G2, dtype=par["dtype"])], dtype=par["dtype"], ) - error_message = str(exception_info.value) - assert "different number of columns" in error_message @pytest.mark.parametrize("par", [(par1)]) @@ -55,13 +53,11 @@ def test_HStack_incosistent_rows(par): """ G1 = np.random.normal(0, 10, (par["ny"], par["nx"])).astype(par["dtype"]) G2 = np.random.normal(0, 10, (par["ny"] + 1, par["nx"])).astype(par["dtype"]) - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="different number of rows"): HStack( [MatrixMult(G1, dtype=par["dtype"]), MatrixMult(G2, dtype=par["dtype"])], dtype=par["dtype"], ) - error_message = str(exception_info.value) - assert "different number of rows" in error_message @pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)]) diff --git a/pytests/test_dwts.py b/pytests/test_dwts.py index 7f0bf6ac..efd071ef 100644 --- a/pytests/test_dwts.py +++ b/pytests/test_dwts.py @@ -29,10 +29,8 @@ @pytest.mark.parametrize("par", [(par1)]) def test_unknown_wavelet(par): """Check error is raised if unknown wavelet is chosen is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="not in family set"): _ = DWT(dims=par["nt"], wavelet="foo") - error_message = str(exception_info.value) - assert "not in family set" in error_message @pytest.mark.skipif( diff --git a/pytests/test_ffts.py b/pytests/test_ffts.py index 1472c1f7..df6d6c02 100644 --- a/pytests/test_ffts.py +++ b/pytests/test_ffts.py @@ -289,7 +289,7 @@ def _choose_random_axes(ndim, n_choices=2): @pytest.mark.parametrize("par", [par1]) def test_unknown_engine(par): """Check error is raised if unknown engine is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be"): _ = FFT( dims=(par["nt"],), nfft=par["nfft"], @@ -297,10 +297,8 @@ def test_unknown_engine(par): real=par["real"], engine="foo", ) - error_message = str(exception_info.value) - assert "engine must be" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be"): _ = FFT2D( dims=(par["nx"], par["nt"]), nfft=(par["nfft"], par["nfft"]), @@ -308,10 +306,8 @@ def test_unknown_engine(par): real=par["real"], engine="foo", ) - error_message = str(exception_info.value) - assert "engine must be" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be"): _ = FFTND( dims=(par["ny"], par["nx"], par["nt"]), nfft=(par["nfft"], par["nfft"], par["nfft"]), @@ -319,8 +315,6 @@ def test_unknown_engine(par): real=par["real"], engine="foo", ) - error_message = str(exception_info.value) - assert "engine must be" in error_message dtype_precision = [ diff --git a/pytests/test_fourierradon.py b/pytests/test_fourierradon.py index 0666cc24..3f860ef8 100644 --- a/pytests/test_fourierradon.py +++ b/pytests/test_fourierradon.py @@ -67,18 +67,14 @@ def test_unknown_engine2D(): """Check error is raised if unknown engine is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be"): _ = FourierRadon2D(None, None, None, None, engine="foo") - error_message = str(exception_info.value) - assert "engine must be" in error_message def test_unknown_engine3D(): """Check error is raised if unknown engine is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be"): _ = FourierRadon3D(None, None, None, None, None, None, engine="foo") - error_message = str(exception_info.value) - assert "engine must be" in error_message @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)]) diff --git a/pytests/test_functionoperator.py b/pytests/test_functionoperator.py index 7dc1ac42..6d976fc9 100644 --- a/pytests/test_functionoperator.py +++ b/pytests/test_functionoperator.py @@ -129,7 +129,5 @@ def forward_f(x): assert_array_equal(F_x, G_x) # check error is raised when applying the adjoint - with pytest.raises(NotImplementedError) as exception_info: + with pytest.raises(NotImplementedError, match="Adjoint not implemented"): _ = Fop.H @ y - error_message = str(exception_info.value) - assert "Adjoint not implemented" in error_message diff --git a/pytests/test_linearoperator.py b/pytests/test_linearoperator.py index 06e08194..e413ee38 100644 --- a/pytests/test_linearoperator.py +++ b/pytests/test_linearoperator.py @@ -354,14 +354,10 @@ def test_non_flattened_arrays(par): assert_array_equal(Y_S, (S @ D @ X_1d).reshape((*S.dimsd, -1))) with pylops.disabled_ndarray_multiplication(): - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="only be applied to 1D"): D @ x_nd - error_message = str(exception_info.value) - assert "only be applied to 1D" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="only be applied to 1D"): D @ X_nd - error_message = str(exception_info.value) - assert "only be applied to 1D" in error_message @pytest.mark.parametrize("par", [(par1), (par2j)]) diff --git a/pytests/test_lsm.py b/pytests/test_lsm.py index 20ed90f7..abb0feaf 100644 --- a/pytests/test_lsm.py +++ b/pytests/test_lsm.py @@ -63,10 +63,8 @@ ) def test_unknown_mode(): """Check error is raised if unknown mode is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="method must be analytic,"): _ = LSM(z, x, t, s2d, r2d, 0, np.ones(3), 1, mode="foo") - error_message = str(exception_info.value) - assert "method must be analytic," in error_message @pytest.mark.skipif( diff --git a/pytests/test_nonstatconvolve.py b/pytests/test_nonstatconvolve.py index afdda892..ed87ce45 100644 --- a/pytests/test_nonstatconvolve.py +++ b/pytests/test_nonstatconvolve.py @@ -102,72 +102,60 @@ @pytest.mark.parametrize("par", [(par_2d)]) def test_even_filter(par): """Check error is raised if filter has even size""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="filters hs must have odd length"): _ = NonStationaryConvolve1D( dims=par["nx"], hs=h1ns[..., :-1], ih=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)), ) - error_message = str(exception_info.value) - assert "filters hs must have odd length" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="filters hs must have odd length"): _ = NonStationaryConvolve2D( dims=(par["nx"], par["nz"]), hs=h2ns[..., :-1], ihx=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)), ihz=(int(par["nz"] // 4), int(3 * par["nz"] // 4)), ) - error_message = str(exception_info.value) - assert "filters hs must have odd length" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="filters hs must have odd length"): _ = NonStationaryFilters1D( inp=np.arange(par["nx"]), hsize=nfilts[0] - 1, ih=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)), ) - error_message = str(exception_info.value) - assert "filters hs must have odd length" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="filters hs must have odd length"): _ = NonStationaryFilters2D( inp=np.ones((par["nx"], par["nz"])), hshape=(nfilts[0] - 1, nfilts[1] - 1), ihx=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)), ihz=(int(par["nz"] // 4), int(3 * par["nz"] // 4)), ) - error_message = str(exception_info.value) - assert "filters hs must have odd length" in error_message @pytest.mark.parametrize("par", [(par_2d)]) def test_ih_irregular(par): """Check error is raised if ih (or ihx/ihz) are irregularly sampled""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="must be regularly sampled"): _ = NonStationaryConvolve1D( dims=par["nx"], hs=h1ns, ih=(10, 11, 15), ) - error_message = str(exception_info.value) - assert "must be regularly sampled" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="must be regularly sampled"): _ = NonStationaryConvolve2D( dims=(par["nx"], par["nz"]), hs=h2ns, ihx=(10, 11, 15), ihz=(int(par["nz"] // 4), int(3 * par["nz"] // 4)), ) - error_message = str(exception_info.value) - assert "must be regularly sampled" in error_message @pytest.mark.parametrize("par", [(par_2d)]) def test_unknown_engine_2d(par): """Check error is raised if unknown engine is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be numpy"): _ = NonStationaryConvolve2D( dims=(par["nx"], par["nz"]), hs=h2ns, @@ -175,10 +163,8 @@ def test_unknown_engine_2d(par): ihz=(int(par["nz"] // 3), int(2 * par["nz"] // 3)), engine="foo", ) - error_message = str(exception_info.value) - assert "engine must be numpy" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be numpy"): _ = NonStationaryFilters2D( inp=np.ones((par["nx"], par["nz"])), hshape=(nfilts[0] - 1, nfilts[1] - 1), @@ -186,8 +172,6 @@ def test_unknown_engine_2d(par): ihz=(int(par["nz"] // 3), int(2 * par["nz"] // 3)), engine="foo", ) - error_message = str(exception_info.value) - assert "engine must be numpy" in error_message @pytest.mark.parametrize("par", [(par1_1d), (par2_1d)]) diff --git a/pytests/test_pad.py b/pytests/test_pad.py index addde1ef..acd20d49 100644 --- a/pytests/test_pad.py +++ b/pytests/test_pad.py @@ -24,19 +24,15 @@ @pytest.mark.parametrize("par", [(par1)]) def test_Pad_1d_negative(par): """Check error is raised when pad has negative number""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="Padding must be positive"): _ = Pad(dims=par["ny"], pad=(-10, 0)) - error_message = str(exception_info.value) - assert "Padding must be positive" in error_message @pytest.mark.parametrize("par", [(par1)]) def test_Pad_2d_negative(par): """Check error is raised when pad has negative number for 2d""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="Padding must be positive"): _ = Pad(dims=(par["ny"], par["nx"]), pad=((-10, 0), (3, -5))) - error_message = str(exception_info.value) - assert "Padding must be positive" in error_message @pytest.mark.parametrize("par", [(par1)]) diff --git a/pytests/test_radon.py b/pytests/test_radon.py index e577f5c5..48fdc280 100644 --- a/pytests/test_radon.py +++ b/pytests/test_radon.py @@ -119,15 +119,11 @@ ) def test_unknown_engine(): """Check error is raised if unknown engine is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be numpy"): _ = Radon2D(None, None, None, engine="foo") - error_message = str(exception_info.value) - assert "engine must be numpy" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be numpy"): _ = Radon3D(None, None, None, None, None, engine="foo") - error_message = str(exception_info.value) - assert "engine must be numpy" in error_message @pytest.mark.skipif( diff --git a/pytests/test_shift.py b/pytests/test_shift.py index 74ed3aa6..fd08549b 100644 --- a/pytests/test_shift.py +++ b/pytests/test_shift.py @@ -51,14 +51,12 @@ @pytest.mark.parametrize("par", [(par1)]) def test_unknown_engine(par): """Check error is raised if unknown engine is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="engine must be numpy"): _ = Shift( par["nt"], 1.0, engine="foo", ) - error_message = str(exception_info.value) - assert "engine must be numpy" in error_message @pytest.mark.parametrize("par", [(par1), (par1j)]) diff --git a/pytests/test_sparsity.py b/pytests/test_sparsity.py index 7f379cdd..1fd37e41 100644 --- a/pytests/test_sparsity.py +++ b/pytests/test_sparsity.py @@ -95,10 +95,8 @@ def test_IRLS_unknown_kind(): """Check error is raised if unknown kind is passed""" - with pytest.raises(NotImplementedError) as exception_info: + with pytest.raises(NotImplementedError, match="kind must be model"): _ = irls(Identity(5), np.ones(5), 10, kind="foo") - error_message = str(exception_info.value) - assert "kind must be model" in error_message @pytest.mark.parametrize("par", [(par3), (par4), (par3j), (par4j)]) @@ -333,28 +331,20 @@ def test_OMP_stopping(par): def test_ISTA_FISTA_unknown_threshkind(): """Check error is raised if unknown threshkind is passed""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="threshkind must be"): _ = ista(Identity(5), np.ones(5), 10, threshkind="foo") - error_message = str(exception_info.value) - assert "threshkind must be" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="threshkind must be"): _ = fista(Identity(5), np.ones(5), 10, threshkind="foo") - error_message = str(exception_info.value) - assert "threshkind must be" in error_message def test_ISTA_FISTA_missing_perc(): """Check error is raised if perc=None and threshkind is percentile based""" - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="Provide a percentile"): _ = ista(Identity(5), np.ones(5), 10, perc=None, threshkind="soft-percentile") - error_message = str(exception_info.value) - assert "Provide a percentile" in error_message - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="Provide a percentile"): _ = fista(Identity(5), np.ones(5), 10, perc=None, threshkind="soft-percentile") - error_message = str(exception_info.value) - assert "Provide a percentile" in error_message @pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)]) @@ -373,7 +363,7 @@ def test_ISTA_FISTA_alpha_too_high(par): for solver in [ista, fista]: # check that exception is raised - with pytest.raises(ValueError) as exception_info: + with pytest.raises(ValueError, match="due to residual increasing"): _, _, _ = solver( Aop, y, @@ -383,8 +373,6 @@ def test_ISTA_FISTA_alpha_too_high(par): monitorres=True, tol=0, ) - error_message = str(exception_info.value) - assert "due to residual increasing" in error_message # check that CostNanInfCallback catches cost=np.inf _, _, cost = solver(