Skip to content

Commit 794ad03

Browse files
committed
Support model variant loading
1 parent f597342 commit 794ad03

18 files changed

+596
-51
lines changed

TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3+
using System;
34
using System.IO;
5+
using System.Linq;
46
using TensorStack.Common;
57
using TensorStack.StableDiffusion.Config;
68
using TensorStack.StableDiffusion.Enums;
@@ -111,16 +113,59 @@ public static FluxConfig FromFile(string configFile, ExecutionProvider execution
111113
/// <param name="executionProvider">The execution provider.</param>
112114
/// <returns>FluxConfig.</returns>
113115
public static FluxConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
116+
{
117+
return CreateFromFolder(modelFolder, default, modelType, executionProvider);
118+
}
119+
120+
121+
/// <summary>
122+
/// Create Flux configuration from folder structure
123+
/// </summary>
124+
/// <param name="modelFolder">The model folder.</param>
125+
/// <param name="variant">The variant.</param>
126+
/// <param name="modelType">Type of the model.</param>
127+
/// <param name="executionProvider">The execution provider.</param>
128+
/// <returns>FluxConfig.</returns>
129+
public static FluxConfig FromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider = default)
130+
{
131+
return CreateFromFolder(modelFolder, variant, modelType, executionProvider);
132+
}
133+
134+
135+
/// <summary>
136+
/// Create Flux configuration from folder structure
137+
/// </summary>
138+
/// <param name="modelFolder">The model folder.</param>
139+
/// <param name="variant">The variant.</param>
140+
/// <param name="executionProvider">The execution provider.</param>
141+
/// <returns>FluxConfig.</returns>
142+
public static FluxConfig FromFolder(string modelFolder, string variant, ExecutionProvider executionProvider = default)
143+
{
144+
string[] typeOptions = ["Turbo", "Distilled", "Dist", "Schnell"];
145+
var modelType = typeOptions.Any(v => variant.Contains(v, StringComparison.OrdinalIgnoreCase)) ? ModelType.Turbo : ModelType.Base;
146+
return CreateFromFolder(modelFolder, variant, modelType, executionProvider);
147+
}
148+
149+
150+
/// <summary>
151+
/// Create Flux configuration from folder structure
152+
/// </summary>
153+
/// <param name="modelFolder">The model folder.</param>
154+
/// <param name="variant">The variant.</param>
155+
/// <param name="modelType">Type of the model.</param>
156+
/// <param name="executionProvider">The execution provider.</param>
157+
/// <returns>FluxConfig.</returns>
158+
private static FluxConfig CreateFromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider)
114159
{
115160
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
116161
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json");
117162
config.Tokenizer2.Path = Path.Combine(modelFolder, "tokenizer_2", "spiece.model");
118-
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
119-
config.TextEncoder2.Path = Path.Combine(modelFolder, "text_encoder_2", "model.onnx");
120-
config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx");
121-
config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
122-
config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
123-
var controlNetPath = Path.Combine(modelFolder, "transformer", "controlnet.onnx");
163+
config.TextEncoder.Path = GetVariantPath(modelFolder, "text_encoder", "model.onnx", variant);
164+
config.TextEncoder2.Path = GetVariantPath(modelFolder, "text_encoder_2", "model.onnx", variant);
165+
config.Transformer.Path = GetVariantPath(modelFolder, "transformer", "model.onnx", variant);
166+
config.AutoEncoder.DecoderModelPath = GetVariantPath(modelFolder, "vae_decoder", "model.onnx", variant);
167+
config.AutoEncoder.EncoderModelPath = GetVariantPath(modelFolder, "vae_encoder", "model.onnx", variant);
168+
var controlNetPath = GetVariantPath(modelFolder, "transformer", "controlnet.onnx", variant);
124169
if (File.Exists(controlNetPath))
125170
config.Transformer.ControlNetPath = controlNetPath;
126171
return config;

TensorStack.StableDiffusion/Pipelines/Flux/FluxPipeline.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,34 @@ public static FluxPipeline FromFolder(string modelFolder, ModelType modelType, E
8282
{
8383
return new FluxPipeline(FluxConfig.FromFolder(modelFolder, modelType, executionProvider), logger);
8484
}
85+
86+
87+
/// <summary>
88+
/// Create Flux pipeline from folder structure
89+
/// </summary>
90+
/// <param name="modelFolder">The model folder.</param>
91+
/// <param name="variant">The variant.</param>
92+
/// <param name="modelType">Type of the model.</param>
93+
/// <param name="executionProvider">The execution provider.</param>
94+
/// <param name="logger">The logger.</param>
95+
/// <returns>FluxPipeline.</returns>
96+
public static FluxPipeline FromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default)
97+
{
98+
return new FluxPipeline(FluxConfig.FromFolder(modelFolder, variant, modelType, executionProvider), logger);
99+
}
100+
101+
102+
/// <summary>
103+
/// Create Flux pipeline from folder structure
104+
/// </summary>
105+
/// <param name="modelFolder">The model folder.</param>
106+
/// <param name="variant">The variant.</param>
107+
/// <param name="executionProvider">The execution provider.</param>
108+
/// <param name="logger">The logger.</param>
109+
/// <returns>FluxPipeline.</returns>
110+
public static FluxPipeline FromFolder(string modelFolder, string variant, ExecutionProvider executionProvider, ILogger logger = default)
111+
{
112+
return new FluxPipeline(FluxConfig.FromFolder(modelFolder, variant, executionProvider), logger);
113+
}
85114
}
86115
}

