Skip to content

How to get quantized latents (indices)? #50

@AmitMY

Description

@AmitMY

I am trying to quantize an image into a tensor of indices, then decode from it, but I am getting float latents.

My full code:

from huggingface_hub import hf_hub_download
from diffusers import VQModel
from pathlib import Path
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt

# Ensure numpy is imported
import numpy as np

# Download the necessary files
files = ["vqvae/config.json", "vqvae/diffusion_pytorch_model.fp16.safetensors", "vqvae/diffusion_pytorch_model.bin"]
downloaded_files = [hf_hub_download(repo_id="microsoft/vq-diffusion-ithq", filename=filename) for filename in files]
vqvae_dir = Path(downloaded_files[0]).parent

# Load the VQModel
vqvae = VQModel.from_pretrained(vqvae_dir)

# Load and preprocess the image
image = Image.open("reference.jpg").resize((512, 512)).convert("RGB")
# image = image.resize((256, 256))  # Resize if necessary
image_tensor = torch.tensor(np.array(image)).float() / 255.0
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)  # Convert to (1, 3, H, W)

# Encode the image into quantized latents
latents = vqvae.encode(image_tensor)
quantized_latents = latents.latents  # Get quantized latents

# Print the quantized latents
print(latents.latents)
print(latents.latents.shape, latents.latents.dtype)

# Decode the latents back into an image
# The output of `vqvae.decode()` needs proper handling to access the tensor
decoded_output = vqvae.decode(latents.latents)
restored_image_tensor = decoded_output.sample

# Squeeze and permute for correct shape
restored_image = restored_image_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
restored_image = (restored_image * 255).clip(0, 255).astype("uint8")  # Rescale to 0-255

# Display the original and restored images
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(image)
ax[0].set_title("Original Image")
ax[0].axis("off")

ax[1].imshow(restored_image)
ax[1].set_title("Restored Image")
ax[1].axis("off")

plt.show()

The encode/decode processes work, but I get latents as floats.
Image

tensor([[[[ 0.4225,  0.2981,  0.3149,  ...,  0.2489,  0.2679,  0.4513],
          [ 0.4001,  0.3239,  0.2657,  ...,  0.4148,  0.3454,  0.3503],
          [ 0.4261,  0.3237,  0.3494,  ...,  0.3703,  0.3447,  0.3494],
          ...,
          [ 0.4650,  0.1186,  0.2364,  ...,  0.2306,  0.2927,  0.2936],
          [ 0.5248,  0.3457,  0.2609,  ...,  0.3057,  0.3070,  0.3021],
          [ 0.4576,  0.4258,  0.3637,  ...,  0.3781,  0.3804,  0.3607]],

         [[-1.0222, -0.7395, -0.6708,  ..., -0.5544, -0.5323, -1.0479],
          [-0.7135, -0.4130, -0.3063,  ..., -0.6664, -0.3887, -0.5997],
          [-0.7237, -0.4842, -0.5760,  ..., -0.6525, -0.6987, -0.9177],
          ...,
          [-0.9906, -0.5066, -0.5344,  ..., -0.2479, -0.3053, -0.5853],
          [-0.8421, -0.4733, -0.2319,  ..., -0.6079, -0.5194, -0.5823],
          [-0.9038, -0.8252, -0.6160,  ..., -0.9664, -0.7653, -0.8778]],

         [[-0.3252, -0.1989, -0.0736,  ..., -0.0459, -0.1201, -0.6306],
          [-0.2184, -0.2671, -0.3236,  ..., -0.2283, -0.0942, -0.5065],
          [-0.3181, -0.0854, -0.1833,  ..., -0.3519, -0.1705, -0.2260],
          ...,
          [-0.2587, -0.1918, -0.1453,  ..., -0.0321, -0.1507, -0.2337],
          [-0.5398, -0.0599, -0.3429,  ..., -0.0813, -0.1139, -0.4409],
          [-0.6048, -0.3892, -0.3652,  ..., -0.5147, -0.4351, -0.3231]],

         ...,

         [[-0.3657, -0.2115,  0.0081,  ..., -0.1041, -0.0889, -0.5111],
          [-0.2293, -0.1542, -0.1134,  ..., -0.2491, -0.0108, -0.3427],
          [-0.1626,  0.0304, -0.0673,  ..., -0.2488, -0.3555, -0.3401],
          ...,
          [-0.2499, -0.2363,  0.0428,  ...,  0.1410, -0.0597, -0.1154],
          [-0.3382,  0.0374, -0.1705,  ..., -0.0545, -0.0871, -0.2866],
          [-0.3992, -0.3094, -0.2202,  ..., -0.4959, -0.3723, -0.3031]],

         [[-1.4538, -1.0330, -1.2147,  ..., -0.9883, -1.0343, -1.4522],
          [-1.4043, -1.1521, -0.7814,  ..., -1.3682, -1.3603, -1.1118],
          [-1.4565, -1.2345, -1.5170,  ..., -1.3795, -1.3626, -1.3186],
          ...,
          [-1.7420, -0.5224, -1.0992,  ..., -1.1148, -1.2031, -1.1217],
          [-1.6478, -1.3150, -0.8264,  ..., -1.1239, -1.1825, -0.9351],
          [-1.4976, -1.4068, -1.2721,  ..., -1.1976, -1.2473, -1.3688]],

         [[-0.7946, -0.6282, -0.5310,  ..., -0.7992, -0.6840, -0.4670],
          [-0.9726, -0.9194, -0.5834,  ..., -1.1396, -0.9707, -0.6007],
          [-0.6839, -0.7608, -0.9145,  ..., -0.7800, -1.1498, -0.7977],
          ...,
          [-0.8311, -0.4279, -0.3612,  ..., -0.5752, -0.7894, -0.4850],
          [-0.7154, -0.8026, -0.6878,  ..., -0.6780, -0.6881, -0.4457],
          [-0.4830, -0.7109, -0.6619,  ..., -0.5163, -0.6582, -0.6847]]]],
       grad_fn=<ConvolutionBackward0>)
torch.Size([1, 128, 64, 64]) torch.float32

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions