88using TensorStack . TextGeneration . Pipelines . Other ;
99using TensorStack . TextGeneration . Pipelines . Phi ;
1010using TensorStack . Providers ;
11+ using TensorStack . TextGeneration . Pipelines . Whisper ;
12+ using TensorStack . Common . Tensor ;
1113
1214namespace TensorStack . Example . Services
1315{
1416 public class TextService : ServiceBase , ITextService
1517 {
1618 private readonly Settings _settings ;
17- private IPipeline < GenerateResult , GenerateOptions , GenerateProgress > _greedyPipeline ;
18- private IPipeline < GenerateResult [ ] , SearchOptions , GenerateProgress > _beamSearchPipeline ;
19+ private IPipeline _currentPipeline ;
1920 private CancellationTokenSource _cancellationTokenSource ;
2021 private bool _isLoaded ;
2122 private bool _isLoading ;
@@ -77,34 +78,37 @@ public async Task LoadAsync(TextModel model, Device device)
7778 using ( _cancellationTokenSource = new CancellationTokenSource ( ) )
7879 {
7980 var cancellationToken = _cancellationTokenSource . Token ;
80- if ( _greedyPipeline != null )
81- await _greedyPipeline . UnloadAsync ( cancellationToken ) ;
81+ if ( _currentPipeline != null )
82+ await _currentPipeline . UnloadAsync ( cancellationToken ) ;
8283
8384 var provider = device . GetProvider ( ) ;
8485 var providerCPU = Provider . GetProvider ( DeviceType . CPU ) ; // TODO: DirectML not working with decoder
8586 if ( model . Type == TextModelType . Phi3 )
8687 {
8788 if ( ! Enum . TryParse < PhiType > ( model . Version , true , out var phiType ) )
88- throw new ArgumentException ( "Invalid PhiType Version" ) ;
89+ throw new ArgumentException ( "Invalid Phi Version" ) ;
8990
90- var pipeline = Phi3Pipeline . Create ( providerCPU , model . Path , phiType ) ;
91- _greedyPipeline = pipeline ;
92- _beamSearchPipeline = pipeline ;
91+ _currentPipeline = Phi3Pipeline . Create ( providerCPU , model . Path , phiType ) ;
9392 }
9493 else if ( model . Type == TextModelType . Summary )
9594 {
96- var pipeline = SummaryPipeline . Create ( provider , providerCPU , model . Path ) ;
97- _greedyPipeline = pipeline ;
98- _beamSearchPipeline = pipeline ;
95+ _currentPipeline = SummaryPipeline . Create ( provider , providerCPU , model . Path ) ;
9996 }
100- await Task . Run ( ( ) => _greedyPipeline . LoadAsync ( cancellationToken ) , cancellationToken ) ;
97+ else if ( model . Type == TextModelType . Whisper )
98+ {
99+ if ( ! Enum . TryParse < WhisperType > ( model . Version , true , out var whisperType ) )
100+ throw new ArgumentException ( "Invalid Whisper Version" ) ;
101+
102+ _currentPipeline = WhisperPipeline . Create ( provider , providerCPU , model . Path , whisperType ) ;
103+ }
104+ await Task . Run ( ( ) => _currentPipeline . LoadAsync ( cancellationToken ) , cancellationToken ) ;
101105
102106 }
103107 }
104108 catch ( OperationCanceledException )
105109 {
106- _greedyPipeline ? . Dispose ( ) ;
107- _greedyPipeline = null ;
110+ _currentPipeline ? . Dispose ( ) ;
111+ _currentPipeline = null ;
108112 _currentConfig = null ;
109113 throw ;
110114 }
@@ -148,11 +152,13 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
148152 if ( options . Beams == 0 )
149153 {
150154 // Greedy Search
151- return [ await _greedyPipeline . RunAsync ( pipelineOptions , cancellationToken : _cancellationTokenSource . Token ) ] ;
155+ var greedyPipeline = _currentPipeline as IPipeline < GenerateResult , GenerateOptions , GenerateProgress > ;
156+ return [ await greedyPipeline . RunAsync ( pipelineOptions , cancellationToken : _cancellationTokenSource . Token ) ] ;
152157 }
153158
154159 // Beam Search
155- return await _beamSearchPipeline . RunAsync ( new SearchOptions ( pipelineOptions ) , cancellationToken : _cancellationTokenSource . Token ) ;
160+ var beamSearchPipeline = _currentPipeline as IPipeline < GenerateResult [ ] , SearchOptions , GenerateProgress > ;
161+ return await beamSearchPipeline . RunAsync ( new SearchOptions ( pipelineOptions ) , cancellationToken : _cancellationTokenSource . Token ) ;
156162 } ) ;
157163
158164 return pipelineResult ;
@@ -165,6 +171,57 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
165171 }
166172
167173
174+ public async Task < GenerateResult [ ] > ExecuteAsync ( WhisperRequest options )
175+ {
176+ try
177+ {
178+ IsExecuting = true ;
179+ using ( _cancellationTokenSource = new CancellationTokenSource ( ) )
180+ {
181+ var pipelineOptions = new WhisperOptions
182+ {
183+ Prompt = options . Prompt ,
184+ Seed = options . Seed ,
185+ Beams = options . Beams ,
186+ TopK = options . TopK ,
187+ TopP = options . TopP ,
188+ Temperature = options . Temperature ,
189+ MaxLength = options . MaxLength ,
190+ MinLength = options . MinLength ,
191+ NoRepeatNgramSize = options . NoRepeatNgramSize ,
192+ LengthPenalty = options . LengthPenalty ,
193+ DiversityLength = options . DiversityLength ,
194+ EarlyStopping = options . EarlyStopping ,
195+ AudioInput = options . AudioInput ,
196+ Language = options . Language ,
197+ Task = options . Task
198+ } ;
199+
200+ var pipelineResult = await Task . Run ( async ( ) =>
201+ {
202+ if ( options . Beams == 0 )
203+ {
204+ // Greedy Search
205+ var greedyPipeline = _currentPipeline as IPipeline < GenerateResult , WhisperOptions , GenerateProgress > ;
206+ return [ await greedyPipeline . RunAsync ( pipelineOptions , cancellationToken : _cancellationTokenSource . Token ) ] ;
207+ }
208+
209+ // Beam Search
210+ var beamSearchPipeline = _currentPipeline as IPipeline < GenerateResult [ ] , WhisperSearchOptions , GenerateProgress > ;
211+ return await beamSearchPipeline . RunAsync ( new WhisperSearchOptions ( pipelineOptions ) , cancellationToken : _cancellationTokenSource . Token ) ;
212+ } ) ;
213+
214+ return pipelineResult ;
215+ }
216+ }
217+ finally
218+ {
219+ IsExecuting = false ;
220+ }
221+ }
222+
223+
224+
168225 /// <summary>
169226 /// Cancel the running task (Load or Execute)
170227 /// </summary>
@@ -179,12 +236,12 @@ public async Task CancelAsync()
179236 /// </summary>
180237 public async Task UnloadAsync ( )
181238 {
182- if ( _greedyPipeline != null )
239+ if ( _currentPipeline != null )
183240 {
184241 await _cancellationTokenSource . SafeCancelAsync ( ) ;
185- await _greedyPipeline . UnloadAsync ( ) ;
186- _greedyPipeline . Dispose ( ) ;
187- _greedyPipeline = null ;
242+ await _currentPipeline . UnloadAsync ( ) ;
243+ _currentPipeline . Dispose ( ) ;
244+ _currentPipeline = null ;
188245 _currentConfig = null ;
189246 }
190247
@@ -205,6 +262,7 @@ public interface ITextService
205262 Task UnloadAsync ( ) ;
206263 Task CancelAsync ( ) ;
207264 Task < GenerateResult [ ] > ExecuteAsync ( TextRequest options ) ;
265+ Task < GenerateResult [ ] > ExecuteAsync ( WhisperRequest options ) ;
208266 }
209267
210268
@@ -224,4 +282,11 @@ public record TextRequest : ITransformerRequest
224282 public int DiversityLength { get ; set ; } = 5 ;
225283 }
226284
285+ public record WhisperRequest : TextRequest
286+ {
287+ public AudioTensor AudioInput { get ; set ; }
288+ public LanguageType Language { get ; set ; } = LanguageType . EN ;
289+ public TaskType Task { get ; set ; } = TaskType . Transcribe ;
290+ }
291+
227292}
0 commit comments