TensorStack.StableDiffusion/Pipelines/LatentConsistency/LatentConsistencyConfig.cs

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,45 @@ public override void Save(string configFile, bool useRelativePaths = true)
6767
/// <param name="executionProvider">The execution provider.</param>
6868
/// <returns>LatentConsistencyConfig.</returns>
6969
public static new LatentConsistencyConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
70+
{
71+
return CreateFromFolder(modelFolder, default, modelType, executionProvider);
72+
}
73+
74+
75+
/// <summary>
76+
/// Create LatentConsistency configuration from folder structure
77+
/// </summary>
78+
/// <param name="modelFolder">The model folder.</param>
79+
/// <param name="variant">The variant.</param>
80+
/// <param name="modelType">Type of the model.</param>
81+
/// <param name="executionProvider">The execution provider.</param>
82+
/// <returns>LatentConsistencyConfig.</returns>
83+
public static new LatentConsistencyConfig FromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider = default)
84+
{
85+
return CreateFromFolder(modelFolder, variant, modelType, executionProvider);
86+
}
87+
88+
89+
/// <summary>
90+
/// Create LatentConsistency configuration from folder structure
91+
/// </summary>
92+
/// <param name="modelFolder">The model folder.</param>
93+
/// <param name="variant">The variant.</param>
94+
/// <param name="modelType">Type of the model.</param>
95+
/// <param name="executionProvider">The execution provider.</param>
96+
/// <returns>LatentConsistencyConfig.</returns>
97+
private static LatentConsistencyConfig CreateFromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider)
7098
{
7199
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
72100
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json");
73-
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
74-
config.Unet.Path = Path.Combine(modelFolder, "unet", "model.onnx");
75-
config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
76-
config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
77-
var controlNetPath = Path.Combine(modelFolder, "unet", "controlnet.onnx");
101+
config.TextEncoder.Path = GetVariantPath(modelFolder, "text_encoder", "model.onnx", variant);
102+
config.Unet.Path = GetVariantPath(modelFolder, "unet", "model.onnx", variant);
103+
config.AutoEncoder.DecoderModelPath = GetVariantPath(modelFolder, "vae_decoder", "model.onnx", variant);
104+
config.AutoEncoder.EncoderModelPath = GetVariantPath(modelFolder, "vae_encoder", "model.onnx", variant);
105+
var controlNetPath = GetVariantPath(modelFolder, "unet", "controlnet.onnx", variant);
78106
if (File.Exists(controlNetPath))
79107
config.Unet.ControlNetPath = controlNetPath;
80108
return config;
81109
}
82-
83110
}
84111
}

