Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,9 +2048,10 @@ async def _scatter(
await asyncio.sleep(0.1)
if time() > start + timeout:
raise TimeoutError("No valid workers found")
nthreads = await self.scheduler.ncores(workers=workers)
# Exclude paused and closing_gracefully workers
nthreads = await self.scheduler.ncores_running(workers=workers)
if not nthreads:
raise ValueError("No valid workers")
raise ValueError("No valid workers found")

_, who_has, nbytes = await scatter_to_workers(
nthreads, data2, report=False, rpc=self.rpc
Expand Down
46 changes: 32 additions & 14 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3826,6 +3826,7 @@ def __init__(
"broadcast": self.broadcast,
"proxy": self.proxy,
"ncores": self.get_ncores,
"ncores_running": self.get_ncores_running,
"has_what": self.get_has_what,
"who_has": self.get_who_has,
"processing": self.get_processing,
Expand Down Expand Up @@ -5710,18 +5711,24 @@ async def scatter(
Scheduler.broadcast:
"""
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState

start = time()
while not parent._workers_dv:
await asyncio.sleep(0.2)
while True:
if workers is None:
wss = parent._running
else:
workers = [self.coerce_address(w) for w in workers]
wss = {parent._workers_dv[w] for w in workers}
wss = {ws for ws in wss if ws._status == Status.running}

if wss:
break
if time() > start + timeout:
raise TimeoutError("No workers found")
raise TimeoutError("No valid workers found")
await asyncio.sleep(0.1)

if workers is None:
ws: WorkerState
nthreads = {w: ws._nthreads for w, ws in parent._workers_dv.items()}
else:
workers = [self.coerce_address(w) for w in workers]
nthreads = {w: parent._workers_dv[w].nthreads for w in workers}
nthreads = {ws._address: ws.nthreads for ws in wss}

assert isinstance(data, dict)

Expand All @@ -5732,10 +5739,7 @@ async def scatter(
self.update_data(who_has=who_has, nbytes=nbytes, client=client)

if broadcast:
if broadcast == True: # noqa: E712
n = len(nthreads)
else:
n = broadcast
n = len(nthreads) if broadcast is True else broadcast
await self.replicate(keys=keys, workers=workers, n=n)

self.log_event(
Expand Down Expand Up @@ -6451,7 +6455,12 @@ async def replicate(

assert branching_factor > 0
async with self._lock if lock else empty_context:
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
if workers is not None:
workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
workers = {ws for ws in workers if ws._status == Status.running}
else:
workers = parent._running

if n is None:
n = len(workers)
else:
Expand Down Expand Up @@ -6989,6 +6998,15 @@ def get_ncores(self, comm=None, workers=None):
else:
return {w: ws._nthreads for w, ws in parent._workers_dv.items()}

def get_ncores_running(self, comm=None, workers=None):
parent: SchedulerState = cast(SchedulerState, self)
ncores = self.get_ncores(workers=workers)
return {
w: n
for w, n in ncores.items()
if parent._workers_dv[w].status == Status.running
}

async def get_call_stack(self, comm=None, keys=None):
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState
Expand Down
27 changes: 27 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5734,6 +5734,33 @@ def bad_fn(x):
assert y.status == "error" # not cancelled


@pytest.mark.parametrize("workers_arg", [False, True])
@pytest.mark.parametrize("direct", [False, True])
@pytest.mark.parametrize("broadcast", [False, True, 10])
@gen_cluster(client=True, nthreads=[("", 1)] * 10)
async def test_scatter_and_replicate_avoid_paused_workers(
c, s, *workers, workers_arg, direct, broadcast
):
paused_workers = [w for i, w in enumerate(workers) if i not in (3, 7)]
for w in paused_workers:
w.memory_pause_fraction = 1e-15
while any(s.workers[w.address].status != Status.paused for w in paused_workers):
await asyncio.sleep(0.01)

f = await c.scatter(
{"x": 1},
workers=[w.address for w in workers[1:-1]] if workers_arg else None,
broadcast=broadcast,
direct=direct,
)
if not broadcast:
await c.replicate(f, n=10)

expect = [i in (3, 7) for i in range(10)]
actual = [("x" in w.data) for w in workers]
assert actual == expect


@pytest.mark.xfail(reason="GH#5409 Dask-Default-Threads are frequently detected")
def test_no_threads_lingering():
if threading.active_count() < 40:
Expand Down
21 changes: 14 additions & 7 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,20 +784,27 @@ async def test_story(c, s, a, b):
assert s.story(x.key) == s.story(s.tasks[x.key])


@gen_cluster(nthreads=[], client=True)
async def test_scatter_no_workers(c, s):
@pytest.mark.parametrize("direct", [False, True])
@gen_cluster(client=True, nthreads=[])
async def test_scatter_no_workers(c, s, direct):
with pytest.raises(TimeoutError):
await s.scatter(data={"x": 1}, client="alice", timeout=0.1)

start = time()
with pytest.raises(TimeoutError):
await c.scatter(123, timeout=0.1)
await c.scatter(123, timeout=0.1, direct=direct)
assert time() < start + 1.5

w = Worker(s.address, nthreads=3)
await asyncio.gather(c.scatter(data={"y": 2}, timeout=5), w)

assert w.data["y"] == 2
fut = c.scatter({"y": 2}, timeout=5, direct=direct)
await asyncio.sleep(0.1)
async with Worker(s.address) as w:
await fut
assert w.data["y"] == 2

# Test race condition between worker init and scatter
w = Worker(s.address)
await asyncio.gather(c.scatter({"z": 3}, timeout=5, direct=direct), w)
assert w.data["z"] == 3
await w.close()


Expand Down