Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions text2image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from tensorflow import keras
from stable_diffusion_tf.stable_diffusion import Text2Image
import argparse
Expand All @@ -17,7 +18,7 @@
"--output",
type=str,
nargs="?",
default="output",
default=None,
help="where to save the output image",
)

Expand Down Expand Up @@ -68,6 +69,7 @@

args = parser.parse_args()


if args.mp:
print("Using mixed precision.")
keras.mixed_precision.set_global_policy("mixed_float16")
Expand All @@ -82,12 +84,31 @@
seed=args.seed,
)

fname = args.output
if not args.output:
# When not providing an output filename, create something using
# the prompt and a timestamp to prevent overwriting of existing files
# Get a timestamp without microseconds, and replace colons with dots
timestamp = datetime.now().isoformat("T").split(".")[0].replace(":", ".")

# Create a 'slug' with only valid alphanumeric characters and spaces to
# prevent filename issues
slug_prompt = "".join(c for c in args.prompt if (c.isalnum() or c in "_- "))

# Trim the length to 100 characters to prevent issues with maximum pathlength
slug_prompt = slug_prompt[0:100]

# And create the final filename
fname = f"{timestamp} - {slug_prompt}"
else:
# Output filename provided, use that
fname = args.output

if fname.endswith(".png"):
fname = fname[:-4]

if args.batch_size == 1:
Image.fromarray(img[0]).save(args.output + ".png")
print(f"saved at {args.output}.png")
Image.fromarray(img[0]).save(f"{fname}.png")
print(f"saved at {fname}.png")
else:
for i in range(args.batch_size):
fname_i = f"{fname}_{i}.png"
Expand Down