TensorStack.StableDiffusion/Pipelines/LatentConsistency/LatentConsistencyPipeline.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,20 @@ protected override GenerateOptions ConfigureDefaultOptions()
9494
{
9595
return new LatentConsistencyPipeline(LatentConsistencyConfig.FromFolder(modelFolder, modelType, executionProvider), logger);
9696
}
97+
98+
99+
/// <summary>
100+
/// Create LatentConsistency pipeline from folder structure
101+
/// </summary>
102+
/// <param name="modelFolder">The model folder.</param>
103+
/// <param name="variant">The variant.</param>
104+
/// <param name="modelType">Type of the model.</param>
105+
/// <param name="executionProvider">The execution provider.</param>
106+
/// <param name="logger">The logger.</param>
107+
/// <returns>LatentConsistencyPipeline.</returns>
108+
public static new LatentConsistencyPipeline FromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default)
109+
{
110+
return new LatentConsistencyPipeline(LatentConsistencyConfig.FromFolder(modelFolder, variant, modelType, executionProvider), logger);
111+
}
97112
}
98113
}

TensorStack.StableDiffusion/Pipelines/LatentConsistency/LatentConsistencyVideoPipeline.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,20 @@ protected override GenerateOptions ConfigureDefaultOptions()
9494
{
9595
return new LatentConsistencyVideoPipeline(LatentConsistencyConfig.FromFolder(modelFolder, modelType, executionProvider), logger);
9696
}
97+
98+
99+
/// <summary>
100+
/// Create LatentConsistencyVideoPipeline pipeline from folder structure
101+
/// </summary>
102+
/// <param name="modelFolder">The model folder.</param>
103+
/// <param name="modelType">Type of the model.</param>
104+
/// <param name="executionProvider">The execution provider.</param>
105+
/// <param name="logger">The logger.</param>
106+
/// <returns>LatentConsistencyVideoPipeline.</returns>
107+
public static new LatentConsistencyVideoPipeline FromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default)
108+
{
109+
return new LatentConsistencyVideoPipeline(LatentConsistencyConfig.FromFolder(modelFolder, variant, modelType, executionProvider), logger);
110+
}
111+
97112
}
98113
}

TensorStack.StableDiffusion/Pipelines/Nitro/NitroConfig.cs

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3+
using System;
34
using System.IO;
5+
using System.Linq;
46
using TensorStack.Common;
57
using TensorStack.StableDiffusion.Config;
68
using TensorStack.StableDiffusion.Enums;
@@ -116,32 +118,62 @@ public static NitroConfig FromFile(string configFile, ExecutionProvider executio
116118
/// <param name="executionProvider">The execution provider.</param>
117119
public static NitroConfig FromFolder(string modelFolder, int outputSize, ModelType modelType, ExecutionProvider executionProvider = default)
118120
{
119-
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), outputSize, modelType, executionProvider);
120-
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer");
121-
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
122-
config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx");
123-
config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
124-
config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
125-
return config;
121+
return CreateFromFolder(modelFolder, default, outputSize, modelType, executionProvider);
126122
}
127123

128124

