Skip to content

Commit bfcb3ad

Browse files
authored
Merge pull request #4 from TensorStack-AI/PromptCache
Prompt cache
2 parents 794ad03 + a88f721 commit bfcb3ad

File tree

15 files changed

+200
-36
lines changed

15 files changed

+200
-36
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using TensorStack.Common.Tensor;
3+
4+
namespace TensorStack.StableDiffusion.Common
5+
{
6+
public record EncoderCache
7+
{
8+
public ImageTensor InputImage { get; init; }
9+
public Tensor<float> CacheResult { get; init; }
10+
11+
public bool IsValid(ImageTensor input)
12+
{
13+
if (input is null || InputImage is null)
14+
return false;
15+
16+
if (!InputImage.Span.SequenceEqual(input.Span))
17+
return false;
18+
19+
return true;
20+
}
21+
}
22+
}

TensorStack.StableDiffusion/Common/GenerateOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public record GenerateOptions : IPipelineOptions, ISchedulerOptions
4545
public bool IsLowMemoryEncoderEnabled { get; set; }
4646
public bool IsLowMemoryDecoderEnabled { get; set; }
4747
public bool IsLowMemoryTextEncoderEnabled { get; set; }
48-
48+
public bool IsPipelineCacheEnabled { get; set; } = true;
4949

5050
public bool HasControlNet => ControlNet is not null;
5151
public bool HasInputImage => InputImage is not null;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using TensorStack.StableDiffusion.Pipelines;
3+
4+
namespace TensorStack.StableDiffusion.Common
5+
{
6+
public record PromptCache
7+
{
8+
public string Conditional { get; init; }
9+
public string Unconditional { get; init; }
10+
public PromptResult CacheResult { get; init; }
11+
12+
public bool IsValid(IPipelineOptions options)
13+
{
14+
return string.Equals(Conditional, options.Prompt, StringComparison.OrdinalIgnoreCase)
15+
&& string.Equals(Unconditional, options.NegativePrompt, StringComparison.OrdinalIgnoreCase);
16+
}
17+
}
18+
}

TensorStack.StableDiffusion/Common/TextEncoderResult.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
using System;
43
using TensorStack.Common.Tensor;
54

65
namespace TensorStack.StableDiffusion.Common
@@ -32,6 +31,4 @@ public Tensor<float> GetHiddenStates(int index)
3231
return _hiddenStates[0];
3332
}
3433
}
35-
36-
public record TextEncoderBatchedResult(Memory<float> PromptEmbeds, Memory<float> PromptPooledEmbeds);
3734
}

TensorStack.StableDiffusion/Models/CLIPTextModel.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using System.Threading;
55
using System.Threading.Tasks;
66
using TensorStack.Common;
7-
using TensorStack.Common.Tensor;
87
using TensorStack.StableDiffusion.Common;
98
using TensorStack.StableDiffusion.Config;
109
using TensorStack.TextGeneration.Tokenizers;

TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ protected override void ValidateOptions(GenerateOptions options)
148148
/// <param name="cancellationToken">The cancellation token.</param>
149149
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
150150
{
151+
var cachedPrompt = GetPromptCache(options);
152+
if (cachedPrompt is not null)
153+
return cachedPrompt;
154+
151155
// Tokenize2
152156
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
153157
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
@@ -179,7 +183,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
179183
var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds;
180184
negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch();
181185

182-
return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds);
186+
return SetPromptCache(options, new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds));
183187
}
184188

185189

@@ -264,16 +268,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
264268
/// <param name="options">The options.</param>
265269
/// <param name="image">The latents.</param>
266270
/// <param name="cancellationToken">The cancellation token.</param>
267-
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
271+
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
268272
{
269273
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
270-
var inputTensor = image.ResizeImage(options.Width, options.Height);
274+
var cacheResult = GetEncoderCache(options);
275+
if (cacheResult is not null)
276+
{
277+
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
278+
return cacheResult;
279+
}
280+
281+
var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
271282
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
272283
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
273284
await AutoEncoder.EncoderUnloadAsync();
274285

275286
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
276-
return encoderResult;
287+
return SetEncoderCache(options, encoderResult);
277288
}
278289

279290

