diff --git a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs index 00b7613..da15b10 100644 --- a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs +++ b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs @@ -12,6 +12,7 @@ open System.Threading.Tasks open System.Runtime.ExceptionServices #if !FABLE_COMPILER open System.Linq +open System.Threading.Channels #endif #nowarn "40" "3218" @@ -374,6 +375,17 @@ module AsyncSeq = member x.MoveNext() = async { return None } member x.Dispose() = () } } + let emptyAsync<'T> (action : Async) : AsyncSeq<'T> = + { new IAsyncEnumerable<'T> with + member x.GetEnumerator() = + { new IAsyncEnumerator<'T> with + member x.MoveNext() = + async { + do! action + return None + } + member x.Dispose() = () } } + let singleton (v:'T) : AsyncSeq<'T> = { new IAsyncEnumerable<'T> with member x.GetEnumerator() = @@ -1946,6 +1958,75 @@ module AsyncSeq = #endif + #if !FABLE_COMPILER + open System.Threading.Channels + + let toChannel (writer : ChannelWriter<'a>) (xs : AsyncSeq<'a>) : Async = + async { + try + do! + xs + |> iterAsync + (fun x -> + async { + if not (writer.TryWrite(x)) then + let! ct = Async.CancellationToken + + do! + writer.WriteAsync(x, ct).AsTask() + |> Async.AwaitTask + }) + + writer.Complete() + with exn -> + writer.Complete(error = exn) + } + + let fromChannel (reader : ChannelReader<'a>) : AsyncSeq<'a> = + asyncSeq { + let mutable keepGoing = true + + while keepGoing do + let mutable item = Unchecked.defaultof<'a> + + if reader.TryRead(&item) then + yield item + else + let! ct = Async.CancellationToken + + let! hasMoreData = + reader.WaitToReadAsync(ct).AsTask() + |> Async.AwaitTask + + if not hasMoreData then + keepGoing <- false + } + + let prefetch (numberToPrefetch : int) (xs : AsyncSeq<'a>) : AsyncSeq<'a> = + if numberToPrefetch = 0 then + xs + else + if numberToPrefetch < 1 then + invalidArg (nameof numberToPrefetch) "must be at least zero" + asyncSeq { + let opts = BoundedChannelOptions(numberToPrefetch) + opts.SingleWriter <- true + opts.SingleReader <- true + + let channel = Channel.CreateBounded(opts) + + let! fillChannelTask = + toChannel channel.Writer xs + |> Async.StartChild + + yield! + append + (fromChannel channel.Reader) + (emptyAsync fillChannelTask) + } + + #endif + [] module AsyncSeqExtensions = diff --git a/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi b/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi index 4fdb916..6db8140 100644 --- a/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi +++ b/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi @@ -557,6 +557,25 @@ module AsyncSeq = #endif #endif + #if !FABLE_COMPILER + + open System.Threading.Channels + + /// Fills a channel writer with the values from an async seq. + /// The writer will be closed when the async seq completes or raises an error. + val toChannel<'T> : writer: ChannelWriter<'T> -> source: AsyncSeq<'T> -> Async + + /// Creates an async seq from a channel reader. + /// The async seq will read values from the channel reader until it is closed. + /// If the reader raises an error than the sequence will raise it. + val fromChannel<'T> : reader: ChannelReader<'T> -> AsyncSeq<'T> + + /// Transforms an async seq to a new one that fetches values ahead of time to improve throughput. + val prefetch<'T> : numberToPrefetch: int -> source: AsyncSeq<'T> -> AsyncSeq<'T> + + #endif + + /// An automatically-opened module that contains the `asyncSeq` builder and an extension method [] module AsyncSeqExtensions = diff --git a/src/FSharp.Control.AsyncSeq/FSharp.Control.AsyncSeq.fsproj b/src/FSharp.Control.AsyncSeq/FSharp.Control.AsyncSeq.fsproj index 93718a5..a3059a7 100644 --- a/src/FSharp.Control.AsyncSeq/FSharp.Control.AsyncSeq.fsproj +++ b/src/FSharp.Control.AsyncSeq/FSharp.Control.AsyncSeq.fsproj @@ -23,6 +23,7 @@ + diff --git a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs index cdfd05a..67f1a8a 100644 --- a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs +++ b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs @@ -10,6 +10,7 @@ open NUnit.Framework open FSharp.Control open System open System.Threading +open System.Threading.Channels type AsyncOps = AsyncOps with static member unit : Async = async { return () } @@ -68,6 +69,14 @@ let runTimeout (timeoutMs:int) (a:Async<'a>) : 'a = let runTest a = runTimeout 1000 a +let rec disaggregate (exn : exn) : exn = + match exn with + | :? AggregateException as agg -> + match Seq.tryExactlyOne agg.InnerExceptions with + | Some inner -> disaggregate inner + | None -> exn + | _ -> exn + type Assert with /// Determines equality of two async sequences by convering them to lists, ignoring side-effects. @@ -1622,25 +1631,25 @@ let ``AsyncSeq.iterAsyncParallelThrottled should throttle`` () = let ``AsyncSeq.mapAsyncUnorderedParallel should produce all results`` () = let input = [1; 2; 3; 4; 5] let expected = [2; 4; 6; 8; 10] |> Set.ofList - - let actual = + + let actual = input |> AsyncSeq.ofSeq |> AsyncSeq.mapAsyncUnorderedParallel (fun x -> async { - do! Async.Sleep(10) + do! Async.Sleep(10) return x * 2 }) |> AsyncSeq.toListAsync |> runTest |> Set.ofList - + Assert.AreEqual(expected, actual) [] let ``AsyncSeq.mapAsyncUnorderedParallel should propagate exceptions`` () = let input = [1; 2; 3; 4; 5] - - let res = + + let res = input |> AsyncSeq.ofSeq |> AsyncSeq.mapAsyncUnorderedParallel (fun x -> async { @@ -1650,7 +1659,7 @@ let ``AsyncSeq.mapAsyncUnorderedParallel should propagate exceptions`` () = |> AsyncSeq.toListAsync |> Async.Catch |> runTest - + match res with | Choice2Of2 _ -> () // Expected exception | Choice1Of2 _ -> Assert.Fail("Expected exception but none was thrown") @@ -1660,7 +1669,7 @@ let ``AsyncSeq.mapAsyncUnorderedParallel should not preserve order`` () = // Test that results can come in different order than input let input = [1; 2; 3; 4; 5] let results = System.Collections.Generic.List() - + input |> AsyncSeq.ofSeq |> AsyncSeq.mapAsyncUnorderedParallel (fun x -> async { @@ -1671,12 +1680,12 @@ let ``AsyncSeq.mapAsyncUnorderedParallel should not preserve order`` () = }) |> AsyncSeq.iter ignore |> runTest - + let resultOrder = results |> List.ofSeq - // With unordered parallel processing and varying delays, + // With unordered parallel processing and varying delays, // we expect some reordering (though not guaranteed in all environments) let isReordered = resultOrder <> [1; 2; 3; 4; 5] - + // This test passes regardless of ordering since reordering depends on timing // The main validation is that all results are present let allPresent = (Set.ofList resultOrder) = (Set.ofList input) @@ -1849,17 +1858,17 @@ let ``AsyncSeq.sortByDescending should work``() = let ``async.For with AsyncSeq should work``() = async { let mutable results = [] - let source = asyncSeq { + let source = asyncSeq { yield 1 yield 2 yield 3 } - + do! async { for item in source do results <- item :: results } - + Assert.AreEqual([3; 2; 1], results) } |> Async.RunSynchronously @@ -1869,12 +1878,12 @@ let ``async.For with empty AsyncSeq should work``() = async { let mutable count = 0 let source = AsyncSeq.empty - + do! async { for item in source do count <- count + 1 } - + Assert.AreEqual(0, count) } |> Async.RunSynchronously @@ -1887,7 +1896,7 @@ let ``async.For with exception in AsyncSeq should propagate``() = failwith "test exception" yield 2 } - + try do! async { for item in source do @@ -1895,9 +1904,9 @@ let ``async.For with exception in AsyncSeq should propagate``() = } Assert.Fail("Expected exception to be thrown") with - | ex when ex.Message = "test exception" -> + | ex when ex.Message = "test exception" -> () // Expected - | ex -> + | ex -> Assert.Fail($"Unexpected exception: {ex.Message}") } |> Async.RunSynchronously @@ -1933,7 +1942,7 @@ let ``AsyncSeqExtensions - async.For with exception in AsyncSeq`` () = for item in asyncSeq { yield 1; failwith "test error"; yield 2 } do () with - | ex when ex.Message = "test error" -> + | ex when ex.Message = "test error" -> exceptionCaught <- true } computation |> Async.RunSynchronously @@ -1961,37 +1970,37 @@ let ``Seq.ofAsyncSeq with exception`` () = failwith "test error" yield 2 } - Assert.Throws(fun () -> + Assert.Throws(fun () -> Seq.ofAsyncSeq asyncSeqWithError |> Seq.toList |> ignore ) |> ignore [] -let ``AsyncSeq.intervalMs should generate sequence with timestamps``() = - let result = +let ``AsyncSeq.intervalMs should generate sequence with timestamps``() = + let result = AsyncSeq.intervalMs 50 |> AsyncSeq.take 3 |> AsyncSeq.toListAsync |> AsyncOps.timeoutMs 1000 |> Async.RunSynchronously - + Assert.AreEqual(3, result.Length) // Verify timestamps are increasing Assert.IsTrue(result.[1] > result.[0]) Assert.IsTrue(result.[2] > result.[1]) [] -let ``AsyncSeq.intervalMs with zero period should work``() = - let result = +let ``AsyncSeq.intervalMs with zero period should work``() = + let result = AsyncSeq.intervalMs 0 |> AsyncSeq.take 2 |> AsyncSeq.toListAsync |> AsyncOps.timeoutMs 500 |> Async.RunSynchronously - + Assert.AreEqual(2, result.Length) [] -let ``AsyncSeq.take with negative count should throw ArgumentException``() = +let ``AsyncSeq.take with negative count should throw ArgumentException``() = Assert.Throws(fun () -> AsyncSeq.ofSeq [1;2;3] |> AsyncSeq.take -1 @@ -2001,7 +2010,7 @@ let ``AsyncSeq.take with negative count should throw ArgumentException``() = ) |> ignore [] -let ``AsyncSeq.skip with negative count should throw ArgumentException``() = +let ``AsyncSeq.skip with negative count should throw ArgumentException``() = Assert.Throws(fun () -> AsyncSeq.ofSeq [1;2;3] |> AsyncSeq.skip -1 @@ -2011,32 +2020,32 @@ let ``AsyncSeq.skip with negative count should throw ArgumentException``() = ) |> ignore [] -let ``AsyncSeq.take zero should return empty sequence``() = +let ``AsyncSeq.take zero should return empty sequence``() = let expected = [] - let actual = + let actual = AsyncSeq.ofSeq [1;2;3] |> AsyncSeq.take 0 |> AsyncSeq.toListAsync |> Async.RunSynchronously - + Assert.AreEqual(expected, actual) -[] -let ``AsyncSeq.skip zero should return original sequence``() = +[] +let ``AsyncSeq.skip zero should return original sequence``() = let expected = [1;2;3] - let actual = + let actual = AsyncSeq.ofSeq [1;2;3] |> AsyncSeq.skip 0 |> AsyncSeq.toListAsync |> Async.RunSynchronously - + Assert.AreEqual(expected, actual) [] let ``AsyncSeq.replicateInfinite with exception should propagate exception``() = let exceptionMsg = "test exception" let expected = System.ArgumentException(exceptionMsg) - + Assert.Throws(fun () -> AsyncSeq.replicateInfinite (raise expected) |> AsyncSeq.take 2 @@ -2167,7 +2176,7 @@ let ``Seq.ofAsyncSeq should work``() = yield 2 yield 3 } - + let result = Seq.ofAsyncSeq source |> Seq.toList Assert.AreEqual([1; 2; 3], result) @@ -2184,22 +2193,22 @@ let ``Seq.ofAsyncSeq with exception should propagate``() = failwith "test exception" yield 2 } - + try let _ = Seq.ofAsyncSeq source |> Seq.toList Assert.Fail("Expected exception to be thrown") with - | ex when ex.Message = "test exception" -> + | ex when ex.Message = "test exception" -> () // Expected - | ex -> + | ex -> Assert.Fail($"Unexpected exception: {ex.Message}") #endif [] let ``AsyncSeq.fold with empty sequence should return seed``() = - let result = AsyncSeq.empty - |> AsyncSeq.fold (+) 10 + let result = AsyncSeq.empty + |> AsyncSeq.fold (+) 10 |> Async.RunSynchronously Assert.AreEqual(10, result) @@ -2219,25 +2228,25 @@ let ``AsyncSeq.mapAsync should preserve order with async transformations``() = do! Async.Sleep(50 - x * 10) // Shorter sleep for larger numbers return x * 2 } - - let result = data + + let result = data |> AsyncSeq.mapAsync asyncTransform - |> AsyncSeq.toListAsync + |> AsyncSeq.toListAsync |> Async.RunSynchronously Assert.AreEqual([2; 4; 6; 8; 10], result) [] let ``AsyncSeq.mapAsync should propagate exceptions``() = - let data = [1; 2; 3] |> AsyncSeq.ofSeq + let data = [1; 2; 3] |> AsyncSeq.ofSeq let asyncTransform x = async { if x = 2 then failwith "test error" return x * 2 } - + try - data + data |> AsyncSeq.mapAsync asyncTransform - |> AsyncSeq.toListAsync + |> AsyncSeq.toListAsync |> Async.RunSynchronously |> ignore Assert.Fail("Expected exception to be thrown") @@ -2252,10 +2261,10 @@ let ``AsyncSeq.chooseAsync should filter and transform``() = if x % 2 = 0 then return Some (x * 10) else return None } - - let result = data + + let result = data |> AsyncSeq.chooseAsync asyncChoose - |> AsyncSeq.toListAsync + |> AsyncSeq.toListAsync |> Async.RunSynchronously Assert.AreEqual([20; 40], result) @@ -2266,19 +2275,19 @@ let ``AsyncSeq.filterAsync should work with async predicates``() = do! Async.Sleep(1) return x % 2 = 1 } - - let result = data + + let result = data |> AsyncSeq.filterAsync asyncPredicate - |> AsyncSeq.toListAsync + |> AsyncSeq.toListAsync |> Async.RunSynchronously Assert.AreEqual([1; 3; 5], result) [] let ``AsyncSeq.scan should work with accumulator``() = let data = [1; 2; 3; 4] |> AsyncSeq.ofSeq - let result = data + let result = data |> AsyncSeq.scan (+) 0 - |> AsyncSeq.toListAsync + |> AsyncSeq.toListAsync |> Async.RunSynchronously Assert.AreEqual([0; 1; 3; 6; 10], result) @@ -2289,9 +2298,9 @@ let ``AsyncSeq.scanAsync should work with async accumulator``() = do! Async.Sleep(1) return acc + x } - let result = data + let result = data |> AsyncSeq.scanAsync asyncFolder 0 - |> AsyncSeq.toListAsync + |> AsyncSeq.toListAsync |> Async.RunSynchronously Assert.AreEqual([0; 1; 3; 6], result) @@ -2303,16 +2312,16 @@ let ``AsyncSeq.threadStateAsync should maintain state correctly``() = let output = x * newState return (output, newState) } - - let result = data + + let result = data |> AsyncSeq.threadStateAsync statefulFolder 0 - |> AsyncSeq.toListAsync + |> AsyncSeq.toListAsync |> Async.RunSynchronously Assert.AreEqual([1; 4; 9; 16], result) [] let ``AsyncSeq.lastOrDefault should return default for empty sequence``() = - let result = AsyncSeq.empty + let result = AsyncSeq.empty |> AsyncSeq.lastOrDefault 999 |> Async.RunSynchronously Assert.AreEqual(999, result) @@ -2320,11 +2329,76 @@ let ``AsyncSeq.lastOrDefault should return default for empty sequence``() = [] let ``AsyncSeq.lastOrDefault should return last element``() = let data = [1; 2; 3; 4; 5] |> AsyncSeq.ofSeq - let result = data + let result = data |> AsyncSeq.lastOrDefault 999 |> Async.RunSynchronously Assert.AreEqual(5, result) +[] +let ``AsyncSeq.toChannel and AsyncSeq.fromChannel should work``() = + let input = [1; 2; 4; 5; 6; 7; 8] + let channel = Channel.CreateBounded(3) + let actual = + async { + let! fillTask = + input + |> AsyncSeq.ofSeq + |> AsyncSeq.toChannel channel.Writer + |> Async.StartChild + + let! toListTask = + AsyncSeq.fromChannel channel.Reader + |> AsyncSeq.toListAsync + |> Async.StartChild + + do! fillTask + + return! toListTask + } + |> Async.RunSynchronously + Assert.AreEqual(input, actual) + +[] +let ``AsyncSeq.prefetch should not alter elements``() = + let input = ["h"; "e"; "l"; "l"; "o"] + let actual = + input + |> AsyncSeq.ofSeq + |> AsyncSeq.prefetch 2 + |> AsyncSeq.toListAsync + |> Async.RunSynchronously + Assert.AreEqual(input, actual) + +[] +let ``AsyncSeq.toChannel and AsyncSeq.fromChannel capture exns``() = + async { + let channel = Channel.CreateBounded(2) + + let! fillTask = + asyncSeq { + "a" + "b" + "c" + failwith "Kaboom" + } + |> AsyncSeq.toChannel channel.Writer + |> Async.StartChild + + let! toListTask = + AsyncSeq.fromChannel channel.Reader + |> AsyncSeq.toListAsync + |> Async.StartChild + + do! fillTask + + try + let! _ = toListTask + Assert.Fail() + with exn -> + Assert.AreEqual((disaggregate exn).Message, "Kaboom") + } + |> Async.RunSynchronously + // ---------------------------------------------------------------------------- // Additional Coverage Tests targeting uncovered edge cases and branches @@ -2381,14 +2455,14 @@ let ``AsyncSeq.append with both sequences having exceptions should propagate fir let seq1 = asyncSeq { yield 1; failwith "error1" } let seq2 = asyncSeq { yield 2; failwith "error2" } let combined = AsyncSeq.append seq1 seq2 - + try let! _ = AsyncSeq.toListAsync combined Assert.Fail("Expected exception to be thrown") with - | ex when ex.Message = "error1" -> + | ex when ex.Message = "error1" -> () // Expected - first sequence's error should be thrown - | ex -> + | ex -> Assert.Fail($"Unexpected exception: {ex.Message}") } |> Async.RunSynchronously @@ -2401,14 +2475,14 @@ let ``AsyncSeq.concat with nested exceptions should propagate properly`` () = yield asyncSeq { yield 3 } } let flattened = AsyncSeq.concat nested - + try let! result = AsyncSeq.toListAsync flattened Assert.Fail("Expected exception to be thrown") with - | ex when ex.Message = "nested error" -> + | ex when ex.Message = "nested error" -> () // Expected - | ex -> + | ex -> Assert.Fail($"Unexpected exception: {ex.Message}") } |> Async.RunSynchronously @@ -2449,11 +2523,11 @@ let ``AsyncSeqOp.FoldAsync with unfoldAsync should work`` () = return None } let source = AsyncSeq.unfoldAsync generator 0 - + // This should hit the uncovered FoldAsync method in UnfoldAsyncEnumerator let folder acc x = async { return acc + x } let! result = AsyncSeq.foldAsync folder 0 source - + // Expected: sum of [0, 2, 4, 6, 8] = 20 Assert.AreEqual(20, result) } |> Async.RunSynchronously @@ -2478,13 +2552,13 @@ let ``AsyncSeqOp.FoldAsync with exception in generator should propagate`` () = return failwith "generator error" } let source = AsyncSeq.unfoldAsync generator 0 - + try let folder acc x = async { return acc + x } let! _ = AsyncSeq.foldAsync folder 0 source Assert.Fail("Expected exception to be thrown") with - | ex when ex.Message = "generator error" -> + | ex when ex.Message = "generator error" -> () // Expected } |> Async.RunSynchronously @@ -2498,7 +2572,7 @@ let ``AsyncSeqOp.FoldAsync with exception in folder should propagate`` () = return None } let source = AsyncSeq.unfoldAsync generator 0 - + try let folder acc x = async { if x = 1 then failwith "folder error" @@ -2507,7 +2581,7 @@ let ``AsyncSeqOp.FoldAsync with exception in folder should propagate`` () = let! _ = AsyncSeq.foldAsync folder 0 source Assert.Fail("Expected exception to be thrown") with - | ex when ex.Message = "folder error" -> + | ex when ex.Message = "folder error" -> () // Expected } |> Async.RunSynchronously diff --git a/tests/FSharp.Control.AsyncSeq.Tests/FSharp.Control.AsyncSeq.Tests.fsproj b/tests/FSharp.Control.AsyncSeq.Tests/FSharp.Control.AsyncSeq.Tests.fsproj index 89d9c90..7199d00 100644 --- a/tests/FSharp.Control.AsyncSeq.Tests/FSharp.Control.AsyncSeq.Tests.fsproj +++ b/tests/FSharp.Control.AsyncSeq.Tests/FSharp.Control.AsyncSeq.Tests.fsproj @@ -15,5 +15,6 @@ + \ No newline at end of file