Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 36 additions & 68 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,95 +2929,63 @@ def __reduce__(self):
await x


@gen_cluster(
client=True,
Worker=Nanny,
worker_kwargs={"memory_limit": "1 GiB"},
config={"distributed.worker.memory.rebalance.sender-min": 0.3},
)
async def test_rebalance(c, s, *_):
# Set rebalance() to work predictably on small amounts of managed memory. By default, it
# uses optimistic memory, which would only be possible to test by allocating very large
# amounts of managed memory, so that they would hide variations in unmanaged memory.
REBALANCE_MANAGED_CONFIG = {
"distributed.worker.memory.rebalance.measure": "managed",
"distributed.worker.memory.rebalance.sender-min": 0,
"distributed.worker.memory.rebalance.sender-recipient-gap": 0,
}


@gen_cluster(client=True, config=REBALANCE_MANAGED_CONFIG)
async def test_rebalance(c, s, a, b):
"""Test Client.rebalance(). These are just to test the Client wrapper around
Scheduler.rebalance(); for more thorough tests on the latter see test_scheduler.py.
"""
# We used nannies to have separate processes for each worker
a, b = s.workers

# Generate 10 buffers worth 512 MiB total on worker a. This sends its memory
# utilisation slightly above 50% (after counting unmanaged) which is above the
# distributed.worker.memory.rebalance.sender-min threshold.
futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a])
await wait(futures)
# Wait for heartbeats
while s.memory.process < 2 ** 29:
await asyncio.sleep(0.1)

assert await c.run(lambda dask_worker: len(dask_worker.data)) == {a: 10, b: 0}

futures = await c.scatter(range(100), workers=[a.address])
assert len(a.data) == 100
assert len(b.data) == 0
await c.rebalance()

ndata = await c.run(lambda dask_worker: len(dask_worker.data))
# Allow for some uncertainty as the unmanaged memory is not stable
assert sum(ndata.values()) == 10
assert 3 <= ndata[a] <= 7
assert 3 <= ndata[b] <= 7
assert len(a.data) == 50
assert len(b.data) == 50


@gen_cluster(
nthreads=[("127.0.0.1", 1)] * 3,
client=True,
Worker=Nanny,
worker_kwargs={"memory_limit": "1 GiB"},
)
async def test_rebalance_workers_and_keys(client, s, *_):
@gen_cluster(nthreads=[("", 1)] * 3, client=True, config=REBALANCE_MANAGED_CONFIG)
async def test_rebalance_workers_and_keys(client, s, a, b, c):
"""Test Client.rebalance(). These are just to test the Client wrapper around
Scheduler.rebalance(); for more thorough tests on the latter see test_scheduler.py.
"""
a, b, c = s.workers
futures = client.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a])
await wait(futures)
# Wait for heartbeats
while s.memory.process < 2 ** 29:
await asyncio.sleep(0.1)
futures = await client.scatter(range(100), workers=[a.address])
assert (len(a.data), len(b.data), len(c.data)) == (100, 0, 0)

# Passing empty iterables is not the same as omitting the arguments
await client.rebalance([])
await client.rebalance(workers=[])
assert await client.run(lambda dask_worker: len(dask_worker.data)) == {
a: 10,
b: 0,
c: 0,
}
assert (len(a.data), len(b.data), len(c.data)) == (100, 0, 0)

# Limit rebalancing to two arbitrary keys and two arbitrary workers.
await client.rebalance([futures[3], futures[7]], [a, b])
assert await client.run(lambda dask_worker: len(dask_worker.data)) == {
a: 8,
b: 2,
c: 0,
}
await client.rebalance([futures[3], futures[7]], [a.address, b.address])
assert (len(a.data), len(b.data), len(c.data)) == (98, 2, 0)

with pytest.raises(KeyError):
await client.rebalance(workers=["notexist"])


def test_rebalance_sync():
# can't use the 'c' fixture because we need workers to run in a separate process
with Client(n_workers=2, memory_limit="1 GiB", dashboard_address=":0") as c:
s = c.cluster.scheduler
a, b = (ws.address for ws in s.workers.values())
futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a])
wait(futures)
# Wait for heartbeat
while s.memory.process < 2 ** 29:
sleep(0.1)

