Skip to content

Commit 6badd21

Browse files
committed
AudioElement details
1 parent cb836e9 commit 6badd21

File tree

16 files changed

+236
-113
lines changed

16 files changed

+236
-113
lines changed

Examples/TensorStack.Example.TextGeneration/MainWindow.xaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515

1616
<!--Main Menu-->
1717
<Grid DockPanel.Dock="Top" WindowChrome.IsHitTestVisibleInChrome="True">
18-
<UniformGrid Columns="4" Height="30" Margin="2">
18+
<UniformGrid Columns="5" Height="30" Margin="2">
1919

2020
<!--Logo-->
2121
<Grid IsHitTestVisible="False">
2222
<Image Source="{StaticResource ImageTensorstackText}" Height="32" HorizontalAlignment="Left" Margin="4,2,50,0" />
2323
</Grid>
2424

2525
<!--Views-->
26-
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.TextSummary}" Content="Text Summary" />
27-
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.Transcribe}" Content="Transcribe" />
26+
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.TextToText}" Content="TextToText" />
27+
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.AudioToText}" Content="AudioToText" />
28+
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.TextToAudio}" Content="TextToAudio" />
2829

2930
<!--Window Options-->
3031
<StackPanel Orientation="Horizontal" HorizontalAlignment="Right">

Examples/TensorStack.Example.TextGeneration/MainWindow.xaml.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public MainWindow(Settings configuration, NavigationService navigation)
1717
NavigateCommand = new AsyncRelayCommand<View>(NavigateAsync, CanNavigate);
1818
InitializeComponent();
1919

20-
NavigateCommand.Execute(View.Transcribe);
20+
NavigateCommand.Execute(View.AudioToText);
2121
}
2222

2323
public NavigationService Navigation { get; }

Examples/TensorStack.Example.TextGeneration/Services/TextService.cs

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
using TensorStack.TextGeneration.Pipelines.Other;
99
using TensorStack.TextGeneration.Pipelines.Phi;
1010
using TensorStack.Providers;
11+
using TensorStack.TextGeneration.Pipelines.Whisper;
12+
using TensorStack.Common.Tensor;
1113

1214
namespace TensorStack.Example.Services
1315
{
1416
public class TextService : ServiceBase, ITextService
1517
{
1618
private readonly Settings _settings;
17-
private IPipeline<GenerateResult, GenerateOptions, GenerateProgress> _greedyPipeline;
18-
private IPipeline<GenerateResult[], SearchOptions, GenerateProgress> _beamSearchPipeline;
19+
private IPipeline _currentPipeline;
1920
private CancellationTokenSource _cancellationTokenSource;
2021
private bool _isLoaded;
2122
private bool _isLoading;
@@ -77,34 +78,37 @@ public async Task LoadAsync(TextModel model, Device device)
7778
using (_cancellationTokenSource = new CancellationTokenSource())
7879
{
7980
var cancellationToken = _cancellationTokenSource.Token;
80-
if (_greedyPipeline != null)
81-
await _greedyPipeline.UnloadAsync(cancellationToken);
81+
if (_currentPipeline != null)
82+
await _currentPipeline.UnloadAsync(cancellationToken);
8283

8384
var provider = device.GetProvider();
8485
var providerCPU = Provider.GetProvider(DeviceType.CPU); // TODO: DirectML not working with decoder
8586
if (model.Type == TextModelType.Phi3)
8687
{
8788
if (!Enum.TryParse<PhiType>(model.Version, true, out var phiType))
88-
throw new ArgumentException("Invalid PhiType Version");
89+
throw new ArgumentException("Invalid Phi Version");
8990

90-
var pipeline = Phi3Pipeline.Create(providerCPU, model.Path, phiType);
91-
_greedyPipeline = pipeline;
92-
_beamSearchPipeline = pipeline;
91+
_currentPipeline = Phi3Pipeline.Create(providerCPU, model.Path, phiType);
9392
}
9493
else if (model.Type == TextModelType.Summary)
9594
{
96-
var pipeline = SummaryPipeline.Create(provider, providerCPU, model.Path);
97-
_greedyPipeline = pipeline;
98-
_beamSearchPipeline = pipeline;
95+
_currentPipeline = SummaryPipeline.Create(provider, providerCPU, model.Path);
9996
}
100-
await Task.Run(() => _greedyPipeline.LoadAsync(cancellationToken), cancellationToken);
97+
else if (model.Type == TextModelType.Whisper)
98+
{
99+
if (!Enum.TryParse<WhisperType>(model.Version, true, out var whisperType))
100+
throw new ArgumentException("Invalid Whisper Version");
101+
102+
_currentPipeline = WhisperPipeline.Create(provider, providerCPU, model.Path, whisperType);
103+
}
104+
await Task.Run(() => _currentPipeline.LoadAsync(cancellationToken), cancellationToken);
101105

102106
}
103107
}
104108
catch (OperationCanceledException)
105109
{
106-
_greedyPipeline?.Dispose();
107-
_greedyPipeline = null;
110+
_currentPipeline?.Dispose();
111+
_currentPipeline = null;
108112
_currentConfig = null;
109113
throw;
110114
}
@@ -148,11 +152,13 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
148152
if (options.Beams == 0)
149153
{
150154
// Greedy Search
151-
return [await _greedyPipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token)];
155+
var greedyPipeline = _currentPipeline as IPipeline<GenerateResult, GenerateOptions, GenerateProgress>;
156+
return [await greedyPipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token)];
152157
}
153158