129125
/// <summary>
130126
/// Create Nitro configuration from folder structure
131127
/// </summary>
132128
/// <param name="modelFolder">The model folder.</param>
129+
/// <param name="variant">The variant.</param>
130+
/// <param name="outputSize">Size of the output.</param>
133131
/// <param name="modelType">Type of the model.</param>
134132
/// <param name="executionProvider">The execution provider.</param>
133+
/// <returns>NitroConfig.</returns>
134+
public static NitroConfig FromFolder(string modelFolder, string variant, int outputSize, ModelType modelType, ExecutionProvider executionProvider = default)
135+
{
136+
return CreateFromFolder(modelFolder, variant, outputSize, modelType, executionProvider);
137+
}
138+
139+
140+
/// <summary>
141+
/// Create Nitro configuration from folder structure
142+
/// </summary>
143+
/// <param name="modelFolder">The model folder.</param>
144+
/// <param name="variant">The variant.</param>
145+
/// <param name="executionProvider">The execution provider.</param>
146+
/// <returns>NitroConfig.</returns>
135147
public static NitroConfig FromFolder(string modelFolder, string variant, ExecutionProvider executionProvider = default)
136148
{
137-
var outputSize = variant.Contains("1024") ? 1024 : 512;
138-
var modelType = variant.Contains("Turbo", System.StringComparison.OrdinalIgnoreCase) ? ModelType.Turbo : ModelType.Base;
149+
string[] sizeOptions = ["XL", "Large", "1024"];
150+
string[] typeOptions = ["Turbo", "Distilled", "Dist"];
151+
var outputSize = sizeOptions.Any(v => variant.Contains(v, StringComparison.OrdinalIgnoreCase)) ? 1024 : 512;
152+
var modelType = typeOptions.Any(v => variant.Contains(v, StringComparison.OrdinalIgnoreCase)) ? ModelType.Turbo : ModelType.Base;
153+
return CreateFromFolder(modelFolder, variant, outputSize, modelType, executionProvider);
154+
}
155+
156+
157+
/// <summary>
158+
/// Create Nitro configuration from folder structure
159+
/// </summary>
160+
/// <param name="modelFolder">The model folder.</param>
161+
/// <param name="variant">The variant.</param>
162+
/// <param name="outputSize">Size of the output.</param>
163+
/// <param name="modelType">Type of the model.</param>
164+
/// <param name="executionProvider">The execution provider.</param>
165+
/// <returns>NitroConfig.</returns>
166+
private static NitroConfig CreateFromFolder(string modelFolder, string variant, int outputSize, ModelType modelType, ExecutionProvider executionProvider)
167+
{
139168
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), outputSize, modelType, executionProvider);
140169
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer");
141170
config.TextEncoder.Path = GetVariantPath(modelFolder, "text_encoder", "model.onnx", variant);
142171
config.Transformer.Path = GetVariantPath(modelFolder, "transformer", "model.onnx", variant);
143172
config.AutoEncoder.DecoderModelPath = GetVariantPath(modelFolder, "vae_decoder", "model.onnx", variant);
144173
config.AutoEncoder.EncoderModelPath = GetVariantPath(modelFolder, "vae_encoder", "model.onnx", variant);
174+
var controlNetPath = GetVariantPath(modelFolder, "transformer", "controlnet.onnx", variant);
175+
if (File.Exists(controlNetPath))
176+
config.Transformer.ControlNetPath = controlNetPath;
145177
return config;
146178
}
147179
}

TensorStack.StableDiffusion/Pipelines/Nitro/NitroPipeline.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,22 @@ public static NitroPipeline FromFolder(string modelFolder, int outputSize, Model
8080
}
8181

8282

83+
/// <summary>
84+
/// Create Nitro pipeline from folder structure
85+
/// </summary>
86+
/// <param name="modelFolder">The model folder.</param>
87+
/// <param name="variant">The variant.</param>
88+
/// <param name="outputSize">Size of the output.</param>
89+
/// <param name="modelType">Type of the model.</param>
90+
/// <param name="executionProvider">The execution provider.</param>
91+
/// <param name="logger">The logger.</param>
92+
/// <returns>NitroPipeline.</returns>
93+
public static NitroPipeline FromFolder(string modelFolder, string variant, int outputSize, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default)
94+
{
95+
return new NitroPipeline(NitroConfig.FromFolder(modelFolder, variant, outputSize, modelType, executionProvider), logger);
96+
}
97+
98+
8399
/// <summary>
84100
/// Create Nitro pipeline from folder structure
85101
/// </summary>

0 commit comments

Comments
 (0)