Generated AI models such as stable diffusion XL (SDXL) allow for the creation of high-quality, realistic content with a wide range of applications. However, harnessing the power of such models creates important challenges and computational costs. SDXL is an image generation model where UNET components are about three times larger than components from previous versions of the model. Deploying such models in production is difficult due to increased memory requirements and increased inference times. Today we are excited to announce that embracing Face Diffusers will support the delivery of SDXL using JAX on a cloud TPU, enabling high-performance, cost-effective inference.
Google Cloud TPU is a custom designed AI accelerator optimized for training and inference of large-scale AI models, including cutting-edge LLMS and SDXL. The new Cloud TPU V5E is dedicated to bringing the cost-effectiveness and performance needed for large-scale AI training and inference. At less than half the cost of a TPU V4, TPU V5E enables more organizations to train and deploy AI models.
diffusers Jax Integration provides a convenient way to run SDXL on a TPU via XLA and has written a demo to showcase it. You can try it in this space or in the playground embedded below.
Under the hood, this demo runs on several TPU V5E-4 instances (each instance has four TPU chips), taking advantage of parallelization to provide four large 1024×1024 images in about 4 seconds. This time, it includes format conversion, communication time, and front-end processing. As shown below, the actual production time is approximately 2.3 seconds!
In this blog post,
Let me explain why Jax + TPU + Diffusers is a powerful framework to run SDXL. Learn how to create a simple image generation pipeline using Diffusers and Jax Show Benchmark, comparing different TPU settings
Why SDXL’s Jax + TPU V5E?
It offers SDXL using JAX on the cloud TPU V5E, providing high performance and cost-effectiveness. This is possible through a combination of purpose-built TPU hardware and performance-optimized software stack. Below we highlight two important factors: JAX Just-in-Time (JIT) compilation and XLA compiler-driven Jax PMAP.
JIT Compilation
A notable feature of JAX is the Just-in-Time (JIT) compilation. The JIT compiler traces the code during the first execution and generates a highly optimized TPU binaries that are reused in subsequent calls. The catch of this process is that all input, intermediate and output shapes require static. In other words, you must know them in advance. Every change in shape triggers a new, costly compilation process again. JIT Compilation is ideal for services that allow you to design static shapes. The compilation runs once, then takes advantage of the ultra-fast inference time.
Image generation is suitable for JIT compilations. If you always generate the same number of images and have the same size, the output shape is constant and pre-known. Text input is also constant. Design, stable diffusion, SDXL uses a fixed shape embedded vector (with padding) to represent user-entered prompts. Therefore, you can write Jax code that depends on fixed shapes.
High-batch size high-performance throughput
Workloads can be scaled across multiple devices using Jax’s PMAP. It can represent a single program multi-data (SPMD) program. When you apply PMAP to a function, you use XLA to compile the function and run it in parallel on various XLA devices. For text-to-image generation workloads, this means that increasing the number of images rendered simultaneously is easy to implement and does not compromise performance. For example, running SDXL on a TPU on 8 chips will generate 8 images at the same time that one chip creates one image.
TPU V5E instances come in multiple shapes, including 1, 4, and 8 chip shapes, with ultra-fast ICI links between chips, up to 256 chips (full TPU V5E POD). This allows you to select the best TPU shape for your use case and easily take advantage of the parallelism provided by Jax and TPU.
How to write a pipeline for image generation in Jax
Using Jax, take a step forward in the code you need to write to run reasoning super fast! First, let’s import the dependencies.
Import Jacks
Import Jax.numpy As JNP
Import numpy As np
from flax.jax_utils Import I’ll replicate
from Diffuser Import FlaxStableDiffusionXLPipeline
Import time
Now we’ll load the base SDXL model and the remaining components needed for inference. Diffusers Pipeline helps you download and cache everything for us. Following Jax’s functional approach, the parameters of the model are returned individually and must be passed to the pipeline during inference.
pipeline, params = flaxtablediffusionxlpipeline.from_pretrained(
“stabilityai/stable-diffusion-xl-base-1.0”split_head_dim =truth
))
Model parameters are downloaded with 32-bit precision by default. To save memory and perform calculations faster, convert it to BFLOAT16, an efficient 16-bit representation. However, there is a warning. For best results, you must keep the scheduler state in float32. Otherwise, precision errors accumulate, resulting in poor quality or even black images.
scheduler_state = params.pop(“Scheduler”)params = jax.tree_util.tree_map (lambda x:x.astype(jnp.bfloat16),params)params(“Scheduler”)= scheduler_state
You are now ready to set up the prompts and remaining pipeline input.
default_prompt = “High quality photos of baby dolphins playing in the pool and wearing party hats”
default_neg_prompt = “Illustrations, low quality”
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = twenty five
The prompt must be provided as a tensor to the pipeline and must always have the same dimensions throughout the call. This allows you to compile inference calls. The pipeline prepare_inputs method creates a helper function to prepare both the prompt and the negative prompt as tensors to perform all the steps you need. Use it later from the Generate function.
def tokenize_prompt(Prompt, neg_prompt): PRONT_IDS = pipeline.pipeline.prepare_inputs(prompt)neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
return PROMPT_IDS, neg_prompt_ids
Duplicate inputs between devices to take advantage of parallelism. The Cloud TPU V5E-4 has four chips, so by duplicating the input, each chip is acquired and a different image is generated in parallel. Care should be taken to supply different random seeds to each chip, as the four images differ.
num_devices = jax.device_count() p_params = replicate(params)
def Replicate_all(PROMPT_IDS, NEG_PROMPT_IDS, SEED): P_PROMPT_IDS = Replicate(PROMPT_IDS) P_NEG_PROMPT_IDS = Replicate(neg_prompt_ids) rng = jax.random.prngkey(seed) rng = jax.random.split(rng, num_devices)
return p_prompt_ids, p_neg_prompt_ids, rng
Now you’re ready to put everything together in a generator function.
def Generate(
prompt,
Negative_prompt,
Seed = default_seed,
guidance_scale = default_guidance_scale,
num_inference_steps = default_num_steps,
): PRONT_IDS, NEG_PROMPT_IDS = TOKENIZE_PROMPT(PROMPT, NEGATION_PROMPT) PROMPT_IDS, NEG_PROMPT_IDS, RNG = Replicate_All (PROMPT_IDS, NEG_PROMPT_IDS, SEEDS) Images = Pipeline (Prompt_IDS, P_PARAMS, RNG, num_inference_steps_steps neg_prompt_ids = neg_prompt_ids, guidance_scale = guidance_scale, jit =truth).images images = images.reshape((images.shape())0) *image. Shape (1),) + images.shape( –3:))
return pipeline.numpy_to_pil (np.array (image))
jit = true indicates that you want to compile the pipeline call. This happens the first time you call Generate and it’s very slow. JAX needs to trace, optimize operations and convert them to low-level primitives. Complete this process and run the first generation to warm things up.
start = time.time()
printing(f “Compile…”) Generate (default_prompt, default_neg_prompt)
printing(f “Compiled {time.time() – start}“))
The first time I ran it took me about 3 minutes. However, when the code is compiled, inference becomes very fast. Let’s try again!
start = time.time()prompt = “Ancient Greek llama, canvas oil”
neg_prompt = “Manga, illustrations, animation”
Image = generate(prompt, neg_prompt)
printing(f “Inference {time.time() – start}“))
It took me about 2 seconds to generate the four images!
benchmark
Using the default Euler separate scheduler, we obtained the following measurements running SDXL 1.0 based in 20 steps: Compare the cloud TPU V5E with the TPUV4 with the same batch size. Note that due to parallel processing, the TPU V5E-4 used in the demo generates 4 images when using batch size 1 (or 8 images with batch size 2). Similarly, the TPU V5E-8 generates 8 images using batch size 1.
Cloud TPU tests were performed using Python 3.10 and Jax version 0.4.16. These are the same specifications used in demo spaces.
Batch-Size Latency Performer/$ TPU V5E-4 (JAX) 4 2.33S 21.46 8 4.99S 20.04 TPU V4-8 (JAX) 4 2.16S 9.05 8 4.17 8.98
The TPU V5E achieves up to 2.4 times the PERF/$ on SDXL compared to the TPU V4, indicating the cost-effectiveness of the latest TPU generation.
Use industry-standard metrics for throughput to measure inference performance. First, we measure the delay per image when the model is compiled and loaded. Then, we calculate the throughput by splitting the batch size into latency for each chip. As a result, throughput measures how the model is running in a production environment, regardless of the number of chips used. Then split the throughput at regular price to get performance per dollar.
How does the demo work?
The demo shown previously was created using a script that essentially follows the code I posted in this blog post. There is a simple load balancing server that runs on several cloud TPU V5E devices each with four chips, with users route requests and randomly backend the server. Enter the prompt in the demo and the request will be assigned to one of the backend servers and receive four images generated.
This is a simple solution based on several pre-allocated TPU instances. In a future post, we will show you how to use GKE to create a dynamic solution that adapts to your load.
All the code in the demo is open source and is available by hugging today’s face diffuser. I’m looking forward to what I’ll build with Diffusers + Jax + Cloud TPU!