diff --git a/prouter/api/job.py b/prouter/api/job.py index 7e234a6..79424fd 100644 --- a/prouter/api/job.py +++ b/prouter/api/job.py @@ -214,7 +214,29 @@ async def download_archive( if destination is None: return arcpath with tarfile.open(arcpath, 'r') as arc: - arc.extractall(destination) + + import os + + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(arc, destination) arcpath.unlink() except Exception: if arcpath.exists(): diff --git a/prouter/client/job_client.py b/prouter/client/job_client.py index 9472635..4ab4ffb 100644 --- a/prouter/client/job_client.py +++ b/prouter/client/job_client.py @@ -281,7 +281,29 @@ async def download_directory(self, source, destination, exclude=None): unpack_path = pathlib.Path(tempdir).joinpath('unpacked') unpack_path.mkdir() # pylint: disable=no-member with tarfile.open(arcpath, 'r') as arc: - arc.extractall(unpack_path) + + import os + + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(arc, unpack_path) unpack_path = unpack_path.joinpath(source) for src in unpack_path.glob('**/*'): # pylint: disable=no-member dst = destination.joinpath(src.relative_to(unpack_path)) diff --git a/prouter/test/test_jobs.py b/prouter/test/test_jobs.py index be31e3e..bde38b6 100644 --- a/prouter/test/test_jobs.py +++ b/prouter/test/test_jobs.py @@ -221,7 +221,26 @@ async def test_upload_download_archive(event_loop, router_process, tmpdir): arc.write(chunk) chunk = await response.content.readany() with tarfile.open(DOWNLOAD_ARC, 'r') as arc: - arc.extractall(DOWNLOAD_PATH) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(arc, DOWNLOAD_PATH) assert len(os.listdir(DOWNLOAD_PATH)) == 1 with open(DATA_PATH.joinpath('something.py')) as uploaded: with open(DOWNLOAD_PATH.joinpath('something.py')) as downloaded: