Skip to content

Commit f597342

Browse files
committed
Support Nitro variants
1 parent 1687a1f commit f597342

File tree

4 files changed

+79
-15
lines changed

4 files changed

+79
-15
lines changed

TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ public abstract class NitroBase : PipelineBase
2626
/// <param name="textEncoder">The text encoder.</param>
2727
/// <param name="autoEncoder">The automatic encoder.</param>
2828
/// <param name="logger">The logger.</param>
29-
public NitroBase(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = default) : base(logger)
29+
public NitroBase(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, int outputSize, ILogger logger = default) : base(logger)
3030
{
3131
Transformer = transformer;
3232
AutoEncoder = autoEncoder;
3333
TextEncoder = textEncoder;
34+
OutputSize = outputSize;
3435
Initialize();
3536
Logger?.LogInformation("[NitroPipeline] Name: {Name}", Name);
3637
}
@@ -50,6 +51,7 @@ public NitroBase(NitroConfig configuration, ILogger logger = default) : this(
5051
Tokenizer = new BPETokenizer(configuration.Tokenizer),
5152
}),
5253
new AutoEncoderModel(configuration.AutoEncoder),
54+
configuration.OutputSize,
5355
logger)
5456
{
5557
Name = configuration.Name;
@@ -80,6 +82,11 @@ public NitroBase(NitroConfig configuration, ILogger logger = default) : this(
8082
/// </summary>
8183
public AutoEncoderModel AutoEncoder { get; init; }
8284

85+
/// <summary>
86+
/// Gets the size of the image output (512 or 1024).
87+
/// </summary>
88+
public int OutputSize { get; }
89+
8390

8491
/// <summary>
8592
/// Loads the pipeline.
@@ -118,6 +125,8 @@ protected override void ValidateOptions(GenerateOptions options)
118125
base.ValidateOptions(options);
119126
if (!Transformer.HasControlNet && options.HasControlNet)
120127
throw new ArgumentException("Model does not support ControlNet");
128+
if (options.Width != OutputSize || options.Height != OutputSize)
129+
throw new ArgumentException($"Model only supports {OutputSize}x{OutputSize} output size");
121130
}
122131

123132

@@ -193,6 +202,7 @@ private async Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, I
193202

194203
protected async Task<Tensor<float>> RunInferenceAsync(IPipelineOptions options, IScheduler scheduler, PromptResult prompt, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
195204
{
205+
System.Console.WriteLine(options);
196206
var timestamp = Logger.LogBegin(LogLevel.Debug, "[RunInferenceAsync] Begin Transformer Inference");
197207

198208
// Prompt
@@ -350,8 +360,8 @@ protected override GenerateOptions ConfigureDefaultOptions()
350360
{
351361
Steps = 20,
352362
Shift = 1f,
353-
Width = 512,
354-
Height = 512,
363+
Width = OutputSize,
364+
Height = OutputSize,
355365
GuidanceScale = 4f,
356366
Scheduler = SchedulerType.FlowMatchEulerDiscrete
357367
};
@@ -363,8 +373,8 @@ protected override GenerateOptions ConfigureDefaultOptions()
363373
{
364374
Steps = 4,
365375
Shift = 1f,
366-
Width = 512,
367-
Height = 512,
376+
Width = OutputSize,
377+
Height = OutputSize,
368378
GuidanceScale = 0,
369379
Scheduler = SchedulerType.FlowMatchEulerDiscrete
370380
};

TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public record NitroConfig : PipelineConfig
1616
/// </summary>
1717
public NitroConfig()
1818
{
19+
OutputSize = 512;
1920
Tokenizer = new TokenizerConfig
2021
{
2122
BOS = 128000,
@@ -50,6 +51,7 @@ public NitroConfig()
5051
public DecoderConfig TextEncoder { get; init; }
5152
public TransformerModelConfig Transformer { get; init; }
5253
public AutoEncoderModelConfig AutoEncoder { get; init; }
54+
public int OutputSize { get; init; }
5355

5456

5557
/// <summary>
@@ -82,9 +84,9 @@ public override void Save(string configFile, bool useRelativePaths = true)
8284
/// <param name="modelType">Type of the model.</param>
8385
/// <param name="executionProvider">The execution provider.</param>
8486
/// <returns>NitroConfig.</returns>
85-
public static NitroConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default)
87+
public static NitroConfig FromDefault(string name, int outputSize, ModelType modelType, ExecutionProvider executionProvider = default)
8688
{
87-
var config = new NitroConfig { Name = name };
89+
var config = new NitroConfig { Name = name, OutputSize = outputSize };
8890
config.Transformer.ModelType = modelType;
8991
config.SetProvider(executionProvider);
9092
return config;
@@ -109,12 +111,12 @@ public static NitroConfig FromFile(string configFile, ExecutionProvider executio
109111
/// Create Nitro configuration from folder structure
110112
/// </summary>
111113
/// <param name="modelFolder">The model folder.</param>
114+
/// <param name="outputSize">Size of the output.</param>
112115
/// <param name="modelType">Type of the model.</param>
113116
/// <param name="executionProvider">The execution provider.</param>
114-
/// <returns>NitroConfig.</returns>
115-
public static NitroConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
117+
public static NitroConfig FromFolder(string modelFolder, int outputSize, ModelType modelType, ExecutionProvider executionProvider = default)
116118
{
117-
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
119+
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), outputSize, modelType, executionProvider);
118120
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer");
119121
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
120122
config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx");
@@ -123,5 +125,24 @@ public static NitroConfig FromFolder(string modelFolder, ModelType modelType, Ex
123125
return config;
124126
}
125127

128+
129+
/// <summary>
130+
/// Create Nitro configuration from folder structure
131+
/// </summary>
132+
/// <param name="modelFolder">The model folder.</param>
133+
/// <param name="modelType">Type of the model.</param>
134+
/// <param name="executionProvider">The execution provider.</param>
135+
public static NitroConfig FromFolder(string modelFolder, string variant, ExecutionProvider executionProvider = default)
136+
{
137+
var outputSize = variant.Contains("1024") ? 1024 : 512;
138+
var modelType = variant.Contains("Turbo", System.StringComparison.OrdinalIgnoreCase) ? ModelType.Turbo : ModelType.Base;
139+
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), outputSize, modelType, executionProvider);
140+
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer");
141+
config.TextEncoder.Path = GetVariantPath(modelFolder, "text_encoder", "model.onnx", variant);
142+
config.Transformer.Path = GetVariantPath(modelFolder, "transformer", "model.onnx", variant);
143+
config.AutoEncoder.DecoderModelPath = GetVariantPath(modelFolder, "vae_decoder", "model.onnx", variant);
144+
config.AutoEncoder.EncoderModelPath = GetVariantPath(modelFolder, "vae_encoder", "model.onnx", variant);
145+
return config;
146+
}
126147
}
127148
}

TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ public class NitroPipeline : NitroBase, IPipeline<ImageTensor, GenerateOptions,
2323
/// <param name="textEncoder">The text encoder.</param>
2424
/// <param name="autoEncoder">The automatic encoder.</param>
2525
/// <param name="logger">The logger.</param>
26-
public NitroPipeline(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, ILogger logger = null)
27-
: base(transformer, textEncoder, autoEncoder, logger) { }
26+
public NitroPipeline(TransformerNitroModel transformer, LlamaPipeline textEncoder, AutoEncoderModel autoEncoder, int outputSize, ILogger logger = null)
27+
: base(transformer, textEncoder, autoEncoder, outputSize, logger) { }
2828

2929
/// <summary>
3030
/// Initializes a new instance of the <see cref="NitroPipeline"/> class.
@@ -70,13 +70,26 @@ public static NitroPipeline FromConfig(string configFile, ExecutionProvider exec
7070
/// Create Nitro pipeline from folder structure
7171
/// </summary>
7272
/// <param name="modelFolder">The model folder.</param>
73+
/// <param name="outputSize">Size of the output. [512, 1024]</param>
7374
/// <param name="modelType">Type of the model.</param>
7475
/// <param name="executionProvider">The execution provider.</param>
7576
/// <param name="logger">The logger.</param>
76-
/// <returns>NitroPipeline.</returns>
77-
public static NitroPipeline FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default)
77+
public static NitroPipeline FromFolder(string modelFolder, int outputSize, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default)
78+
{
79+
return new NitroPipeline(NitroConfig.FromFolder(modelFolder, outputSize, modelType, executionProvider), logger);
80+
}
81+
82+
83+
/// <summary>
84+
/// Create Nitro pipeline from folder structure
85+
/// </summary>
86+
/// <param name="modelFolder">The model folder.</param>
87+
/// <param name="variant">The variant.[512, 512-Turbo, 1024, 1024-Turbo]</param>
88+
/// <param name="executionProvider">The execution provider.</param>
89+
/// <param name="logger">The logger.</param>
90+
public static NitroPipeline FromFolder(string modelFolder, string variant, ExecutionProvider executionProvider, ILogger logger = default)
7891
{
79-
return new NitroPipeline(NitroConfig.FromFolder(modelFolder, modelType, executionProvider), logger);
92+
return new NitroPipeline(NitroConfig.FromFolder(modelFolder, variant, executionProvider), logger);
8093
}
8194
}
8295
}

TensorStack.StableDiffusion/Pipelines/PipelineConfig.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3+
using System.IO;
34
using TensorStack.Common;
45
using TensorStack.StableDiffusion.Enums;
56

@@ -24,5 +25,24 @@ public abstract record PipelineConfig
2425
/// </summary>
2526
/// <param name="executionProvider">The execution provider.</param>
2627
public abstract void SetProvider(ExecutionProvider executionProvider);
28+
29+
/// <summary>
30+
/// Gets the variant path if it exists.
31+
/// </summary>
32+
/// <param name="modelFolder">The model folder.</param>
33+
/// <param name="model">The model.</param>
34+
/// <param name="variant">The variant.</param>
35+
/// <param name="filename">The filename.</param>
36+
protected static string GetVariantPath(string modelFolder, string model, string filename, string variant = default)
37+
{
38+
if (!string.IsNullOrEmpty(variant))
39+
{
40+
var variantPath = Path.Combine(modelFolder, model, variant, filename);
41+
if (File.Exists(variantPath))
42+
return variantPath;
43+
}
44+
45+
return Path.Combine(modelFolder, model, filename);
46+
}
2747
}
2848
}

0 commit comments

Comments
 (0)