assert c.run(lambda dask_worker: len(dask_worker.data)) == {a: 10, b: 0}
c.rebalance()
ndata = c.run(lambda dask_worker: len(dask_worker.data))
# Allow for some uncertainty as the unmanaged memory is not stable
assert sum(ndata.values()) == 10
assert 3 <= ndata[a] <= 7
assert 3 <= ndata[b] <= 7
with dask.config.set(REBALANCE_MANAGED_CONFIG):
with Client(n_workers=2, processes=False, dashboard_address=":0") as c:
s = c.cluster.scheduler
a = c.cluster.workers[0]
b = c.cluster.workers[1]
futures = c.scatter(range(100), workers=[a.address])

assert len(a.data) == 100
assert len(b.data) == 0
c.rebalance()
assert len(a.data) == 50
assert len(b.data) == 50


@gen_cluster(client=True)
Expand Down
148 changes: 64 additions & 84 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2675,52 +2675,66 @@ async def assert_ndata(client, by_addr, total=None):
worker_kwargs={"memory_limit": "1 GiB"},
config={"distributed.worker.memory.rebalance.sender-min": 0.3},
)
async def test_rebalance(c, s, *_):
async def test_rebalance(c, s, a, b):
# We used nannies to have separate processes for each worker
a, b = s.workers

# Generate 10 buffers worth 512 MiB total on worker a. This sends its memory
# Generate 500 buffers worth 512 MiB total on worker a. This sends its memory
# utilisation slightly above 50% (after counting unmanaged) which is above the
# distributed.worker.memory.rebalance.sender-min threshold.
futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a])
futures = c.map(
lambda _: "x" * (2 ** 29 // 500), range(500), workers=[a.worker_address]
)
await wait(futures)
# Wait for heartbeats
await assert_memory(s, "process", 512, 1024)
await assert_ndata(c, {a: 10, b: 0})
await assert_ndata(c, {a.worker_address: 500, b.worker_address: 0})
await s.rebalance()
# Allow for some uncertainty as the unmanaged memory is not stable
await assert_ndata(c, {a: (3, 7), b: (3, 7)}, total=10)
await assert_ndata(
c, {a.worker_address: (50, 450), b.worker_address: (50, 450)}, total=500
)

# rebalance() when there is nothing to do
await s.rebalance()
await assert_ndata(c, {a: (3, 7), b: (3, 7)}, total=10)
await assert_ndata(
c, {a.worker_address: (50, 450), b.worker_address: (50, 450)}, total=500
)


@gen_cluster(
nthreads=[("127.0.0.1", 1)] * 3,
client=True,
Worker=Nanny,
worker_kwargs={"memory_limit": "1 GiB"},
)
async def test_rebalance_workers_and_keys(client, s, *_):
a, b, c = s.workers
futures = client.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a])
await wait(futures)
# Wait for heartbeats
await assert_memory(s, "process", 512, 1024)
# Set rebalance() to work predictably on small amounts of managed memory. By default, it
# uses optimistic memory, which would only be possible to test by allocating very large
# amounts of managed memory, so that they would hide variations in unmanaged memory.
REBALANCE_MANAGED_CONFIG = {
"distributed.worker.memory.rebalance.measure": "managed",
"distributed.worker.memory.rebalance.sender-min": 0,
"distributed.worker.memory.rebalance.sender-recipient-gap": 0,
}


