Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 63 additions & 31 deletions invokeai/app/invocations/qwen_image_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@

@invocation(
"qwen_image_denoise",
title="Denoise - Qwen Image Edit",
title="Denoise - Qwen Image",
tags=["image", "qwen_image"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run the denoising process with a Qwen Image Edit model."""
"""Run the denoising process with a Qwen Image model."""

# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
Expand Down Expand Up @@ -270,7 +270,7 @@ def _run_diffusion(self, context: InvocationContext):

# Try to load the scheduler config from the model's directory (Diffusers models
# have a scheduler/ subdir). For GGUF models this path doesn't exist, so fall
# back to instantiating the scheduler with the known Qwen Image Edit defaults.
# back to instantiating the scheduler with the known Qwen Image defaults.
model_path = context.models.get_absolute_path(context.models.get_config(self.transformer.transformer))
scheduler_path = model_path / "scheduler"
if scheduler_path.is_dir() and (scheduler_path / "scheduler_config.json").exists():
Expand Down Expand Up @@ -304,8 +304,19 @@ def _run_diffusion(self, context: InvocationContext):
init_sigmas = np.linspace(1.0, 1.0 / self.steps, self.steps).tolist()
scheduler.set_timesteps(sigmas=init_sigmas, mu=mu, device=device)

timesteps_sched = scheduler.timesteps
sigmas_sched = scheduler.sigmas
# Clip the schedule based on denoising_start/denoising_end to support img2img strength.
# The scheduler's sigmas go from high (noisy) to 0 (clean). We clip to the fractional range.
sigmas_sched = scheduler.sigmas # (N+1,) including terminal 0
if self.denoising_start > 0 or self.denoising_end < 1:
total_sigmas = len(sigmas_sched) - 1 # exclude terminal
start_idx = int(round(self.denoising_start * total_sigmas))
end_idx = int(round(self.denoising_end * total_sigmas))
sigmas_sched = sigmas_sched[start_idx : end_idx + 1] # +1 to include the next sigma for dt
# Rebuild timesteps from clipped sigmas (exclude terminal 0)
timesteps_sched = sigmas_sched[:-1] * scheduler.config.num_train_timesteps
else:
timesteps_sched = scheduler.timesteps

total_steps = len(timesteps_sched)

cfg_scale = self._prepare_cfg_scale(total_steps)
Expand Down Expand Up @@ -353,29 +364,44 @@ def _run_diffusion(self, context: InvocationContext):
# Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4)
latents = self._pack_latents(latents, 1, out_channels, latent_height, latent_width)

# Pack reference image latents and concatenate along the sequence dimension.
# The edit transformer always expects [noisy_patches ; ref_patches] in its sequence.
if ref_latents is not None:
_, ref_ch, rh, rw = ref_latents.shape
if rh != latent_height or rw != latent_width:
ref_latents = torch.nn.functional.interpolate(
ref_latents, size=(latent_height, latent_width), mode="bilinear"
# Determine whether the model uses reference latent conditioning (zero_cond_t).
# Edit models (zero_cond_t=True) expect [noisy_patches ; ref_patches] in the sequence.
# Txt2img models (zero_cond_t=False) only take noisy patches.
has_zero_cond_t = getattr(transformer_info.model, "zero_cond_t", False) or getattr(
transformer_info.model.config, "zero_cond_t", False
)
use_ref_latents = has_zero_cond_t

ref_latents_packed = None
if use_ref_latents:
if ref_latents is not None:
_, ref_ch, rh, rw = ref_latents.shape
if rh != latent_height or rw != latent_width:
ref_latents = torch.nn.functional.interpolate(
ref_latents, size=(latent_height, latent_width), mode="bilinear"
)
else:
# No reference image provided — use zeros so the model still gets the
# expected sequence layout.
ref_latents = torch.zeros(
1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype
)
ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width)

# img_shapes tells the transformer the spatial layout of patches.
if use_ref_latents:
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
(1, latent_height // 2, latent_width // 2),
]
]
else:
# No reference image provided — use zeros so the model still gets the
# expected sequence layout.
ref_latents = torch.zeros(
1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype
)
ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width)

# img_shapes tells the transformer the spatial layout of noisy and reference patches.
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
(1, latent_height // 2, latent_width // 2),
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
]
]
]

# Prepare inpaint extension (operates in 4D space, so unpack/repack around it)
inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape
Expand Down Expand Up @@ -422,14 +448,16 @@ def _run_diffusion(self, context: InvocationContext):
)
)

scheduler.set_begin_index(0)

for step_idx, t in enumerate(tqdm(timesteps_sched)):
# The pipeline passes timestep / 1000 to the transformer
timestep = t.expand(latents.shape[0]).to(inference_dtype)

# Concatenate noisy and reference patches along the sequence dim
model_input = torch.cat([latents, ref_latents_packed], dim=1)
# For edit models: concatenate noisy and reference patches along the sequence dim
# For txt2img models: just use noisy patches
if ref_latents_packed is not None:
model_input = torch.cat([latents, ref_latents_packed], dim=1)
else:
model_input = latents

noise_pred_cond = transformer(
hidden_states=model_input,
Expand Down Expand Up @@ -457,8 +485,12 @@ def _run_diffusion(self, context: InvocationContext):
else:
noise_pred = noise_pred_cond

# Use the scheduler's step method — exactly matching the pipeline
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# Euler step using the (possibly clipped) sigma schedule
sigma_curr = sigmas_sched[step_idx]
sigma_next = sigmas_sched[step_idx + 1]
dt = sigma_next - sigma_curr
latents = latents.to(torch.float32) + dt * noise_pred.to(torch.float32)
latents = latents.to(inference_dtype)

if inpaint_extension is not None:
sigma_next = sigmas_sched[step_idx + 1].item()
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/qwen_image_image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@

@invocation(
"qwen_image_i2l",
title="Image to Latents - Qwen Image Edit",
title="Image to Latents - Qwen Image",
tags=["image", "latents", "vae", "i2l", "qwen_image"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image using the Qwen Image Edit VAE."""
"""Generates latents from an image using the Qwen Image VAE."""

image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
Expand All @@ -51,7 +51,7 @@ def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tenso

image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode():
# The Qwen Image Edit VAE expects 5D input: (B, C, num_frames, H, W)
# The Qwen Image VAE expects 5D input: (B, C, num_frames, H, W)
if image_tensor.dim() == 4:
image_tensor = image_tensor.unsqueeze(2)

Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/qwen_image_latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@

@invocation(
"qwen_image_l2i",
title="Latents to Image - Qwen Image Edit",
title="Latents to Image - Qwen Image",
tags=["latents", "image", "vae", "l2i", "qwen_image"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents using the Qwen Image Edit VAE."""
"""Generates an image from latents using the Qwen Image VAE."""

latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
Expand All @@ -56,7 +56,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
TorchDevice.empty_cache()

with torch.inference_mode(), tiling_context:
# The Qwen Image Edit VAE uses per-channel latents_mean / latents_std
# The Qwen Image VAE uses per-channel latents_mean / latents_std
# instead of a single scaling_factor.
# Latents are 5D: (B, C, num_frames, H, W) — the unpack from the
# denoise step already produces this shape.
Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/invocations/qwen_image_lora_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@invocation_output("qwen_image_lora_loader_output")
class QwenImageLoRALoaderOutput(BaseInvocationOutput):
"""Qwen Image Edit LoRA Loader Output"""
"""Qwen Image LoRA Loader Output"""

transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Transformer"
Expand All @@ -24,14 +24,14 @@ class QwenImageLoRALoaderOutput(BaseInvocationOutput):

@invocation(
"qwen_image_lora_loader",
title="Apply LoRA - Qwen Image Edit",
title="Apply LoRA - Qwen Image",
tags=["lora", "model", "qwen_image"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a Qwen Image Edit transformer."""
"""Apply a LoRA model to a Qwen Image transformer."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
Expand Down Expand Up @@ -72,14 +72,14 @@ def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput:

@invocation(
"qwen_image_lora_collection_loader",
title="Apply LoRA Collection - Qwen Image Edit",
title="Apply LoRA Collection - Qwen Image",
tags=["lora", "model", "qwen_image"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a Qwen Image Edit transformer."""
"""Applies a collection of LoRAs to a Qwen Image transformer."""

loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/invocations/qwen_image_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@invocation_output("qwen_image_model_loader_output")
class QwenImageModelLoaderOutput(BaseInvocationOutput):
"""Qwen Image Edit base model loader output."""
"""Qwen Image model loader output."""

transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
qwen_vl_encoder: QwenVLEncoderField = OutputField(
Expand All @@ -31,14 +31,14 @@ class QwenImageModelLoaderOutput(BaseInvocationOutput):

@invocation(
"qwen_image_model_loader",
title="Main Model - Qwen Image Edit",
title="Main Model - Qwen Image",
tags=["model", "qwen_image"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class QwenImageModelLoaderInvocation(BaseInvocation):
"""Loads a Qwen Image Edit model, outputting its submodels.
"""Loads a Qwen Image model, outputting its submodels.

The transformer is always loaded from the main model (Diffusers or GGUF).

Expand All @@ -59,7 +59,7 @@ class QwenImageModelLoaderInvocation(BaseInvocation):

component_source: Optional[ModelIdentifierField] = InputField(
default=None,
description="Diffusers Qwen Image Edit model to extract the VAE and Qwen VL encoder from. "
description="Diffusers Qwen Image model to extract the VAE and Qwen VL encoder from. "
"Required when using a GGUF quantized transformer. "
"Ignored when the main model is already in Diffusers format.",
input=Input.Direct,
Expand Down Expand Up @@ -96,7 +96,7 @@ def invoke(self, context: InvocationContext) -> QwenImageModelLoaderOutput:
raise ValueError(
"No source for VAE and Qwen VL encoder. "
"GGUF quantized models only contain the transformer — "
"please set 'Component Source' to a Diffusers Qwen Image Edit model "
"please set 'Component Source' to a Diffusers Qwen Image model "
"to provide the VAE and text encoder."
)

Expand Down
53 changes: 37 additions & 16 deletions invokeai/app/invocations/qwen_image_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,57 @@
QwenImageConditioningInfo,
)

# The Qwen Image Edit pipeline uses a specific system prompt and drops the first
# N tokens (the system prompt prefix) from the embeddings. These constants are
# taken directly from the diffusers QwenImagePipeline.
_SYSTEM_PROMPT = (
# Prompt templates and drop indices for the two Qwen Image model modes.
# These are taken directly from the diffusers pipelines.

# Image editing mode (QwenImagePipeline)
_EDIT_SYSTEM_PROMPT = (
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
"then explain how the user's text instruction should alter or modify the image. "
"Generate a new image that meets the user's requirements while maintaining consistency "
"with the original input where appropriate."
)
_EDIT_DROP_IDX = 64

# Text-to-image mode (QwenImagePipeline)
_GENERATE_SYSTEM_PROMPT = (
"Describe the image by detailing the color, shape, size, texture, quantity, "
"text, spatial relationships of the objects and background:"
)
_GENERATE_DROP_IDX = 34

_IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
_DROP_IDX = 64


def _build_prompt(user_prompt: str, num_images: int) -> str:
"""Build the full prompt with one vision placeholder per reference image."""
image_tokens = _IMAGE_PLACEHOLDER * max(num_images, 1)
return (
f"<|im_start|>system\n{_SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{image_tokens}{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)
"""Build the full prompt with the appropriate template based on whether reference images are provided."""
if num_images > 0:
# Edit mode: include vision placeholders for reference images
image_tokens = _IMAGE_PLACEHOLDER * num_images
return (
f"<|im_start|>system\n{_EDIT_SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{image_tokens}{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)
else:
# Generate mode: text-only prompt
return (
f"<|im_start|>system\n{_GENERATE_SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)


@invocation(
"qwen_image_text_encoder",
title="Prompt - Qwen Image Edit",
title="Prompt - Qwen Image",
tags=["prompt", "conditioning", "qwen_image"],
category="conditioning",
version="1.2.0",
classification=Classification.Prototype,
)
class QwenImageTextEncoderInvocation(BaseInvocation):
"""Encodes text and reference images for Qwen Image Edit using Qwen2.5-VL."""
"""Encodes text and reference images for Qwen Image using Qwen2.5-VL."""

prompt: str = InputField(description="Text prompt describing the desired edit.", ui_component=UIComponent.Textarea)
reference_images: list[ImageField] = InputField(
Expand Down Expand Up @@ -188,15 +206,18 @@ def _encode(
hidden_states = outputs.hidden_states[-1]

# Extract valid (non-padding) tokens using the attention mask,
# then drop the first _DROP_IDX tokens (system prompt prefix).
# then drop the system prompt prefix tokens.
# The drop index differs between edit mode (64) and generate mode (34).
drop_idx = _EDIT_DROP_IDX if images else _GENERATE_DROP_IDX

attn_mask = model_inputs.attention_mask
bool_mask = attn_mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0)

# Drop system prefix tokens and build padded output
trimmed = [h[_DROP_IDX:] for h in split_hidden]
trimmed = [h[drop_idx:] for h in split_hidden]
attn_mask_list = [torch.ones(h.size(0), dtype=torch.long, device=device) for h in trimmed]
max_seq_len = max(h.size(0) for h in trimmed)

Expand Down
Loading
Loading