154159
// Beam Search
155-
return await _beamSearchPipeline.RunAsync(new SearchOptions(pipelineOptions), cancellationToken: _cancellationTokenSource.Token);
160+
var beamSearchPipeline = _currentPipeline as IPipeline<GenerateResult[], SearchOptions, GenerateProgress>;
161+
return await beamSearchPipeline.RunAsync(new SearchOptions(pipelineOptions), cancellationToken: _cancellationTokenSource.Token);
156162
});
157163

158164
return pipelineResult;
@@ -165,6 +171,57 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
165171
}
166172

167173

174+
public async Task<GenerateResult[]> ExecuteAsync(WhisperRequest options)
175+
{
176+
try
177+
{
178+
IsExecuting = true;
179+
using (_cancellationTokenSource = new CancellationTokenSource())
180+
{
181+
var pipelineOptions = new WhisperOptions
182+
{
183+
Prompt = options.Prompt,
184+
Seed = options.Seed,
185+
Beams = options.Beams,
186+
TopK = options.TopK,
187+
TopP = options.TopP,
188+
Temperature = options.Temperature,
189+
MaxLength = options.MaxLength,
190+
MinLength = options.MinLength,
191+
NoRepeatNgramSize = options.NoRepeatNgramSize,
192+
LengthPenalty = options.LengthPenalty,
193+
DiversityLength = options.DiversityLength,
194+
EarlyStopping = options.EarlyStopping,
195+
AudioInput = options.AudioInput,
196+
Language = options.Language,
197+
Task = options.Task
198+
};
199+
200+
var pipelineResult = await Task.Run(async () =>
201+
{
202+
if (options.Beams == 0)
203+
{
204+
// Greedy Search
205+
var greedyPipeline = _currentPipeline as IPipeline<GenerateResult, WhisperOptions, GenerateProgress>;
206+
return [await greedyPipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token)];
207+
}
208+
209+
// Beam Search
210+
var beamSearchPipeline = _currentPipeline as IPipeline<GenerateResult[], WhisperSearchOptions, GenerateProgress>;
211+
return await beamSearchPipeline.RunAsync(new WhisperSearchOptions(pipelineOptions), cancellationToken: _cancellationTokenSource.Token);
212+
});
213+
214+
return pipelineResult;
215+
}
216+
}
217+
finally
218+
{
219+
IsExecuting = false;
220+
}
221+
}
222+
223+
224+
168225
/// <summary>
169226
/// Cancel the running task (Load or Execute)
170227
/// </summary>
@@ -179,12 +236,12 @@ public async Task CancelAsync()
179236
/// </summary>
180237
public async Task UnloadAsync()
181238
{
182-
if (_greedyPipeline != null)
239+
if (_currentPipeline != null)
183240
{
184241
await _cancellationTokenSource.SafeCancelAsync();
185-
await _greedyPipeline.UnloadAsync();
186-
_greedyPipeline.Dispose();
187-
_greedyPipeline = null;
242+
await _currentPipeline.UnloadAsync();
243+
_currentPipeline.Dispose();
244+
_currentPipeline = null;
188245
_currentConfig = null;
189246
}
190247

@@ -205,6 +262,7 @@ public interface ITextService
205262
Task UnloadAsync();
206263
Task CancelAsync();
207264
Task<GenerateResult[]> ExecuteAsync(TextRequest options);
265+
Task<GenerateResult[]> ExecuteAsync(WhisperRequest options);
208266
}
209267

210268

@@ -224,4 +282,11 @@ public record TextRequest : ITransformerRequest
224282
public int DiversityLength { get; set; } = 5;
225283
}
226284