@gen_cluster(client=True, config=REBALANCE_MANAGED_CONFIG)
async def test_rebalance_managed_memory(c, s, a, b):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a duplicate of test_client.py::test_rebalance (besides calling rebalance on the client vs the scheduler)? I understand wanting to unit-test the actual rebalance logic on the scheduler, and test that the client is invoking it correctly, but testing so much of the rebalance logic in the client tests too feels a little redundant? (Same goes for other tests here.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All tests on test_client specifically test the Client API. It's very easy to accidentally forget a parameter or to convert a None to an empty list (which have different meanings).

Additionally, some behaviour is slightly different on Client.rebalance vs Scheduler.rebalance (not my design):

  • test_rebalance_unprepared: Client waits for unfinished tasks. Scheduler expects all tasks to be already finished.
  • test_rebalance_raises_on_explicit_missing_data: exception handling is specifically implemented client side

futures = await c.scatter(range(100), workers=[a.address])
assert len(a.data) == 100
assert len(b.data) == 0
await s.rebalance()
assert len(a.data) == 50
assert len(b.data) == 50


@gen_cluster(nthreads=[("", 1)] * 3, client=True, config=REBALANCE_MANAGED_CONFIG)
async def test_rebalance_workers_and_keys(client, s, a, b, c):
futures = await client.scatter(range(100), workers=[a.address])
assert (len(a.data), len(b.data), len(c.data)) == (100, 0, 0)

# Passing empty iterables is not the same as omitting the arguments
await s.rebalance(keys=[])
await assert_ndata(client, {a: 10, b: 0, c: 0})
await s.rebalance(workers=[])
await assert_ndata(client, {a: 10, b: 0, c: 0})
# Limit operation to workers that have nothing to do
await s.rebalance(workers=[b, c])
await assert_ndata(client, {a: 10, b: 0, c: 0})
assert (len(a.data), len(b.data), len(c.data)) == (100, 0, 0)

# Limit rebalancing to two arbitrary keys and two arbitrary workers
await s.rebalance(keys=[futures[3].key, futures[7].key], workers=[a, b])
await assert_ndata(client, {a: 8, b: 2, c: 0}, total=10)
# Limit rebalancing to two arbitrary keys and two arbitrary workers.
await s.rebalance(
keys=[futures[3].key, futures[7].key], workers=[a.address, b.address]
)
assert (len(a.data), len(b.data), len(c.data)) == (98, 2, 0)

with pytest.raises(KeyError):
await s.rebalance(workers=["notexist"])
Expand All @@ -2746,24 +2760,20 @@ async def test_rebalance_missing_data2(c, s, a, b):


@pytest.mark.parametrize("explicit", [False, True])
@gen_cluster(client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1 GiB"})
async def test_rebalance_raises_missing_data3(c, s, *_, explicit):
@gen_cluster(client=True, config=REBALANCE_MANAGED_CONFIG)
async def test_rebalance_raises_missing_data3(c, s, a, b, explicit):
"""keys exist when the sync part of rebalance runs, but are gone by the time the
actual data movement runs.
There is an error message only if the keys are explicitly listed in the API call.
"""
a, _ = s.workers
futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a])
await wait(futures)
# Wait for heartbeats
await assert_memory(s, "process", 512, 1024)
futures = await c.scatter(range(100), workers=[a.address])

if explicit:
keys = [f.key for f in futures]
del futures
out = await s.rebalance(keys=keys)
assert out["status"] == "partial-fail"
assert 1 <= len(out["keys"]) <= 10
assert 1 <= len(out["keys"]) <= 100
else:
del futures
out = await s.rebalance()
Expand All @@ -2775,48 +2785,19 @@ async def test_rebalance_no_workers(s):
await s.rebalance()


@gen_cluster(
client=True,
Worker=Nanny,
worker_kwargs={"memory_limit": "1000 MiB"},
config={
"distributed.worker.memory.rebalance.measure": "managed",
"distributed.worker.memory.rebalance.sender-min": 0.3,
},
)
async def test_rebalance_managed_memory(c, s, *_):
a, b = s.workers
# Generate 100 buffers worth 400 MiB total on worker a. This sends its memory
# utilisation to exactly 40%, ignoring unmanaged, which is above the
# distributed.worker.memory.rebalance.sender-min threshold.
futures = c.map(lambda _: "x" * (2 ** 22), range(100), workers=[a])
await wait(futures)
# Even if we're just using managed memory, which is instantaneously accounted for as
# soon as the tasks finish, MemoryState.managed is still capped by the process
# memory, so we need to wait for the heartbeat.
await assert_memory(s, "managed", 400, 401)
await assert_ndata(c, {a: 100, b: 0})
await s.rebalance()
# We can expect an exact, stable result because we are completely bypassing the
# unpredictability of unmanaged memory.
await assert_ndata(c, {a: 62, b: 38})


@gen_cluster(
client=True,
worker_kwargs={"memory_limit": 0},
config={"distributed.worker.memory.rebalance.measure": "managed"},
)
async def test_rebalance_no_limit(c, s, a, b):
# See notes in test_rebalance_managed_memory
futures = c.map(lambda _: "x", range(100), workers=[a.address])
await wait(futures)
# No reason to wait for memory here as we're allocating hundreds of bytes, so
# there's plenty of unmanaged process memory to pad it out
await assert_ndata(c, {a.address: 100, b.address: 0})
futures = await c.scatter(range(100), workers=[a.address])
assert len(a.data) == 100
assert len(b.data) == 0
await s.rebalance()
# Disabling memory_limit made us ignore all % thresholds set in the config
await assert_ndata(c, {a.address: 50, b.address: 50})
assert len(a.data) == 50
assert len(b.data) == 50


@gen_cluster(
Expand All @@ -2829,33 +2810,32 @@ async def test_rebalance_no_limit(c, s, a, b):
"distributed.worker.memory.rebalance.recipient-max": 0.1,
},
)
async def test_rebalance_no_recipients(c, s, *_):
async def test_rebalance_no_recipients(c, s, a, b):
"""There are sender workers, but no recipient workers"""
a, b = s.workers
fut_a = c.map(lambda _: "x" * (2 ** 20), range(250), workers=[a]) # 25%
fut_b = c.map(lambda _: "x" * (2 ** 20), range(100), workers=[b]) # 10%
# Fill 25% of the memory of a and 10% of the memory of b
fut_a = c.map(lambda _: "x" * (2 ** 20), range(250), workers=[a.worker_address])
fut_b = c.map(lambda _: "x" * (2 ** 20), range(100), workers=[b.worker_address])
await wait(fut_a + fut_b)
await assert_memory(s, "managed", 350, 351)
await assert_ndata(c, {a: 250, b: 100})
await assert_ndata(c, {a.worker_address: 250, b.worker_address: 100})
await s.rebalance()
await assert_ndata(c, {a: 250, b: 100})
await assert_ndata(c, {a.worker_address: 250, b.worker_address: 100})


@gen_cluster(
nthreads=[("127.0.0.1", 1)] * 3,
nthreads=[("", 1)] * 3,
client=True,
worker_kwargs={"memory_limit": 0},
config={"distributed.worker.memory.rebalance.measure": "managed"},
)
async def test_rebalance_skip_recipient(client, s, a, b, c):
"""A recipient is skipped because it already holds a copy of the key to be sent"""
futures = client.map(lambda _: "x", range(10), workers=[a.address])
await wait(futures)
futures = await client.scatter(range(10), workers=[a.address])
await client.replicate(futures[0:2], workers=[a.address, b.address])
await client.replicate(futures[2:4], workers=[a.address, c.address])
await assert_ndata(client, {a.address: 10, b.address: 2, c.address: 2})
assert (len(a.data), len(b.data), len(c.data)) == (10, 2, 2)
await client.rebalance(futures[:2])
await assert_ndata(client, {a.address: 8, b.address: 2, c.address: 4})
assert (len(a.data), len(b.data), len(c.data)) == (8, 2, 4)


@gen_cluster(
Expand All @@ -2865,12 +2845,12 @@ async def test_rebalance_skip_recipient(client, s, a, b, c):
)
async def test_rebalance_skip_all_recipients(c, s, a, b):
"""All recipients are skipped because they already hold copies"""
futures = c.map(lambda _: "x", range(10), workers=[a.address])
futures = await c.scatter(range(10), workers=[a.address])
await wait(futures)
await c.replicate([futures[0]])
await assert_ndata(c, {a.address: 10, b.address: 1})
assert (len(a.data), len(b.data)) == (10, 1)
await c.rebalance(futures[:2])
await assert_ndata(c, {a.address: 9, b.address: 2})
assert (len(a.data), len(b.data)) == (9, 2)


@gen_cluster(
Expand Down
Loading