-
Notifications
You must be signed in to change notification settings - Fork 71
Open
Description
I am trying to quantize an image into a tensor of indices, then decode from it, but I am getting float latents.
My full code:
from huggingface_hub import hf_hub_download
from diffusers import VQModel
from pathlib import Path
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt
# Ensure numpy is imported
import numpy as np
# Download the necessary files
files = ["vqvae/config.json", "vqvae/diffusion_pytorch_model.fp16.safetensors", "vqvae/diffusion_pytorch_model.bin"]
downloaded_files = [hf_hub_download(repo_id="microsoft/vq-diffusion-ithq", filename=filename) for filename in files]
vqvae_dir = Path(downloaded_files[0]).parent
# Load the VQModel
vqvae = VQModel.from_pretrained(vqvae_dir)
# Load and preprocess the image
image = Image.open("reference.jpg").resize((512, 512)).convert("RGB")
# image = image.resize((256, 256)) # Resize if necessary
image_tensor = torch.tensor(np.array(image)).float() / 255.0
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # Convert to (1, 3, H, W)
# Encode the image into quantized latents
latents = vqvae.encode(image_tensor)
quantized_latents = latents.latents # Get quantized latents
# Print the quantized latents
print(latents.latents)
print(latents.latents.shape, latents.latents.dtype)
# Decode the latents back into an image
# The output of `vqvae.decode()` needs proper handling to access the tensor
decoded_output = vqvae.decode(latents.latents)
restored_image_tensor = decoded_output.sample
# Squeeze and permute for correct shape
restored_image = restored_image_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
restored_image = (restored_image * 255).clip(0, 255).astype("uint8") # Rescale to 0-255
# Display the original and restored images
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(image)
ax[0].set_title("Original Image")
ax[0].axis("off")
ax[1].imshow(restored_image)
ax[1].set_title("Restored Image")
ax[1].axis("off")
plt.show()The encode/decode processes work, but I get latents as floats.

tensor([[[[ 0.4225, 0.2981, 0.3149, ..., 0.2489, 0.2679, 0.4513],
[ 0.4001, 0.3239, 0.2657, ..., 0.4148, 0.3454, 0.3503],
[ 0.4261, 0.3237, 0.3494, ..., 0.3703, 0.3447, 0.3494],
...,
[ 0.4650, 0.1186, 0.2364, ..., 0.2306, 0.2927, 0.2936],
[ 0.5248, 0.3457, 0.2609, ..., 0.3057, 0.3070, 0.3021],
[ 0.4576, 0.4258, 0.3637, ..., 0.3781, 0.3804, 0.3607]],
[[-1.0222, -0.7395, -0.6708, ..., -0.5544, -0.5323, -1.0479],
[-0.7135, -0.4130, -0.3063, ..., -0.6664, -0.3887, -0.5997],
[-0.7237, -0.4842, -0.5760, ..., -0.6525, -0.6987, -0.9177],
...,
[-0.9906, -0.5066, -0.5344, ..., -0.2479, -0.3053, -0.5853],
[-0.8421, -0.4733, -0.2319, ..., -0.6079, -0.5194, -0.5823],
[-0.9038, -0.8252, -0.6160, ..., -0.9664, -0.7653, -0.8778]],
[[-0.3252, -0.1989, -0.0736, ..., -0.0459, -0.1201, -0.6306],
[-0.2184, -0.2671, -0.3236, ..., -0.2283, -0.0942, -0.5065],
[-0.3181, -0.0854, -0.1833, ..., -0.3519, -0.1705, -0.2260],
...,
[-0.2587, -0.1918, -0.1453, ..., -0.0321, -0.1507, -0.2337],
[-0.5398, -0.0599, -0.3429, ..., -0.0813, -0.1139, -0.4409],
[-0.6048, -0.3892, -0.3652, ..., -0.5147, -0.4351, -0.3231]],
...,
[[-0.3657, -0.2115, 0.0081, ..., -0.1041, -0.0889, -0.5111],
[-0.2293, -0.1542, -0.1134, ..., -0.2491, -0.0108, -0.3427],
[-0.1626, 0.0304, -0.0673, ..., -0.2488, -0.3555, -0.3401],
...,
[-0.2499, -0.2363, 0.0428, ..., 0.1410, -0.0597, -0.1154],
[-0.3382, 0.0374, -0.1705, ..., -0.0545, -0.0871, -0.2866],
[-0.3992, -0.3094, -0.2202, ..., -0.4959, -0.3723, -0.3031]],
[[-1.4538, -1.0330, -1.2147, ..., -0.9883, -1.0343, -1.4522],
[-1.4043, -1.1521, -0.7814, ..., -1.3682, -1.3603, -1.1118],
[-1.4565, -1.2345, -1.5170, ..., -1.3795, -1.3626, -1.3186],
...,
[-1.7420, -0.5224, -1.0992, ..., -1.1148, -1.2031, -1.1217],
[-1.6478, -1.3150, -0.8264, ..., -1.1239, -1.1825, -0.9351],
[-1.4976, -1.4068, -1.2721, ..., -1.1976, -1.2473, -1.3688]],
[[-0.7946, -0.6282, -0.5310, ..., -0.7992, -0.6840, -0.4670],
[-0.9726, -0.9194, -0.5834, ..., -1.1396, -0.9707, -0.6007],
[-0.6839, -0.7608, -0.9145, ..., -0.7800, -1.1498, -0.7977],
...,
[-0.8311, -0.4279, -0.3612, ..., -0.5752, -0.7894, -0.4850],
[-0.7154, -0.8026, -0.6878, ..., -0.6780, -0.6881, -0.4457],
[-0.4830, -0.7109, -0.6619, ..., -0.5163, -0.6582, -0.6847]]]],
grad_fn=<ConvolutionBackward0>)
torch.Size([1, 128, 64, 64]) torch.float32Metadata
Metadata
Assignees
Labels
No labels