285+
public record WhisperRequest : TextRequest
286+
{
287+
public AudioTensor AudioInput { get; set; }
288+
public LanguageType Language { get; set; } = LanguageType.EN;
289+
public TaskType Task { get; set; } = TaskType.Transcribe;
290+
}
291+
227292
}

Examples/TensorStack.Example.TextGeneration/Settings.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ public class Settings : IUIConfiguration
1919
public string VideoCodec { get; set; } = "mp4v";
2020
public string DirectoryTemp { get; set; }
2121
public IReadOnlyList<Device> Devices { get; set; }
22-
public ObservableCollection<TextModel> TextModels { get; set; }
22+
public ObservableCollection<TextModel> TextToTextModels { get; set; }
23+
public ObservableCollection<TextModel> TextToAudioModels { get; set; }
24+
public ObservableCollection<TextModel> AudioToTextModels { get; set; }
2325

2426

2527
public void Initialize()

Examples/TensorStack.Example.TextGeneration/Settings.json

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"DirectoryTemp": "Temp",
3-
"TextModels": [
3+
"TextToTextModels": [
44
{
55
"Id": 1,
66
"Name": "Text Summary",
@@ -17,19 +17,28 @@
1717
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/TextSummary/spiece.model?download=true",
1818
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/TextSummary/tokenizer.json?download=true"
1919
]
20-
},
21-
{
22-
"Id": 2,
20+
}
21+
],
22+
"TextToAudioModels": [
23+
24+
],
25+
"AudioToTextModels":[ {
26+
"Id": 1,
2327
"Name": "Whisper-Small",
28+
"IsDefault": true,
2429
"Type": "Whisper",
30+
"Version": "Small",
31+
"MinLength": 20,
32+
"MaxLength": 512,
2533
"Prefixes": [ "Transcribe", "Translate" ],
2634
"Path": "Models\\TextGeneration\\Whisper-Small",
2735
"UrlPaths": [
28-
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper/config.json?download=true",
29-
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper/decoder_model_merged.onnx?download=true",
30-
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper/encoder_model.onnx?download=true",
31-
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper/spiece.model?download=true",
32-
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper/tokenizer.json?download=true"
36+
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper-Small/decoder_model_merged.onnx?download=true",
37+
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper-Small/encoder_model.onnx?download=true",
38+
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper-Small/merges.txt?download=true",
39+
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper-Small/tokenizer_config.json?download=true",
40+
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper-Small/vocab.json?download=true",
41+
"https://huggingface.co/TensorStack/TensorStack/resolve/main/TextGeneration/Whisper-Small/mel_filters.npz?download=true"
3342
]
3443
}
3544
]

Examples/TensorStack.Example.TextGeneration/Views/TranscribeView.xaml renamed to Examples/TensorStack.Example.TextGeneration/Views/AudioToTextView.xaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Views:ViewBase x:Class="TensorStack.Example.Views.TranscribeView"
1+
<Views:ViewBase x:Class="TensorStack.Example.Views.AudioToTextView"
22
xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
33
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
44
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
@@ -9,7 +9,7 @@
99
mc:Ignorable="d"
1010
d:DesignHeight="450"
1111
d:DesignWidth="800">
12-
<Grid DataContext="{Binding RelativeSource={RelativeSource AncestorType={x:Type local:TranscribeView}}}">
12+
<Grid DataContext="{Binding RelativeSource={RelativeSource AncestorType={x:Type local:AudioToTextView}}}">
1313

1414
<Grid.RowDefinitions>
1515
<RowDefinition Height="1*" />
@@ -44,7 +44,7 @@
4444
</Grid.ColumnDefinitions>
4545
<StackPanel Grid.Column="0" IsEnabled="{Binding TextService.IsLoaded, Converter={StaticResource InverseBoolConverter}}">
4646
<TextBlock Text="Models" Style="{StaticResource FieldTextBlockStyle}" />
47-
<ComboBox SelectedItem="{Binding SelectedModel, Mode=TwoWay}" ItemsSource="{Binding Settings.TextModels}" DisplayMemberPath="Name" />
47+
<ComboBox SelectedItem="{Binding SelectedModel, Mode=TwoWay}" ItemsSource="{Binding Settings.AudioToTextModels}" DisplayMemberPath="Name" />
4848
</StackPanel>
4949
<Grid Grid.Column="1" Height="22" Margin="0,18,0,0" >
5050
<Button x:Name="ButtonLoad" Content="Load" Command="{Binding LoadCommand}" />
@@ -54,15 +54,15 @@
5454

