@@ -148,6 +148,10 @@ protected override void ValidateOptions(GenerateOptions options)
148148 /// <param name="cancellationToken">The cancellation token.</param>
149149 protected async Task < PromptResult > CreatePromptAsync ( IPipelineOptions options , CancellationToken cancellationToken = default )
150150 {
151+ var cachedPrompt = GetPromptCache ( options ) ;
152+ if ( cachedPrompt is not null )
153+ return cachedPrompt ;
154+
151155 // Tokenize2
152156 var promptTokens = await TokenizePromptAsync ( options . Prompt , cancellationToken ) ;
153157 var negativePromptTokens = await TokenizePromptAsync ( options . NegativePrompt , cancellationToken ) ;
@@ -179,7 +183,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
179183 var negativePromptPooledEmbeds = negativePromptEmbeddings . TextEmbeds ;
180184 negativePromptPooledEmbeds = negativePromptPooledEmbeds . Reshape ( [ negativePromptPooledEmbeds . Dimensions [ ^ 2 ] , negativePromptPooledEmbeds . Dimensions [ ^ 1 ] ] ) . FirstBatch ( ) ;
181185
182- return new PromptResult ( promptEmbeds , promptPooledEmbeds , negativePromptEmbeds , negativePromptPooledEmbeds ) ;
186+ return SetPromptCache ( options , new PromptResult ( promptEmbeds , promptPooledEmbeds , negativePromptEmbeds , negativePromptPooledEmbeds ) ) ;
183187 }
184188
185189
@@ -264,16 +268,23 @@ protected async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, T
264268 /// <param name="options">The options.</param>
265269 /// <param name="image">The latents.</param>
266270 /// <param name="cancellationToken">The cancellation token.</param>
267- private async Task < Tensor < float > > EncodeLatentsAsync ( IPipelineOptions options , ImageTensor image , CancellationToken cancellationToken = default )
271+ private async Task < Tensor < float > > EncodeLatentsAsync ( IPipelineOptions options , CancellationToken cancellationToken = default )
268272 {
269273 var timestamp = Logger . LogBegin ( LogLevel . Debug , "[EncodeLatentsAsync] Begin AutoEncoder Encode" ) ;
270- var inputTensor = image . ResizeImage ( options . Width , options . Height ) ;
274+ var cacheResult = GetEncoderCache ( options ) ;
275+ if ( cacheResult is not null )
276+ {
277+ Logger . LogEnd ( LogLevel . Debug , timestamp , "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result." ) ;
278+ return cacheResult ;
279+ }
280+
281+ var inputTensor = options . InputImage . ResizeImage ( options . Width , options . Height ) ;
271282 var encoderResult = await AutoEncoder . EncodeAsync ( inputTensor , cancellationToken : cancellationToken ) ;
272283 if ( options . IsLowMemoryEnabled || options . IsLowMemoryEncoderEnabled )
273284 await AutoEncoder . EncoderUnloadAsync ( ) ;
274285
275286 Logger . LogEnd ( LogLevel . Debug , timestamp , "[EncodeLatentsAsync] AutoEncoder Encode Complete" ) ;
276- return encoderResult ;
287+ return SetEncoderCache ( options , encoderResult ) ;
277288 }
278289
279290
@@ -392,7 +403,7 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
392403 if ( options . HasInputImage )
393404 {
394405 var timestep = scheduler . GetStartTimestep ( ) ;
395- var encoderResult = await EncodeLatentsAsync ( options , options . InputImage , cancellationToken ) ;
406+ var encoderResult = await EncodeLatentsAsync ( options , cancellationToken ) ;
396407 var noiseTensor = scheduler . CreateRandomSample ( encoderResult . Dimensions ) ;
397408 return PackLatents ( scheduler . ScaleNoise ( timestep , encoderResult , noiseTensor ) ) ;
398409 }
@@ -410,8 +421,8 @@ private async Task<Tensor<float>> CreateLatentInputAsync(IPipelineOptions option
410421 /// <returns></returns>
411422 protected Tensor < float > CreateLatentImageIds ( IPipelineOptions options )
412423 {
413- var height = options . Height / AutoEncoder . LatentChannels ;
414- var width = options . Width / AutoEncoder . LatentChannels ;
424+ var height = options . Height / AutoEncoder . LatentChannels ;
425+ var width = options . Width / AutoEncoder . LatentChannels ;
415426 var latentIds = new Tensor < float > ( [ height , width , 3 ] ) ;
416427
417428 for ( int i = 0 ; i < height ; i ++ )
0 commit comments