@@ -392,7 +403,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
392403
if (options.HasInputImage)
393404
{
394405
var timestep = scheduler.GetStartTimestep();
395-
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
406+
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
396407
var noiseTensor = scheduler.CreateRandomSample(encoderResult.Dimensions);
397408
return PackLatents(scheduler.ScaleNoise(timestep, encoderResult, noiseTensor));
398409
}
@@ -410,8 +421,8 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
410421
/// <returns></returns>
411422
protected Tensor<float> CreateLatentImageIds(IPipelineOptions options)
412423
{
413-
var height = options.Height / AutoEncoder.LatentChannels;
414-
var width = options.Width / AutoEncoder.LatentChannels;
424+
var height = options.Height / AutoEncoder.LatentChannels;
425+
var width = options.Width / AutoEncoder.LatentChannels;
415426
var latentIds = new Tensor<float>([height, width, 3]);
416427

417428
for (int i = 0; i < height; i++)

TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public record FluxConfig : PipelineConfig
1818
public FluxConfig()
1919
{
2020
Tokenizer = new TokenizerConfig();
21-
Tokenizer2 = new TokenizerConfig{MaxLength = 512 };
21+
Tokenizer2 = new TokenizerConfig { MaxLength = 512 };
2222
TextEncoder = new CLIPModelConfig();
2323
TextEncoder2 = new CLIPModelConfig
2424
{

TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public interface IPipelineOptions : IRunOptions
2727
float ControlNetStrength { get; set; }
2828
ImageTensor InputControlImage { get; set; }
2929

30-
int ClipSkip{ get; set; }
30+
int ClipSkip { get; set; }
3131
float AestheticScore { get; set; }
3232
float AestheticNegativeScore { get; set; }
3333

@@ -36,7 +36,7 @@ public interface IPipelineOptions : IRunOptions
3636
bool IsLowMemoryEncoderEnabled { get; set; }
3737
bool IsLowMemoryDecoderEnabled { get; set; }
3838
bool IsLowMemoryTextEncoderEnabled { get; set; }
39-
39+
bool IsPipelineCacheEnabled { get; set; }
4040

4141
bool HasControlNet => ControlNet is not null;
4242
bool HasInputImage => InputImage is not null;

TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ protected override void ValidateOptions(GenerateOptions options)
137137
/// <param name="cancellationToken">The cancellation token.</param>
138138
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
139139
{
140+
var cachedPrompt = GetPromptCache(options);
141+
if (cachedPrompt is not null)
142+
return cachedPrompt;
143+
140144
// Conditional Prompt
141145
var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions
142146
{
@@ -159,7 +163,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
159163
}, cancellationToken);
160164
}
161165

162-
return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default);
166+
return SetPromptCache(options, new PromptResult(promptEmbeds, default, negativePromptEmbeds, default));
163167
}
164168

165169

@@ -187,16 +191,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
187191
/// <param name="options">The options.</param>
188192
/// <param name="image">The latents.</param>
189193
/// <param name="cancellationToken">The cancellation token.</param>
190-
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTensor image, CancellationToken cancellationToken = default)
194+
private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
191195
{
192196
var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode");
193-
var inputTensor = image.ResizeImage(options.Width, options.Height);
197+
var cacheResult = GetEncoderCache(options);
198+
if (cacheResult is not null)
199+
{
200+
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result.");
201+
return cacheResult;
202+
}
203+
204+
var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height);
194205
var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken);
195206
if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)
196207
await AutoEncoder.EncoderUnloadAsync();
197208

198209
Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete");
199-
return encoderResult;
210+
return SetEncoderCache(options, encoderResult);
200211
}
201212

202213

@@ -270,7 +281,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
270281
if (options.HasInputImage)
271282
{
272283
var timestep = scheduler.GetStartTimestep();
273-
var encoderResult = await EncodeLatentsAsync(options, options.InputImage, cancellationToken);
284+
var encoderResult = await EncodeLatentsAsync(options, cancellationToken);
274285
return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor);
275286
}
276287
return noiseTensor;

TensorStack.StableDiffusion/Pipelines/PipelineBase.cs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace TensorStack.StableDiffusion.Pipelines
1515
{
1616
public abstract class PipelineBase : IDisposable
1717
{
18+
private PromptCache _promptCache;
19+
private EncoderCache _encoderCache;
1820
private GenerateOptions _defaultOptions;
1921
private IReadOnlyList<SchedulerType> _schedulers;
2022

@@ -149,11 +151,78 @@ protected Tensor<float> ApplyGuidance(Tensor<float> conditional, Tensor<float> u
149151
}
150152

151153

154+
/// <summary>
155+
/// Gets the prompt cache.
156+
/// </summary>
157+
/// <param name="options">The options.</param>
158+
protected PromptResult GetPromptCache(IPipelineOptions options)
159+
{
160+
if (!options.IsPipelineCacheEnabled)
161+
return default;
162+
163+
if (_promptCache is null || !_promptCache.IsValid(options))
164+
return default;
165+
166+
return _promptCache.CacheResult;
167+
}
168+
169+
170+
/// <summary>
171+
/// Sets the prompt cache.
172+
/// </summary>
173+
/// <param name="options">The options.</param>
174+
/// <param name="promptResult">The prompt result to cache.</param>
175+
protected PromptResult SetPromptCache(IPipelineOptions options, PromptResult promptResult)
176+
{
177+
_promptCache = new PromptCache
178+
{
179+
CacheResult = promptResult,
180+
Conditional = options.Prompt,
181+
Unconditional = options.NegativePrompt,
182+
};
183+
return promptResult;
184+
}
185+
186+
187+
/// <summary>
188+
/// Gets the encoder cache.
189+
/// </summary>
190+
/// <param name="options">The options.</param>
191+
protected Tensor<float> GetEncoderCache(IPipelineOptions options)
192+
{
193+
if (!options.IsPipelineCacheEnabled)
194+
return default;
195+
196+
if (_encoderCache is null || !_encoderCache.IsValid(options.InputImage))
197+
return default;
198+
199+
return _encoderCache.CacheResult;
200+
}
201+
202+
203+
/// <summary>
204+
/// Sets the encoder cache.
205+
/// </summary>
206+
/// <param name="options">The options.</param>
207+
/// <param name="encoded">The encoded.</param>
208+
protected Tensor<float> SetEncoderCache(IPipelineOptions options, Tensor<float> encoded)
209+
{
210+
_encoderCache = new EncoderCache
211+
{
212+
InputImage = options.InputImage,
213+
CacheResult = encoded
214+
};
215+
return encoded;
216+
}
217+
218+
152219
/// <summary>
153220
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
154221
/// </summary>
155222
public void Dispose()
156223
{
224+
_promptCache = null;
225+
_encoderCache = null;
157226
Dispose(disposing: true);
158227
GC.SuppressFinalize(this);
159228
}

0 commit comments

Comments
 (0)