Skip to content

How to get quantized latents (indices)? #50

@AmitMY

Description

@AmitMY

I am trying to quantize an image into a tensor of indices, then decode from it, but I am getting float latents.

My full code:

from huggingface_hub import hf_hub_download
from diffusers import VQModel
from pathlib import Path
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt

# Ensure numpy is imported
import numpy as np

# Download the necessary files
files = ["vqvae/config.json", "vqvae/diffusion_pytorch_model.fp16.safetensors", "vqvae/diffusion_pytorch_model.bin"]
downloaded_files = [hf_hub_download(repo_id="microsoft/vq-diffusion-ithq", filename=filename) for filename in files]
vqvae_dir = Path(downloaded_files[0]).parent

# Load the VQModel
vqvae = VQModel.from_pretrained(vqvae_dir)

# Load and preprocess the image
image = Image.open("reference.jpg").resize((512, 512)).convert("RGB")
# image = image.resize((256, 256))  # Resize if necessary
image_tensor = torch.tensor(np.array(image)).float() / 255.0
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)  # Convert to (1, 3, H, W)

# Encode the image into quantized latents
latents = vqvae.encode(image_tensor)
quantized_latents = latents.latents  # Get quantized latents

# Print the quantized latents
print(latents.latents)
print(latents.latents.shape, latents.latents.dtype)

# Decode the latents back into an image
# The output of `vqvae.decode()` needs proper handling to access the tensor
decoded_output = vqvae.decode(latents.latents)
restored_image_tensor = decoded_output.sample

# Squeeze and permute for correct shape
restored_image = restored_image_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
restored_image = (restored_image * 255).clip(0, 255).astype("uint8")  # Rescale to 0-255

# Display the original and restored images
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(image)
ax[0].set_title("Original Image")
ax[0].axis("off")

ax[1].imshow(restored_image)
ax[1].set_title("Restored Image")
ax[1].axis("off")

plt.show()

The encode/decode processes work, but I get latents as floats.
Image

tensor([[[[ 0.4225,  0.2981,  0.3149,  ...,  0.2489,  0.2679,  0.4513],
          [ 0.4001,  0.3239,  0.2657,  ...,  0.4148,  0.3454,  0.3503],
          [ 0.4261,  0.3237,  0.3494,  ...,  0.3703,  0.3447,  0.3494],
          ...,
          [ 0.4650,  0.1186,  0.2364,  ...,  0.2306,  0.2927,  0.2936],
          [ 0.5248,  0.3457,  0.2609,  ...,  0.3057,  0.3070,  0.3021],
          [ 0.4576,  0.4258,  0.3637,  ...,  0.3781,  0.3804,  0.3607]],

         [[-1.0222, -0.7395, -0.6708,  ..., -0.5544, -0.5323, -1.0479],
          [-0.7135, -0.4130, -0.3063,  ..., -0.6664, -0.3887, -0.5997],
          [-0.7237, -0.4842, -0.5760,  ..., -0.6525, -0.6987, -0.9177],
          ...,
          [-0.9906, -0.5066, -0.5344,  ..., -0.2479, -0.3053, -0.5853],
          [-0.8421, -0.4733, -0.2319,  ..., -0.6079, -0.5194, -0.5823],
          [-0.9038, -0.8252, -0.6160,  ..., -0.9664, -0.7653, -0.8778]],

         [[-0.3252, -0.1989, -0.0736,  ..., -0.0459, -0.1201, -0.6306],
          [-0.2184, -0.2671, -0.3236,  ..., -0.2283, -0.0942, -0.5065],
          [-0.3181, -0.0854, -0.1833,  ..., -0.3519, -0.1705, -0.2260],
          ...,
          [-0.2587, -0.1918, -0.1453,  ..., -0.0321, -0.1507, -0.2337],
          [-0.5398, -0.0599, -0.3429,  ..., -0.0813, -0.1139, -0.4409],
          [-0.6048, -0.3892, -0.3652,  ..., -0.5147, -0.4351, -0.3231]],

         ...,

         [[-0.3657, -0.2115,  0.0081,  ..., -0.1041, -0.0889, -0.5111],
          [-0.2293, -0.1542, -0.1134,  ..., -0.2491, -0.0108, -0.3427],
          [-0.1626,  0.0304, -0.0673,  ..., -0.2488, -0.3555, -0.3401],
          ...,
          [-0.2499, -0.2363,  0.0428,  ...,  0.1410, -0.0597, -0.1154],
          [-0.3382,  0.0374, -0.1705,  ..., -0.0545, -0.0871, -0.2866],
          [-0.3992, -0.3094, -0.2202,  ..., -0.4959, -0.3723, -0.3031]],

         [[-1.4538, -1.0330, -1.2147,  ..., -0.9883, -1.0343, -1.4522],
          [-1.4043, -1.1521, -0.7814,  ..., -1.3682, -1.3603, -1.1118],
          [-1.4565, -1.2345, -1.5170,  ..., -1.3795, -1.3626, -1.3186],
          ...,
          [-1.7420, -0.5224, -1.0992,  ..., -1.1148, -1.2031, -1.1217],
          [-1.6478, -1.3150, -0.8264,  ..., -1.1239, -1.1825, -0.9351],
          [-1.4976, -1.4068, -1.2721,  ..., -1.1976, -1.2473, -1.3688]],

         [[-0.7946, -0.6282, -0.5310,  ..., -0.7992, -0.6840, -0.4670],
          [-0.9726, -0.9194, -0.5834,  ..., -1.1396, -0.9707, -0.6007],
          [-0.6839, -0.7608, -0.9145,  ..., -0.7800, -1.1498, -0.7977],
          ...,
          [-0.8311, -0.4279, -0.3612,  ..., -0.5752, -0.7894, -0.4850],
          [-0.7154, -0.8026, -0.6878,  ..., -0.6780, -0.6881, -0.4457],
          [-0.4830, -0.7109, -0.6619,  ..., -0.5163, -0.6582, -0.6847]]]],
       grad_fn=<ConvolutionBackward0>)
torch.Size([1, 128, 64, 64]) torch.float32

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions