Skip to content

Commit cb836e9

Browse files
committed
Merge branch 'master' into TextDemo
2 parents 734750a + a25abce commit cb836e9

File tree

26 files changed

+1013
-54
lines changed

26 files changed

+1013
-54
lines changed

TensorStack.Audio.Windows/AudioInput.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace TensorStack.Audio
1212
/// </summary>
1313
public class AudioInput : AudioInputBase
1414
{
15-
private readonly string _sourceFile;
15+
private string _sourceFile;
1616

1717
/// <summary>
1818
/// Initializes a new instance of the <see cref="AudioInput"/> class.
@@ -21,6 +21,13 @@ public class AudioInput : AudioInputBase
2121
public AudioInput(string filename, string audioCodec = "pcm_s16le", int sampleRate = 16000, int channels = 1)
2222
: this(filename, AudioManager.LoadTensor(filename, audioCodec, sampleRate, channels)) { }
2323

24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="AudioInput"/> class.
26+
/// </summary>
27+
/// <param name="audioTensor">The audio tensor.</param>
28+
public AudioInput(AudioTensor audioTensor)
29+
: base(audioTensor) { }
30+
2431
/// <summary>
2532
/// Initializes a new instance of the <see cref="AudioInput"/> class.
2633
/// </summary>
@@ -44,6 +51,9 @@ protected AudioInput(string filename, AudioTensor audioTensor)
4451
/// <param name="filename">The filename.</param>
4552
public override void Save(string filename)
4653
{
54+
if (string.IsNullOrEmpty(_sourceFile))
55+
_sourceFile = filename;
56+
4757
AudioManager.SaveAudio(filename, this);
4858
}
4959

@@ -55,6 +65,9 @@ public override void Save(string filename)
5565
/// <param name="cancellationToken">The cancellation token.</param>
5666
public override async Task SaveAsync(string filename, CancellationToken cancellationToken = default)
5767
{
68+
if (string.IsNullOrEmpty(_sourceFile))
69+
_sourceFile = filename;
70+
5871
await AudioManager.SaveAudioAync(filename, this, cancellationToken);
5972
}
6073

TensorStack.Providers.CPU/Provider.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public static Device GetDevice(DeviceType deviceType, int deviceId)
9090
/// </summary>
9191
/// <param name="deviceType">Type of the device.</param>
9292
/// <param name="optimizationLevel">The optimization level.</param>
93-
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
93+
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
9494
{
9595
return GetDevice().GetProvider(optimizationLevel);
9696
}
@@ -101,7 +101,7 @@ public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationL
101101
/// </summary>
102102
/// <param name="deviceType">Type of the device.</param>
103103
/// <param name="optimizationLevel">The optimization level.</param>
104-
public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
104+
public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
105105
{
106106
return GetDevice(deviceType).GetProvider(optimizationLevel);
107107
}
@@ -113,7 +113,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimiza
113113
/// <param name="deviceType">Type of the device.</param>
114114
/// <param name="deviceId">The device identifier.</param>
115115
/// <param name="optimizationLevel">The optimization level.</param>
116-
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
116+
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
117117
{
118118
return GetDevice(deviceType, deviceId).GetProvider(optimizationLevel);
119119
}
@@ -124,7 +124,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId,
124124
/// </summary>
125125
/// <param name="device">The device.</param>
126126
/// <param name="optimizationLevel">The optimization level.</param>
127-
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
127+
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
128128
{
129129
if (device == null)
130130
return default;
@@ -142,7 +142,7 @@ public static ExecutionProvider GetProvider(this Device device, GraphOptimizatio
142142
/// </summary>
143143
/// <param name="optimizationLevel">The optimization level.</param>
144144
/// <returns>ExecutionProvider.</returns>
145-
private static ExecutionProvider CreateProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
145+
private static ExecutionProvider CreateProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
146146
{
147147
return new ExecutionProvider(ProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
148148
{

TensorStack.Providers.CPU/TensorStack.Providers.CPU.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
<PackageReference Include="TensorStack.Common" Version="$(Version)" />
1616
</ItemGroup>
1717

18+
<!--Common Packages-->
19+
<ItemGroup>
20+
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.23.2" />
21+
</ItemGroup>
22+
1823
<!--Nuget Settings-->
1924
<PropertyGroup>
2025
<Title>$(AssemblyName)</Title>

TensorStack.Providers.CUDA/Provider.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public static Device GetDevice(DeviceType deviceType, int deviceId)
8989
/// Gets the CUDA provider this DeviceType.
9090
/// </summary>
9191
/// <param name="optimizationLevel">The optimization level.</param>
92-
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
92+
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
9393
{
9494
return GetDevice().GetProvider(optimizationLevel);
9595
}
@@ -100,7 +100,7 @@ public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationL
100100
/// </summary>
101101
/// <param name="deviceType">Type of the device.</param>
102102
/// <param name="optimizationLevel">The optimization level.</param>
103-
public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
103+
public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
104104
{
105105
return GetDevice(deviceType).GetProvider(optimizationLevel);
106106
}
@@ -112,7 +112,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimiza
112112
/// <param name="deviceType">Type of the device.</param>
113113
/// <param name="deviceId">The device identifier.</param>
114114
/// <param name="optimizationLevel">The optimization level.</param>
115-
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
115+
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
116116
{
117117
return GetDevice(deviceType, deviceId).GetProvider(optimizationLevel);
118118
}
@@ -123,7 +123,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId,
123123
/// </summary>
124124
/// <param name="device">The device.</param>
125125
/// <param name="optimizationLevel">The optimization level.</param>
126-
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
126+
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
127127
{
128128
if (device == null)
129129
return default;
@@ -141,7 +141,7 @@ public static ExecutionProvider GetProvider(this Device device, GraphOptimizatio
141141
/// </summary>
142142
/// <param name="deviceId">The device identifier.</param>
143143
/// <param name="optimizationLevel">The optimization level.</param>
144-
private static ExecutionProvider CreateProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
144+
private static ExecutionProvider CreateProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
145145
{
146146
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCUDA_PINNED, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
147147
return new ExecutionProvider(_providerName, memoryInfo, configuration =>
@@ -163,7 +163,7 @@ private static ExecutionProvider CreateProvider(int deviceId, GraphOptimizationL
163163
/// </summary>
164164
/// <param name="optimizationLevel">The optimization level.</param>
165165
/// <returns>ExecutionProvider.</returns>
166-
private static ExecutionProvider CreateProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
166+
private static ExecutionProvider CreateProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
167167
{
168168
return new ExecutionProvider(DeviceManager.CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
169169
{

TensorStack.Providers.DML/Provider.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public static Device GetDevice(DeviceType deviceType, int deviceId)
9191
/// Gets the DirectML provider this DeviceType.
9292
/// </summary>
9393
/// <param name="optimizationLevel">The optimization level.</param>
94-
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
94+
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
9595
{
9696
return GetDevice().GetProvider(optimizationLevel);
9797
}
@@ -102,7 +102,7 @@ public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationL
102102
/// </summary>
103103
/// <param name="deviceType">Type of the device.</param>
104104
/// <param name="optimizationLevel">The optimization level.</param>
105-
public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
105+
public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
106106
{
107107
return GetDevice(deviceType).GetProvider(optimizationLevel);
108108
}
@@ -114,7 +114,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimiza
114114
/// <param name="deviceType">Type of the device.</param>
115115
/// <param name="deviceId">The device identifier.</param>
116116
/// <param name="optimizationLevel">The optimization level.</param>
117-
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
117+
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
118118
{
119119
return GetDevice(deviceType, deviceId).GetProvider(optimizationLevel);
120120
}
@@ -125,7 +125,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId,
125125
/// </summary>
126126
/// <param name="device">The device.</param>
127127
/// <param name="optimizationLevel">The optimization level.</param>
128-
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
128+
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
129129
{
130130
if (device == null)
131131
return default;
@@ -143,7 +143,7 @@ public static ExecutionProvider GetProvider(this Device device, GraphOptimizatio
143143
/// </summary>
144144
/// <param name="deviceId">The device identifier.</param>
145145
/// <param name="optimizationLevel">The optimization level.</param>
146-
private static ExecutionProvider CreateProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
146+
private static ExecutionProvider CreateProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
147147
{
148148
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
149149
return new ExecutionProvider(_providerName, memoryInfo, configuration =>
@@ -165,7 +165,7 @@ private static ExecutionProvider CreateProvider(int deviceId, GraphOptimizationL
165165
/// </summary>
166166
/// <param name="optimizationLevel">The optimization level.</param>
167167
/// <returns>ExecutionProvider.</returns>
168-
private static ExecutionProvider CreateProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
168+
private static ExecutionProvider CreateProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
169169
{
170170
return new ExecutionProvider(DeviceManager.CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
171171
{
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using TensorStack.Common.Tensor;
3+
4+
namespace TensorStack.StableDiffusion.Common
5+
{
6+
public record EncoderCache
7+
{
8+
public ImageTensor InputImage { get; init; }
9+
public Tensor<float> CacheResult { get; init; }
10+
11+
public bool IsValid(ImageTensor input)
12+
{
13+
if (input is null || InputImage is null)
14+
return false;
15+
16+
if (!InputImage.Span.SequenceEqual(input.Span))
17+
return false;
18+
19+
return true;
20+
}
21+
}
22+
}

TensorStack.StableDiffusion/Common/GenerateOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public record GenerateOptions : IPipelineOptions, ISchedulerOptions
4545
public bool IsLowMemoryEncoderEnabled { get; set; }
4646
public bool IsLowMemoryDecoderEnabled { get; set; }
4747
public bool IsLowMemoryTextEncoderEnabled { get; set; }
48-
48+
public bool IsPipelineCacheEnabled { get; set; } = true;
4949

5050
public bool HasControlNet => ControlNet is not null;
5151
public bool HasInputImage => InputImage is not null;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using TensorStack.StableDiffusion.Pipelines;
3+
4+
namespace TensorStack.StableDiffusion.Common
5+
{
6+
public record PromptCache
7+
{
8+
public string Conditional { get; init; }
9+
public string Unconditional { get; init; }
10+
public PromptResult CacheResult { get; init; }
11+
12+
public bool IsValid(IPipelineOptions options)
13+
{
14+
return string.Equals(Conditional, options.Prompt, StringComparison.OrdinalIgnoreCase)
15+
&& string.Equals(Unconditional, options.NegativePrompt, StringComparison.OrdinalIgnoreCase);
16+
}
17+
}
18+
}

TensorStack.StableDiffusion/Common/TextEncoderResult.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
using System;
43
using TensorStack.Common.Tensor;
54

65
namespace TensorStack.StableDiffusion.Common
@@ -32,6 +31,4 @@ public Tensor<float> GetHiddenStates(int index)
3231
return _hiddenStates[0];
3332
}
3433
}
35-
36-
public record TextEncoderBatchedResult(Memory<float> PromptEmbeds, Memory<float> PromptPooledEmbeds);
3734
}

TensorStack.StableDiffusion/Models/CLIPTextModel.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using System.Threading;
55
using System.Threading.Tasks;
66
using TensorStack.Common;
7-
using TensorStack.Common.Tensor;
87
using TensorStack.StableDiffusion.Common;
98
using TensorStack.StableDiffusion.Config;
109
using TensorStack.TextGeneration.Tokenizers;

0 commit comments

Comments
 (0)