5555
<StackPanel>
5656
<TextBlock Text="Audio Input" Style="{StaticResource FieldTextBlockStyle}" />
57-
<CommonControls:AudioElement Source="{Binding AudioInput, Mode=TwoWay}" Configuration="{Binding Settings}" Progress="{Binding Progress}" Height="100" />
57+
<CommonControls:AudioElement Source="{Binding AudioInput, Mode=TwoWay}" Configuration="{Binding Settings}" />
5858
</StackPanel>
5959

6060
</StackPanel>
6161

6262
<!--Advanced-->
6363
<StackPanel DockPanel.Dock="Top" IsEnabled="{Binding TextService.IsLoaded}">
6464
<StackPanel Margin="0,4,0,0" Visibility="{Binding Prefixes, Converter={StaticResource EmptyToVisibilityConverter}}">
65-
<TextBlock Text="Summary Task" Style="{StaticResource FieldTextBlockStyle}" />
65+
<TextBlock Text="Task" Style="{StaticResource FieldTextBlockStyle}" />
6666
<ComboBox ItemsSource="{Binding Prefixes}" SelectedItem="{Binding SelectedPrefix}" />
6767
</StackPanel>
6868
<UniformGrid Columns="4" Margin="0,4,0,0">
@@ -118,7 +118,7 @@
118118
<!--Execute/Cancel-->
119119
<UniformGrid DockPanel.Dock="Bottom" Columns="2" Height="30">
120120
<Button Content="Cancel" Command="{Binding CancelCommand}" />
121-
<Button Content="Summarize" Command="{Binding ExecuteCommand}" />
121+
<Button Content="Transcribe" Command="{Binding ExecuteCommand}" />
122122
</UniformGrid>
123123

124124
</DockPanel>

Examples/TensorStack.Example.TextGeneration/Views/TranscribeView.xaml.cs renamed to Examples/TensorStack.Example.TextGeneration/Views/AudioToTextView.xaml.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
namespace TensorStack.Example.Views
1717
{
1818
/// <summary>
19-
/// Interaction logic for TranscribeView.xaml
19+
/// Interaction logic for AudioToTextView.xaml
2020
/// </summary>
21-
public partial class TranscribeView : ViewBase
21+
public partial class AudioToTextView : ViewBase
2222
{
2323
private Device _selectedDevice;
2424
private TextModel _selectedModel;
@@ -37,7 +37,7 @@ public partial class TranscribeView : ViewBase
3737
private int _selectedBeam;
3838
private AudioInput _audioInput;
3939

40-
public TranscribeView(Settings settings, NavigationService navigationService, ITextService textService)
40+
public AudioToTextView(Settings settings, NavigationService navigationService, ITextService textService)
4141
: base(settings, navigationService)
4242
{
4343
TextService = textService;
@@ -46,14 +46,14 @@ public TranscribeView(Settings settings, NavigationService navigationService, IT
4646
ExecuteCommand = new AsyncRelayCommand(ExecuteAsync, CanExecute);
4747
CancelCommand = new AsyncRelayCommand(CancelAsync, CanCancel);
4848
Progress = new ProgressInfo();
49-
SelectedModel = settings.TextModels.First(x => x.IsDefault);
49+
SelectedModel = settings.AudioToTextModels.First(x => x.IsDefault);
5050
SelectedDevice = settings.DefaultDevice;
5151
Prefixes = new ObservableCollection<string>();
5252
TranscribeResults = new ObservableCollection<TranscribeResult>();
5353
InitializeComponent();
5454
}
5555

56-
public override int Id => (int)View.Transcribe;
56+
public override int Id => (int)View.AudioToText;
5757
public ITextService TextService { get; }
5858
public AsyncRelayCommand LoadCommand { get; }
5959
public AsyncRelayCommand UnloadCommand { get; }
@@ -220,7 +220,7 @@ private async Task ExecuteAsync()
220220
Progress.Indeterminate("Generating Results...");
221221

222222
// Run Transcribe
223-
var transcribeResults = await TextService.ExecuteAsync(new TextRequest
223+
var transcribeResults = await TextService.ExecuteAsync(new WhisperRequest
224224
{
225225
//Prompt = promptText,
226226
Beams = _beams,
@@ -234,6 +234,7 @@ private async Task ExecuteAsync()
234234
NoRepeatNgramSize = 4,
235235
DiversityLength = _diversityLength,
236236
EarlyStopping = _earlyStopping,
237+
AudioInput = _audioInput
237238
});
238239

239240
TranscribeResults.Clear();

0 commit comments

Comments
 (0)