diff --git a/AIDevGallery.Utils/GitHubModelFileDetails.cs b/AIDevGallery.Utils/GitHubModelFileDetails.cs
index fca956e8..beaef13f 100644
--- a/AIDevGallery.Utils/GitHubModelFileDetails.cs
+++ b/AIDevGallery.Utils/GitHubModelFileDetails.cs
@@ -45,4 +45,17 @@ public class GitHubModelFileDetails
///
[JsonPropertyName("type")]
public string? Type { get; init; }
+
+ ///
+ /// Gets the encoded content of the file.
+ /// For LFS files, this contains the LFS pointer with SHA256.
+ ///
+ [JsonPropertyName("content")]
+ public string? Content { get; init; }
+
+ ///
+ /// Gets the encoding of the content (usually "base64").
+ ///
+ [JsonPropertyName("encoding")]
+ public string? Encoding { get; init; }
}
\ No newline at end of file
diff --git a/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs b/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs
index 9a982820..942f767e 100644
--- a/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs
+++ b/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs
@@ -27,4 +27,28 @@ public class HuggingFaceModelFileDetails
///
[JsonPropertyName("path")]
public string? Path { get; init; }
+
+ ///
+ /// Gets the LFS (Large File Storage) information for the file.
+ ///
+ [JsonPropertyName("lfs")]
+ public HuggingFaceLfsInfo? Lfs { get; init; }
+}
+
+///
+/// LFS (Large File Storage) information for a Hugging Face file.
+///
+public class HuggingFaceLfsInfo
+{
+ ///
+ /// Gets the OID (SHA256 hash) of the file. Format: "sha256:abc123..."
+ ///
+ [JsonPropertyName("oid")]
+ public string? Oid { get; init; }
+
+ ///
+ /// Gets the size of the file in LFS.
+ ///
+ [JsonPropertyName("size")]
+ public long Size { get; init; }
}
\ No newline at end of file
diff --git a/AIDevGallery.Utils/ModelFileDetails.cs b/AIDevGallery.Utils/ModelFileDetails.cs
index f5e717df..1cce49bb 100644
--- a/AIDevGallery.Utils/ModelFileDetails.cs
+++ b/AIDevGallery.Utils/ModelFileDetails.cs
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+using System;
+
namespace AIDevGallery.Utils;
///
@@ -27,4 +29,25 @@ public class ModelFileDetails
/// Gets the relative path to the file
///
public string? Path { get; init; }
+
+ ///
+ /// Gets the expected SHA256 hash of the file.
+ /// For Hugging Face: from LFS oid field.
+ /// For GitHub: from LFS pointer file.
+ ///
+ public string? Sha256 { get; init; }
+
+ ///
+ /// Gets a value indicating whether this file should be verified for integrity.
+ ///
+ public bool ShouldVerifyIntegrity => Name != null &&
+ (Name.EndsWith(".onnx", StringComparison.OrdinalIgnoreCase) ||
+ Name.EndsWith(".onnx.data", StringComparison.OrdinalIgnoreCase) ||
+ Name.EndsWith(".gguf", StringComparison.OrdinalIgnoreCase) ||
+ Name.EndsWith(".safetensors", StringComparison.OrdinalIgnoreCase));
+
+ ///
+ /// Gets a value indicating whether this file has a hash available for verification.
+ ///
+ public bool HasVerificationHash => !string.IsNullOrEmpty(Sha256);
}
\ No newline at end of file
diff --git a/AIDevGallery.Utils/ModelInformationHelper.cs b/AIDevGallery.Utils/ModelInformationHelper.cs
index 08293a94..d5e69e59 100644
--- a/AIDevGallery.Utils/ModelInformationHelper.cs
+++ b/AIDevGallery.Utils/ModelInformationHelper.cs
@@ -58,13 +58,55 @@ public static async Task> GetDownloadFilesFromGitHub(GitH
}
return files.Select(f =>
- new ModelFileDetails()
+ {
+ string? sha256 = null;
+
+ if (f.Content != null && f.Encoding == "base64")
+ {
+ try
+ {
+ var decodedContent = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(f.Content));
+ sha256 = ParseLfsPointerSha256(decodedContent);
+ }
+ catch (FormatException)
+ {
+ Debug.WriteLine($"Failed to decode base64 content for {f.Path}");
+ }
+ }
+
+ return new ModelFileDetails()
{
DownloadUrl = f.DownloadUrl,
Size = f.Size,
Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(),
- Path = f.Path
- }).ToList();
+ Path = f.Path,
+ Sha256 = sha256
+ };
+ }).ToList();
+ }
+
+ ///
+ /// Parses a Git LFS pointer file content to extract the SHA256 hash.
+ ///
+ /// The content of the LFS pointer file.
+ /// The SHA256 hash if found, otherwise null.
+ private static string? ParseLfsPointerSha256(string lfsPointerContent)
+ {
+ if (string.IsNullOrEmpty(lfsPointerContent))
+ {
+ return null;
+ }
+
+ var lines = lfsPointerContent.Split(['\n', '\r'], StringSplitOptions.RemoveEmptyEntries);
+ foreach (var line in lines)
+ {
+ if (line.StartsWith("oid sha256:", StringComparison.OrdinalIgnoreCase))
+ {
+ return line.Substring("oid sha256:".Length).Trim();
+ }
+ }
+
+ return null;
}
///
@@ -172,13 +214,26 @@ public static async Task> GetDownloadFilesFromHuggingFace
}
return hfFiles.Where(f => f.Type != "directory").Select(f =>
- new ModelFileDetails()
+ {
+ string? sha256 = null;
+ if (f.Lfs?.Oid != null)
+ {
+ sha256 = f.Lfs.Oid;
+ if (sha256.StartsWith("sha256:", StringComparison.OrdinalIgnoreCase))
+ {
+ sha256 = sha256.Substring(7);
+ }
+ }
+
+ return new ModelFileDetails()
{
DownloadUrl = $"https://huggingface.co/{hfUrl.Organization}/{hfUrl.Repo}/resolve/{hfUrl.Ref}/{f.Path}",
Size = f.Size,
Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(),
- Path = f.Path
- }).ToList();
+ Path = f.Path,
+ Sha256 = sha256
+ };
+ }).ToList();
}
///
diff --git a/AIDevGallery.Utils/SourceGenerationContext.cs b/AIDevGallery.Utils/SourceGenerationContext.cs
index 2fa31fda..6c02f5f5 100644
--- a/AIDevGallery.Utils/SourceGenerationContext.cs
+++ b/AIDevGallery.Utils/SourceGenerationContext.cs
@@ -9,6 +9,7 @@ namespace AIDevGallery.Utils;
[JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true)]
[JsonSerializable(typeof(List))]
[JsonSerializable(typeof(List))]
+[JsonSerializable(typeof(HuggingFaceLfsInfo))]
internal partial class SourceGenerationContext : JsonSerializerContext
{
}
\ No newline at end of file
diff --git a/AIDevGallery/Controls/DownloadProgressList.xaml b/AIDevGallery/Controls/DownloadProgressList.xaml
index 50bd8c8d..38a238d3 100644
--- a/AIDevGallery/Controls/DownloadProgressList.xaml
+++ b/AIDevGallery/Controls/DownloadProgressList.xaml
@@ -169,6 +169,24 @@
Tag="{x:Bind}"
ToolTipService.ToolTip="Cancel"
Visibility="{x:Bind vm:DownloadableModel.VisibleWhenDownloading(Status), Mode=OneWay}" />
+
diff --git a/AIDevGallery/Controls/DownloadProgressList.xaml.cs b/AIDevGallery/Controls/DownloadProgressList.xaml.cs
index 1cbfe46c..0baf849b 100644
--- a/AIDevGallery/Controls/DownloadProgressList.xaml.cs
+++ b/AIDevGallery/Controls/DownloadProgressList.xaml.cs
@@ -80,10 +80,20 @@ private void ClearHistory_Click(object sender, RoutedEventArgs e)
{
foreach (DownloadableModel model in downloadProgresses.ToList())
{
- if (model.Status is DownloadStatus.Completed or DownloadStatus.Canceled)
+ if (model.Status is DownloadStatus.Completed or DownloadStatus.Canceled or DownloadStatus.VerificationFailed)
{
downloadProgresses.Remove(model);
}
}
}
+
+ private void VerificationFailedClicked(object sender, RoutedEventArgs e)
+ {
+ // Retry download when verification failed
+ if (sender is Button button && button.Tag is DownloadableModel downloadableModel)
+ {
+ downloadProgresses.Remove(downloadableModel);
+ App.ModelDownloadQueue.AddModel(downloadableModel.ModelDetails);
+ }
+ }
}
\ No newline at end of file
diff --git a/AIDevGallery/Telemetry/Events/ModelIntegrityVerificationFailedEvent.cs b/AIDevGallery/Telemetry/Events/ModelIntegrityVerificationFailedEvent.cs
new file mode 100644
index 00000000..c9cbfb1a
--- /dev/null
+++ b/AIDevGallery/Telemetry/Events/ModelIntegrityVerificationFailedEvent.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using Microsoft.Diagnostics.Telemetry;
+using Microsoft.Diagnostics.Telemetry.Internal;
+using System;
+using System.Diagnostics.Tracing;
+
+namespace AIDevGallery.Telemetry.Events;
+
+[EventData]
+internal class ModelIntegrityVerificationFailedEvent : EventBase
+{
+ internal ModelIntegrityVerificationFailedEvent(string modelUrl, string fileName, string verificationType, string expectedValue, string actualValue)
+ {
+ ModelUrl = modelUrl;
+ FileName = fileName;
+ VerificationType = verificationType;
+ ExpectedValue = expectedValue;
+ ActualValue = actualValue;
+ EventTime = DateTime.UtcNow;
+ }
+
+ public string ModelUrl { get; private set; }
+ public string FileName { get; private set; }
+ public string VerificationType { get; private set; }
+ public string ExpectedValue { get; private set; }
+ public string ActualValue { get; private set; }
+ public DateTime EventTime { get; private set; }
+
+ public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage;
+
+ public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings)
+ {
+ }
+
+ public static void Log(string modelUrl, string fileName, string verificationType, string expectedValue, string actualValue)
+ {
+ TelemetryFactory.Get().LogError(
+ "ModelIntegrityVerificationFailed_Event",
+ LogLevel.Info,
+ new ModelIntegrityVerificationFailedEvent(modelUrl, fileName, verificationType, expectedValue, actualValue));
+ }
+}
\ No newline at end of file
diff --git a/AIDevGallery/Utils/ModelDownload.cs b/AIDevGallery/Utils/ModelDownload.cs
index 6dfd7dc7..201fa8c1 100644
--- a/AIDevGallery/Utils/ModelDownload.cs
+++ b/AIDevGallery/Utils/ModelDownload.cs
@@ -6,9 +6,11 @@
using AIDevGallery.Telemetry.Events;
using System;
using System.Collections.Generic;
+using System.Globalization;
using System.IO;
using System.Linq;
using System.Net.Http;
+using System.Security.Cryptography;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
@@ -50,6 +52,22 @@ protected set
}
}
+ private string? _verificationFailureMessage;
+ public string? VerificationFailureMessage
+ {
+ get => _verificationFailureMessage;
+ protected set
+ {
+ _verificationFailureMessage = value;
+ StateChanged?.Invoke(this, new ModelDownloadEventArgs
+ {
+ Progress = DownloadProgress,
+ Status = DownloadStatus,
+ VerificationFailureMessage = _verificationFailureMessage
+ });
+ }
+ }
+
protected CancellationTokenSource CancellationTokenSource { get; }
public void Dispose()
@@ -108,12 +126,15 @@ public override async Task StartDownload()
if (cachedModel == null)
{
- DownloadStatus = DownloadStatus.Canceled;
-
- var localPath = ModelUrl.GetLocalPath(App.AppData.ModelCachePath);
- if (Directory.Exists(localPath))
+ if (DownloadStatus != DownloadStatus.VerificationFailed)
{
- Directory.Delete(localPath, true);
+ DownloadStatus = DownloadStatus.Canceled;
+
+ var localPath = ModelUrl.GetLocalPath(App.AppData.ModelCachePath);
+ if (Directory.Exists(localPath))
+ {
+ Directory.Delete(localPath, true);
+ }
}
return false;
@@ -130,7 +151,7 @@ public override void CancelDownload()
DownloadStatus = DownloadStatus.Canceled;
}
- private async Task DownloadModel(string cacheDir, IProgress? progress = null)
+ private async Task DownloadModel(string cacheDir, IProgress? progress = null)
{
ModelUrl url;
List filesToDownload;
@@ -171,6 +192,9 @@ private async Task DownloadModel(string cacheDir, IProgress?
using var client = new HttpClient();
+ // Track files that need verification
+ List<(string FilePath, ModelFileDetails FileDetails)> filesToVerify = [];
+
foreach (var downloadableFile in filesToDownload)
{
if (downloadableFile.DownloadUrl == null)
@@ -187,6 +211,12 @@ private async Task DownloadModel(string cacheDir, IProgress?
var existingFileInfo = new FileInfo(existingFile);
if (existingFileInfo.Length == downloadableFile.Size)
{
+ // Still need to verify existing files if they have a hash
+ if (downloadableFile.ShouldVerifyIntegrity && downloadableFile.HasVerificationHash)
+ {
+ filesToVerify.Add((filePath, downloadableFile));
+ }
+
continue;
}
}
@@ -201,16 +231,85 @@ private async Task DownloadModel(string cacheDir, IProgress?
var fileInfo = new FileInfo(filePath);
if (fileInfo.Length != downloadableFile.Size)
{
- // file did not download properly, should retry
+ // Size mismatch - log telemetry
+ ModelIntegrityVerificationFailedEvent.Log(
+ Details.Url,
+ downloadableFile.Name ?? filePath,
+ verificationType: "Size",
+ expectedValue: downloadableFile.Size.ToString(CultureInfo.InvariantCulture),
+ actualValue: fileInfo.Length.ToString(CultureInfo.InvariantCulture));
+ VerificationFailureMessage = $"Size verification failed for: {downloadableFile.Name}";
+ DownloadStatus = DownloadStatus.VerificationFailed;
+
+ var localPath = url.GetLocalPath(cacheDir);
+ if (Directory.Exists(localPath))
+ {
+ Directory.Delete(localPath, true);
+ }
+
+ return null;
+ }
+
+ // Add to verification list if it's a main model file with hash
+ if (downloadableFile.ShouldVerifyIntegrity && downloadableFile.HasVerificationHash)
+ {
+ filesToVerify.Add((filePath, downloadableFile));
}
bytesDownloaded += downloadableFile.Size;
}
+ // Verify integrity of main model files
+ if (filesToVerify.Count > 0)
+ {
+ DownloadStatus = DownloadStatus.Verifying;
+
+ foreach (var (filePath, fileDetails) in filesToVerify)
+ {
+ if (string.IsNullOrEmpty(fileDetails.Sha256))
+ {
+ continue;
+ }
+
+ var expectedHash = fileDetails.Sha256;
+ var actualHash = await ComputeSha256Async(filePath, cancellationToken);
+ var verified = string.Equals(actualHash, expectedHash, StringComparison.OrdinalIgnoreCase);
+
+ if (!verified)
+ {
+ ModelIntegrityVerificationFailedEvent.Log(
+ Details.Url,
+ fileDetails.Name ?? filePath,
+ verificationType: "SHA256",
+ expectedValue: expectedHash,
+ actualValue: actualHash);
+ VerificationFailureMessage = $"Integrity verification failed for: {fileDetails.Name ?? filePath}";
+ DownloadStatus = DownloadStatus.VerificationFailed;
+
+ // Delete the downloaded files
+ var localPath = url.GetLocalPath(cacheDir);
+ if (Directory.Exists(localPath))
+ {
+ Directory.Delete(localPath, true);
+ }
+
+ return null;
+ }
+ }
+ }
+
var modelDirectory = url.GetLocalPath(cacheDir);
return new CachedModel(Details, url.IsFile ? $"{modelDirectory}\\{filesToDownload.First().Name}" : modelDirectory, url.IsFile, modelSize);
}
+
+ private static async Task ComputeSha256Async(string filePath, CancellationToken cancellationToken)
+ {
+ using var sha256 = SHA256.Create();
+ using var stream = new FileStream(filePath, FileMode.Open, FileAccess.Read, FileShare.Read, 81920, useAsync: true);
+ var hashBytes = await sha256.ComputeHashAsync(stream, cancellationToken);
+ return Convert.ToHexString(hashBytes).ToLowerInvariant();
+ }
}
internal class FoundryLocalModelDownload : ModelDownload
@@ -263,12 +362,15 @@ internal enum DownloadStatus
{
Waiting,
InProgress,
+ Verifying,
Completed,
- Canceled
+ Canceled,
+ VerificationFailed
}
internal class ModelDownloadEventArgs
{
public required float Progress { get; init; }
public required DownloadStatus Status { get; init; }
+ public string? VerificationFailureMessage { get; init; }
}
\ No newline at end of file
diff --git a/AIDevGallery/ViewModels/DownloadableModel.cs b/AIDevGallery/ViewModels/DownloadableModel.cs
index 5304d984..aa946bed 100644
--- a/AIDevGallery/ViewModels/DownloadableModel.cs
+++ b/AIDevGallery/ViewModels/DownloadableModel.cs
@@ -22,6 +22,9 @@ internal partial class DownloadableModel : BaseModel
[ObservableProperty]
public partial DownloadStatus Status { get; set; } = DownloadStatus.Waiting;
+ [ObservableProperty]
+ public partial string? VerificationFailureMessage { get; set; }
+
public bool IsDownloadEnabled => Compatibility.CompatibilityState != ModelCompatibilityState.NotCompatible;
private ModelDownload? _modelDownload;
@@ -52,6 +55,7 @@ public ModelDownload? ModelDownload
_modelDownload.StateChanged += ModelDownload_StateChanged;
Status = _modelDownload.DownloadStatus;
Progress = _modelDownload.DownloadProgress;
+ VerificationFailureMessage = _modelDownload.VerificationFailureMessage;
CanDownload = false;
}
}
@@ -94,7 +98,13 @@ private void ModelDownload_StateChanged(object? sender, ModelDownloadEventArgs e
_progressTimer.Start();
}
- if (e.Progress == 1)
+ if (e.Progress == 1 && e.Status == DownloadStatus.InProgress)
+ {
+ // Download complete, but may still need verification
+ return;
+ }
+
+ if (e.Status == DownloadStatus.Completed)
{
Status = DownloadStatus.Completed;
ModelDownload = null;
@@ -106,6 +116,19 @@ private void ModelDownload_StateChanged(object? sender, ModelDownloadEventArgs e
ModelDownload = null;
Progress = 0;
}
+
+ if (e.Status == DownloadStatus.VerificationFailed)
+ {
+ Status = DownloadStatus.VerificationFailed;
+ VerificationFailureMessage = e.VerificationFailureMessage;
+ ModelDownload = null;
+ Progress = 0;
+ }
+
+ if (e.Status == DownloadStatus.Verifying)
+ {
+ Status = DownloadStatus.Verifying;
+ }
}
private void ProgressTimer_Tick(object? sender, object e)
@@ -115,6 +138,7 @@ private void ProgressTimer_Tick(object? sender, object e)
{
Progress = ModelDownload.DownloadProgress * 100;
Status = ModelDownload.DownloadStatus;
+ VerificationFailureMessage = ModelDownload.VerificationFailureMessage;
}
}
@@ -144,7 +168,7 @@ public static Visibility DownloadStatusButtonVisibility(ModelDownload download)
public static Visibility VisibleWhenDownloading(DownloadStatus status)
{
- return status is DownloadStatus.InProgress or DownloadStatus.Waiting ? Visibility.Visible : Visibility.Collapsed;
+ return status is DownloadStatus.InProgress or DownloadStatus.Waiting or DownloadStatus.Verifying ? Visibility.Visible : Visibility.Collapsed;
}
public static Visibility VisibleWhenCanceled(DownloadStatus status)
@@ -167,6 +191,11 @@ public static Visibility VisibleWhenDownloaded(DownloadStatus status, ModelDetai
return Visibility.Visible;
}
+ public static Visibility VisibleWhenVerificationFailed(DownloadStatus status)
+ {
+ return status == DownloadStatus.VerificationFailed ? Visibility.Visible : Visibility.Collapsed;
+ }
+
public static Visibility BoolToVisibilityInverse(bool value)
{
return value ? Visibility.Collapsed : Visibility.Visible;
@@ -190,10 +219,14 @@ public static string StatusToText(DownloadStatus status)
return "Waiting..";
case DownloadStatus.InProgress:
return "Downloading..";
+ case DownloadStatus.Verifying:
+ return "Verifying integrity..";
case DownloadStatus.Completed:
return "Downloaded";
case DownloadStatus.Canceled:
return "Canceled";
+ case DownloadStatus.VerificationFailed:
+ return "Verification failed";
default:
return string.Empty;
}