Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading.Tasks;
using Azure.Identity;
using Azure.Core;
using Microsoft.Azure.Cosmos;

namespace Microsoft.Agents.AI;
Expand Down Expand Up @@ -47,23 +47,30 @@ public static ChatClientAgentOptions WithCosmosDBMessageStore(
/// <param name="accountEndpoint">The Cosmos DB account endpoint URI.</param>
/// <param name="databaseId">The identifier of the Cosmos DB database.</param>
/// <param name="containerId">The identifier of the Cosmos DB container.</param>
/// <param name="tokenCredential">The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential).</param>
/// <returns>The configured <see cref="ChatClientAgentOptions"/>.</returns>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="options"/> is null.</exception>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="options"/> or <paramref name="tokenCredential"/> is null.</exception>
/// <exception cref="ArgumentException">Thrown when any string parameter is null or whitespace.</exception>
[RequiresUnreferencedCode("The CosmosChatMessageStore uses JSON serialization which is incompatible with trimming.")]
[RequiresDynamicCode("The CosmosChatMessageStore uses JSON serialization which is incompatible with NativeAOT.")]
public static ChatClientAgentOptions WithCosmosDBMessageStoreUsingManagedIdentity(
this ChatClientAgentOptions options,
string accountEndpoint,
string databaseId,
string containerId)
string containerId,
TokenCredential tokenCredential)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}

options.ChatMessageStoreFactory = (context, ct) => new ValueTask<ChatMessageStore>(new CosmosChatMessageStore(accountEndpoint, new DefaultAzureCredential(), databaseId, containerId));
if (tokenCredential is null)
{
throw new ArgumentNullException(nameof(tokenCredential));
}

options.ChatMessageStoreFactory = (context, ct) => new ValueTask<ChatMessageStore>(new CosmosChatMessageStore(accountEndpoint, tokenCredential, databaseId, containerId));
return options;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

using System;
using System.Diagnostics.CodeAnalysis;
using Azure.Identity;
using Azure.Core;
using Microsoft.Agents.AI.Workflows.Checkpointing;
using Microsoft.Azure.Cosmos;

Expand Down Expand Up @@ -52,14 +52,17 @@ public static CosmosCheckpointStore CreateCheckpointStore(
/// <param name="accountEndpoint">The Cosmos DB account endpoint URI.</param>
/// <param name="databaseId">The identifier of the Cosmos DB database.</param>
/// <param name="containerId">The identifier of the Cosmos DB container.</param>
/// <param name="tokenCredential">The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential).</param>
/// <returns>A new instance of <see cref="CosmosCheckpointStore"/>.</returns>
/// <exception cref="ArgumentException">Thrown when any string parameter is null or whitespace.</exception>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="tokenCredential"/> is null.</exception>
[RequiresUnreferencedCode("The CosmosCheckpointStore uses JSON serialization which is incompatible with trimming.")]
[RequiresDynamicCode("The CosmosCheckpointStore uses JSON serialization which is incompatible with NativeAOT.")]
public static CosmosCheckpointStore CreateCheckpointStoreUsingManagedIdentity(
string accountEndpoint,
string databaseId,
string containerId)
string containerId,
TokenCredential tokenCredential)
{
if (string.IsNullOrWhiteSpace(accountEndpoint))
{
Expand All @@ -76,7 +79,12 @@ public static CosmosCheckpointStore CreateCheckpointStoreUsingManagedIdentity(
throw new ArgumentException("Cannot be null or whitespace", nameof(containerId));
}

return new CosmosCheckpointStore(accountEndpoint, new DefaultAzureCredential(), databaseId, containerId);
if (tokenCredential is null)
{
throw new ArgumentNullException(nameof(tokenCredential));
}

return new CosmosCheckpointStore(accountEndpoint, tokenCredential, databaseId, containerId);
}

/// <summary>
Expand Down Expand Up @@ -154,14 +162,17 @@ public static CosmosCheckpointStore<T> CreateCheckpointStore<T>(
/// <param name="accountEndpoint">The Cosmos DB account endpoint URI.</param>
/// <param name="databaseId">The identifier of the Cosmos DB database.</param>
/// <param name="containerId">The identifier of the Cosmos DB container.</param>
/// <param name="tokenCredential">The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential).</param>
/// <returns>A new instance of <see cref="CosmosCheckpointStore{T}"/>.</returns>
/// <exception cref="ArgumentException">Thrown when any string parameter is null or whitespace.</exception>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="tokenCredential"/> is null.</exception>
[RequiresUnreferencedCode("The CosmosCheckpointStore uses JSON serialization which is incompatible with trimming.")]
[RequiresDynamicCode("The CosmosCheckpointStore uses JSON serialization which is incompatible with NativeAOT.")]
public static CosmosCheckpointStore<T> CreateCheckpointStoreUsingManagedIdentity<T>(
string accountEndpoint,
string databaseId,
string containerId)
string containerId,
TokenCredential tokenCredential)
{
if (string.IsNullOrWhiteSpace(accountEndpoint))
{
Expand All @@ -178,7 +189,12 @@ public static CosmosCheckpointStore<T> CreateCheckpointStoreUsingManagedIdentity
throw new ArgumentException("Cannot be null or whitespace", nameof(containerId));
}

return new CosmosCheckpointStore<T>(accountEndpoint, new DefaultAzureCredential(), databaseId, containerId);
if (tokenCredential is null)
{
throw new ArgumentNullException(nameof(tokenCredential));
}

return new CosmosCheckpointStore<T>(accountEndpoint, tokenCredential, databaseId, containerId);
}

/// <summary>
Expand Down
Loading