diff --git a/docs/help.md b/docs/help.md index 509bcda..670e0fb 100644 --- a/docs/help.md +++ b/docs/help.md @@ -187,6 +187,7 @@ Options: -o, --output DIRECTORY Write to this directory instead of one matching the bucket name -p, --pattern TEXT Glob patterns for files to download, e.g. '*/*.js' + -s, --silent Don't show progress bar --access-key TEXT AWS access key ID --secret-key TEXT AWS secret access key --session-token TEXT AWS session token diff --git a/docs/other-commands.md b/docs/other-commands.md index ada25a9..81eb632 100644 --- a/docs/other-commands.md +++ b/docs/other-commands.md @@ -393,6 +393,8 @@ You can pass one or more `--pattern` or `-p` options to download files matching Here the `*` wildcard will match any sequence of characters, including `/`. `?` will match a single character. +A progress bar will be shown by default. Use `-s` or `--silent` to hide it. + ## set-cors-policy and get-cors-policy You can set the [CORS policy](https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html) for a bucket using the `set-cors-policy` command. S3 CORS policies are set at the bucket level - they cannot be set for individual items. diff --git a/s3_credentials/cli.py b/s3_credentials/cli.py index 1b3819a..2c92430 100644 --- a/s3_credentials/cli.py +++ b/s3_credentials/cli.py @@ -1052,8 +1052,9 @@ def get_object(bucket, key, output, **boto_options): multiple=True, help="Glob patterns for files to download, e.g. '*/*.js'", ) +@click.option("silent", "-s", "--silent", is_flag=True, help="Don't show progress bar") @common_boto3_options -def get_objects(bucket, keys, output, patterns, **boto_options): +def get_objects(bucket, keys, output, patterns, silent, **boto_options): """ Download multiple objects from an S3 bucket @@ -1076,35 +1077,67 @@ def get_objects(bucket, keys, output, patterns, **boto_options): # If user specified keys and no patterns, use the keys they specified keys_to_download = list(keys) + key_sizes = {} + + if keys and not silent: + # Get sizes of those keys for progress bar + for key in keys: + try: + key_sizes[key] = s3.head_object(Bucket=bucket, Key=key)["ContentLength"] + except botocore.exceptions.ClientError: + # Ignore errors - they will be reported later + key_sizes[key] = 0 if (not keys) or patterns: # Fetch all keys, then filter them if --pattern - all_keys = [ - obj["Key"] - for obj in paginate(s3, "list_objects_v2", "Contents", Bucket=bucket) - ] + all_key_infos = list(paginate(s3, "list_objects_v2", "Contents", Bucket=bucket)) if patterns: filtered = [] for pattern in patterns: - filtered.extend(fnmatch.filter(all_keys, pattern)) + filtered.extend( + fnmatch.filter((k["Key"] for k in all_key_infos), pattern) + ) keys_to_download.extend(filtered) else: - keys_to_download.extend(all_keys) + keys_to_download.extend(k["Key"] for k in all_key_infos) + if not silent: + key_set = set(keys_to_download) + for key in all_key_infos: + if key["Key"] in key_set: + key_sizes[key["Key"]] = key["Size"] output_dir = pathlib.Path(output or ".") if not output_dir.exists(): output_dir.mkdir(parents=True) errors = [] - for key in keys_to_download: + + def download(key, callback=None): # Ensure directory for key exists key_dir = (output_dir / key).parent if not key_dir.exists(): key_dir.mkdir(parents=True) try: - s3.download_file(bucket, key, str(output_dir / key)) + s3.download_file(bucket, key, str(output_dir / key), Callback=callback) except botocore.exceptions.ClientError as e: errors.append("Not found: {}".format(key)) + + if not silent: + total_size = sum(key_sizes.values()) + with click.progressbar( + length=total_size, + label="Downloading {} ({} file{})".format( + format_bytes(total_size), + len(key_sizes), + "s" if len(key_sizes) != 1 else "", + ), + ) as bar: + for key in keys_to_download: + download(key, bar.update) + else: + for key in keys_to_download: + download(key) + if errors: raise click.ClickException("\n".join(errors)) @@ -1269,3 +1302,12 @@ def fix_json(row): for key, value in row.items() ] ) + + +def format_bytes(size): + for x in ("bytes", "KB", "MB", "GB", "TB"): + if size < 1024: + return "{:3.1f} {}".format(size, x) + size /= 1024 + + return size