diff --git a/s3transfer/__init__.py b/s3transfer/__init__.py index e8ff66f0..6fe61a1e 100644 --- a/s3transfer/__init__.py +++ b/s3transfer/__init__.py @@ -533,6 +533,11 @@ def __init__( def download_file( self, bucket, key, filename, object_size, extra_args, callback=None + ): + self.download_versioned_file(bucket, key, filename, object_size, None, extra_args, callback=callback) + + def download_versioned_file( + self, bucket, key, filename, object_size, object_version, extra_args, callback=None ): with self._executor_cls(max_workers=2) as controller: # 1 thread for the future that manages the uploading of files @@ -543,6 +548,7 @@ def download_file( key, filename, object_size, + object_version, callback, ) parts_future = controller.submit(download_parts_handler) @@ -563,7 +569,7 @@ def _process_future_results(self, futures): future.result() def _download_file_as_future( - self, bucket, key, filename, object_size, callback + self, bucket, key, filename, object_size, object_version, callback ): part_size = self._config.multipart_chunksize num_parts = int(math.ceil(object_size / float(part_size))) @@ -573,6 +579,7 @@ def _download_file_as_future( bucket, key, filename, + object_version, part_size, num_parts, callback, @@ -593,7 +600,7 @@ def _calculate_range_param(self, part_size, part_index, num_parts): return range_param def _download_range( - self, bucket, key, filename, part_size, num_parts, callback, part_index + self, bucket, key, filename, object_version, part_size, num_parts, callback, part_index ): try: range_param = self._calculate_range_param( @@ -605,9 +612,10 @@ def _download_range( for i in range(max_attempts): try: logger.debug("Making get_object call.") - response = self._client.get_object( - Bucket=bucket, Key=key, Range=range_param - ) + kwargs = {'Bucket': bucket, 'Key': key, 'Range': range_param} + if object_version is not None: + kwargs['VersionId'] = object_version + response = self._client.get_object(**kwargs) streaming_body = StreamReaderProgress( response['Body'], callback ) @@ -781,11 +789,23 @@ def download_file( if extra_args is None: extra_args = {} self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS) - object_size = self._object_size(bucket, key, extra_args) + object_meta = self._object_meta(bucket, key, extra_args) + object_size = object_meta['ContentLength'] + + # If the latest version of the file changes during a multipart download and we make + # multiple concurrent ranged downloads, then each download may see a different version. + # To avoid this, we specify a common version for all. If the bucket does not have + # version, then there is nothing that can be done, and we specify no version. + object_version = object_meta.get('VersionId') + if object_version is not None: + logger.debug("Using version ID %s for %s/%s", object_version, bucket, key) + else: + logger.debug("Not using version ID for %s/%s", bucket, key) + temp_filename = filename + os.extsep + random_file_extension() try: self._download_file( - bucket, key, temp_filename, object_size, extra_args, callback + bucket, key, temp_filename, object_size, object_version, extra_args, callback ) except Exception: logger.debug( @@ -800,11 +820,11 @@ def download_file( self._osutil.rename_file(temp_filename, filename) def _download_file( - self, bucket, key, filename, object_size, extra_args, callback + self, bucket, key, filename, object_size, object_version, extra_args, callback ): if object_size >= self._config.multipart_threshold: self._ranged_download( - bucket, key, filename, object_size, extra_args, callback + bucket, key, filename, object_size, object_version, extra_args, callback ) else: self._get_object(bucket, key, filename, extra_args, callback) @@ -818,13 +838,13 @@ def _validate_all_known_args(self, actual, allowed): ) def _ranged_download( - self, bucket, key, filename, object_size, extra_args, callback + self, bucket, key, filename, object_size, object_version, extra_args, callback ): downloader = MultipartDownloader( self._client, self._config, self._osutil ) - downloader.download_file( - bucket, key, filename, object_size, extra_args, callback + downloader.download_versioned_file( + bucket, key, filename, object_size, object_version, extra_args, callback ) def _get_object(self, bucket, key, filename, extra_args, callback): @@ -865,10 +885,8 @@ def _do_get_object(self, bucket, key, filename, extra_args, callback): for chunk in iter(lambda: streaming_body.read(8192), b''): f.write(chunk) - def _object_size(self, bucket, key, extra_args): - return self._client.head_object(Bucket=bucket, Key=key, **extra_args)[ - 'ContentLength' - ] + def _object_meta(self, bucket, key, extra_args): + return self._client.head_object(Bucket=bucket, Key=key, **extra_args) def _multipart_upload(self, filename, bucket, key, callback, extra_args): uploader = MultipartUploader(self._client, self._config, self._osutil) diff --git a/tests/unit/test_s3transfer.py b/tests/unit/test_s3transfer.py index a2f46a13..010c541f 100644 --- a/tests/unit/test_s3transfer.py +++ b/tests/unit/test_s3transfer.py @@ -382,6 +382,7 @@ class TestMultipartDownloader(unittest.TestCase): def test_multipart_download_uses_correct_client_calls(self): client = mock.Mock() + version_id = '123' response_body = b'foobarbaz' client.get_object.return_value = {'Body': BytesIO(response_body)} @@ -396,6 +397,14 @@ def test_multipart_download_uses_correct_client_calls(self): Range='bytes=0-', Bucket='bucket', Key='key' ) + downloader.download_versioned_file( + 'bucket', 'key', 'filename', len(response_body), version_id, {} + ) + + client.get_object.assert_called_with( + Range='bytes=0-', Bucket='bucket', Key='key', VersionId=version_id + ) + def test_multipart_download_with_multiple_parts(self): client = mock.Mock() response_body = b'foobarbaz' @@ -608,12 +617,39 @@ def test_uses_multipart_download_when_over_threshold(self): 'bucket', 'key', 'filename', callback=callback ) - downloader.return_value.download_file.assert_called_with( + downloader.return_value.download_versioned_file.assert_called_with( + # Note how we're downloading to a temporary random file. + 'bucket', + 'key', + 'filename.RANDOM', + over_multipart_threshold, + None, + {}, + callback, + ) + + def test_multipart_download_uses_version(self): + with mock.patch('s3transfer.MultipartDownloader') as downloader: + osutil = InMemoryOSLayer({}) + over_multipart_threshold = 100 * 1024 * 1024 + version_id = 'version-id' + transfer = S3Transfer(self.client, osutil=osutil) + callback = mock.sentinel.CALLBACK + self.client.head_object.return_value = { + 'ContentLength': over_multipart_threshold, + 'VersionId': version_id, + } + transfer.download_file( + 'bucket', 'key', 'filename', callback=callback + ) + + downloader.return_value.download_versioned_file.assert_called_with( # Note how we're downloading to a temporary random file. 'bucket', 'key', 'filename.RANDOM', over_multipart_threshold, + version_id, {}, callback, )