Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pytests/test_basicoperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,8 @@ def test_Sum2D_forceflat(par):
assert y.shape == (par["ny"],)
assert xadj.shape == (par["ny"], par["nx"])

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="Operators have conflicting forceflat"):
Sop_True * Sop_False.H
error_message = str(exception_info.value)
assert "Operators have conflicting forceflat" in error_message

Sop = Sop_True * Sop_None.H
assert Sop.forceflat is True
Expand Down
8 changes: 2 additions & 6 deletions pytests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ def test_VStack_incosistent_columns(par):
"""
G1 = np.random.normal(0, 10, (par["ny"], par["nx"])).astype(par["dtype"])
G2 = np.random.normal(0, 10, (par["ny"], par["nx"] + 1)).astype(par["dtype"])
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="different number of columns"):
VStack(
[MatrixMult(G1, dtype=par["dtype"]), MatrixMult(G2, dtype=par["dtype"])],
dtype=par["dtype"],
)
error_message = str(exception_info.value)
assert "different number of columns" in error_message


@pytest.mark.parametrize("par", [(par1)])
Expand All @@ -55,13 +53,11 @@ def test_HStack_incosistent_rows(par):
"""
G1 = np.random.normal(0, 10, (par["ny"], par["nx"])).astype(par["dtype"])
G2 = np.random.normal(0, 10, (par["ny"] + 1, par["nx"])).astype(par["dtype"])
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="different number of rows"):
HStack(
[MatrixMult(G1, dtype=par["dtype"]), MatrixMult(G2, dtype=par["dtype"])],
dtype=par["dtype"],
)
error_message = str(exception_info.value)
assert "different number of rows" in error_message


@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
Expand Down
4 changes: 1 addition & 3 deletions pytests/test_dwts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@
@pytest.mark.parametrize("par", [(par1)])
def test_unknown_wavelet(par):
"""Check error is raised if unknown wavelet is chosen is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="not in family set"):
_ = DWT(dims=par["nt"], wavelet="foo")
error_message = str(exception_info.value)
assert "not in family set" in error_message


@pytest.mark.skipif(
Expand Down
12 changes: 3 additions & 9 deletions pytests/test_ffts.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,38 +289,32 @@ def _choose_random_axes(ndim, n_choices=2):
@pytest.mark.parametrize("par", [par1])
def test_unknown_engine(par):
"""Check error is raised if unknown engine is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be"):
_ = FFT(
dims=(par["nt"],),
nfft=par["nfft"],
sampling=0.005,
real=par["real"],
engine="foo",
)
error_message = str(exception_info.value)
assert "engine must be" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be"):
_ = FFT2D(
dims=(par["nx"], par["nt"]),
nfft=(par["nfft"], par["nfft"]),
sampling=0.005,
real=par["real"],
engine="foo",
)
error_message = str(exception_info.value)
assert "engine must be" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be"):
_ = FFTND(
dims=(par["ny"], par["nx"], par["nt"]),
nfft=(par["nfft"], par["nfft"], par["nfft"]),
sampling=0.005,
real=par["real"],
engine="foo",
)
error_message = str(exception_info.value)
assert "engine must be" in error_message


dtype_precision = [
Expand Down
8 changes: 2 additions & 6 deletions pytests/test_fourierradon.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,14 @@

def test_unknown_engine2D():
"""Check error is raised if unknown engine is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be"):
_ = FourierRadon2D(None, None, None, None, engine="foo")
error_message = str(exception_info.value)
assert "engine must be" in error_message


def test_unknown_engine3D():
"""Check error is raised if unknown engine is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be"):
_ = FourierRadon3D(None, None, None, None, None, None, engine="foo")
error_message = str(exception_info.value)
assert "engine must be" in error_message


@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
Expand Down
4 changes: 1 addition & 3 deletions pytests/test_functionoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,5 @@ def forward_f(x):
assert_array_equal(F_x, G_x)

# check error is raised when applying the adjoint
with pytest.raises(NotImplementedError) as exception_info:
with pytest.raises(NotImplementedError, match="Adjoint not implemented"):
_ = Fop.H @ y
error_message = str(exception_info.value)
assert "Adjoint not implemented" in error_message
8 changes: 2 additions & 6 deletions pytests/test_linearoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,10 @@ def test_non_flattened_arrays(par):
assert_array_equal(Y_S, (S @ D @ X_1d).reshape((*S.dimsd, -1)))

with pylops.disabled_ndarray_multiplication():
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="only be applied to 1D"):
D @ x_nd
error_message = str(exception_info.value)
assert "only be applied to 1D" in error_message
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="only be applied to 1D"):
D @ X_nd
error_message = str(exception_info.value)
assert "only be applied to 1D" in error_message


@pytest.mark.parametrize("par", [(par1), (par2j)])
Expand Down
4 changes: 1 addition & 3 deletions pytests/test_lsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@
)
def test_unknown_mode():
"""Check error is raised if unknown mode is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="method must be analytic,"):
_ = LSM(z, x, t, s2d, r2d, 0, np.ones(3), 1, mode="foo")
error_message = str(exception_info.value)
assert "method must be analytic," in error_message


@pytest.mark.skipif(
Expand Down
32 changes: 8 additions & 24 deletions pytests/test_nonstatconvolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,92 +102,76 @@
@pytest.mark.parametrize("par", [(par_2d)])
def test_even_filter(par):
"""Check error is raised if filter has even size"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="filters hs must have odd length"):
_ = NonStationaryConvolve1D(
dims=par["nx"],
hs=h1ns[..., :-1],
ih=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)),
)
error_message = str(exception_info.value)
assert "filters hs must have odd length" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="filters hs must have odd length"):
_ = NonStationaryConvolve2D(
dims=(par["nx"], par["nz"]),
hs=h2ns[..., :-1],
ihx=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)),
ihz=(int(par["nz"] // 4), int(3 * par["nz"] // 4)),
)
error_message = str(exception_info.value)
assert "filters hs must have odd length" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="filters hs must have odd length"):
_ = NonStationaryFilters1D(
inp=np.arange(par["nx"]),
hsize=nfilts[0] - 1,
ih=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)),
)
error_message = str(exception_info.value)
assert "filters hs must have odd length" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="filters hs must have odd length"):
_ = NonStationaryFilters2D(
inp=np.ones((par["nx"], par["nz"])),
hshape=(nfilts[0] - 1, nfilts[1] - 1),
ihx=(int(par["nx"] // 4), int(2 * par["nx"] // 4), int(3 * par["nx"] // 4)),
ihz=(int(par["nz"] // 4), int(3 * par["nz"] // 4)),
)
error_message = str(exception_info.value)
assert "filters hs must have odd length" in error_message


@pytest.mark.parametrize("par", [(par_2d)])
def test_ih_irregular(par):
"""Check error is raised if ih (or ihx/ihz) are irregularly sampled"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="must be regularly sampled"):
_ = NonStationaryConvolve1D(
dims=par["nx"],
hs=h1ns,
ih=(10, 11, 15),
)
error_message = str(exception_info.value)
assert "must be regularly sampled" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="must be regularly sampled"):
_ = NonStationaryConvolve2D(
dims=(par["nx"], par["nz"]),
hs=h2ns,
ihx=(10, 11, 15),
ihz=(int(par["nz"] // 4), int(3 * par["nz"] // 4)),
)
error_message = str(exception_info.value)
assert "must be regularly sampled" in error_message


@pytest.mark.parametrize("par", [(par_2d)])
def test_unknown_engine_2d(par):
"""Check error is raised if unknown engine is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be numpy"):
_ = NonStationaryConvolve2D(
dims=(par["nx"], par["nz"]),
hs=h2ns,
ihx=(int(par["nx"] // 3), int(2 * par["nx"] // 3)),
ihz=(int(par["nz"] // 3), int(2 * par["nz"] // 3)),
engine="foo",
)
error_message = str(exception_info.value)
assert "engine must be numpy" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be numpy"):
_ = NonStationaryFilters2D(
inp=np.ones((par["nx"], par["nz"])),
hshape=(nfilts[0] - 1, nfilts[1] - 1),
ihx=(int(par["nx"] // 3), int(2 * par["nx"] // 3)),
ihz=(int(par["nz"] // 3), int(2 * par["nz"] // 3)),
engine="foo",
)
error_message = str(exception_info.value)
assert "engine must be numpy" in error_message


@pytest.mark.parametrize("par", [(par1_1d), (par2_1d)])
Expand Down
8 changes: 2 additions & 6 deletions pytests/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,15 @@
@pytest.mark.parametrize("par", [(par1)])
def test_Pad_1d_negative(par):
"""Check error is raised when pad has negative number"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="Padding must be positive"):
_ = Pad(dims=par["ny"], pad=(-10, 0))
error_message = str(exception_info.value)
assert "Padding must be positive" in error_message


@pytest.mark.parametrize("par", [(par1)])
def test_Pad_2d_negative(par):
"""Check error is raised when pad has negative number for 2d"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="Padding must be positive"):
_ = Pad(dims=(par["ny"], par["nx"]), pad=((-10, 0), (3, -5)))
error_message = str(exception_info.value)
assert "Padding must be positive" in error_message


@pytest.mark.parametrize("par", [(par1)])
Expand Down
8 changes: 2 additions & 6 deletions pytests/test_radon.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,11 @@
)
def test_unknown_engine():
"""Check error is raised if unknown engine is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be numpy"):
_ = Radon2D(None, None, None, engine="foo")
error_message = str(exception_info.value)
assert "engine must be numpy" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be numpy"):
_ = Radon3D(None, None, None, None, None, engine="foo")
error_message = str(exception_info.value)
assert "engine must be numpy" in error_message


@pytest.mark.skipif(
Expand Down
4 changes: 1 addition & 3 deletions pytests/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@
@pytest.mark.parametrize("par", [(par1)])
def test_unknown_engine(par):
"""Check error is raised if unknown engine is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="engine must be numpy"):
_ = Shift(
par["nt"],
1.0,
engine="foo",
)
error_message = str(exception_info.value)
assert "engine must be numpy" in error_message


@pytest.mark.parametrize("par", [(par1), (par1j)])
Expand Down
24 changes: 6 additions & 18 deletions pytests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@

def test_IRLS_unknown_kind():
"""Check error is raised if unknown kind is passed"""
with pytest.raises(NotImplementedError) as exception_info:
with pytest.raises(NotImplementedError, match="kind must be model"):
_ = irls(Identity(5), np.ones(5), 10, kind="foo")
error_message = str(exception_info.value)
assert "kind must be model" in error_message


@pytest.mark.parametrize("par", [(par3), (par4), (par3j), (par4j)])
Expand Down Expand Up @@ -333,28 +331,20 @@ def test_OMP_stopping(par):

def test_ISTA_FISTA_unknown_threshkind():
"""Check error is raised if unknown threshkind is passed"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="threshkind must be"):
_ = ista(Identity(5), np.ones(5), 10, threshkind="foo")
error_message = str(exception_info.value)
assert "threshkind must be" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="threshkind must be"):
_ = fista(Identity(5), np.ones(5), 10, threshkind="foo")
error_message = str(exception_info.value)
assert "threshkind must be" in error_message


def test_ISTA_FISTA_missing_perc():
"""Check error is raised if perc=None and threshkind is percentile based"""
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="Provide a percentile"):
_ = ista(Identity(5), np.ones(5), 10, perc=None, threshkind="soft-percentile")
error_message = str(exception_info.value)
assert "Provide a percentile" in error_message

with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="Provide a percentile"):
_ = fista(Identity(5), np.ones(5), 10, perc=None, threshkind="soft-percentile")
error_message = str(exception_info.value)
assert "Provide a percentile" in error_message


@pytest.mark.parametrize("par", [(par1), (par3), (par5), (par1j), (par3j), (par5j)])
Expand All @@ -373,7 +363,7 @@ def test_ISTA_FISTA_alpha_too_high(par):

for solver in [ista, fista]:
# check that exception is raised
with pytest.raises(ValueError) as exception_info:
with pytest.raises(ValueError, match="due to residual increasing"):
_, _, _ = solver(
Aop,
y,
Expand All @@ -383,8 +373,6 @@ def test_ISTA_FISTA_alpha_too_high(par):
monitorres=True,
tol=0,
)
error_message = str(exception_info.value)
assert "due to residual increasing" in error_message

# check that CostNanInfCallback catches cost=np.inf
_, _, cost = solver(
Expand Down