diff --git a/AIDevGallery.Utils/GitHubModelFileDetails.cs b/AIDevGallery.Utils/GitHubModelFileDetails.cs index fca956e8..beaef13f 100644 --- a/AIDevGallery.Utils/GitHubModelFileDetails.cs +++ b/AIDevGallery.Utils/GitHubModelFileDetails.cs @@ -45,4 +45,17 @@ public class GitHubModelFileDetails /// [JsonPropertyName("type")] public string? Type { get; init; } + + /// + /// Gets the encoded content of the file. + /// For LFS files, this contains the LFS pointer with SHA256. + /// + [JsonPropertyName("content")] + public string? Content { get; init; } + + /// + /// Gets the encoding of the content (usually "base64"). + /// + [JsonPropertyName("encoding")] + public string? Encoding { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs b/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs index 9a982820..942f767e 100644 --- a/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs +++ b/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs @@ -27,4 +27,28 @@ public class HuggingFaceModelFileDetails /// [JsonPropertyName("path")] public string? Path { get; init; } + + /// + /// Gets the LFS (Large File Storage) information for the file. + /// + [JsonPropertyName("lfs")] + public HuggingFaceLfsInfo? Lfs { get; init; } +} + +/// +/// LFS (Large File Storage) information for a Hugging Face file. +/// +public class HuggingFaceLfsInfo +{ + /// + /// Gets the OID (SHA256 hash) of the file. Format: "sha256:abc123..." + /// + [JsonPropertyName("oid")] + public string? Oid { get; init; } + + /// + /// Gets the size of the file in LFS. + /// + [JsonPropertyName("size")] + public long Size { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.Utils/ModelFileDetails.cs b/AIDevGallery.Utils/ModelFileDetails.cs index f5e717df..1cce49bb 100644 --- a/AIDevGallery.Utils/ModelFileDetails.cs +++ b/AIDevGallery.Utils/ModelFileDetails.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; + namespace AIDevGallery.Utils; /// @@ -27,4 +29,25 @@ public class ModelFileDetails /// Gets the relative path to the file /// public string? Path { get; init; } + + /// + /// Gets the expected SHA256 hash of the file. + /// For Hugging Face: from LFS oid field. + /// For GitHub: from LFS pointer file. + /// + public string? Sha256 { get; init; } + + /// + /// Gets a value indicating whether this file should be verified for integrity. + /// + public bool ShouldVerifyIntegrity => Name != null && + (Name.EndsWith(".onnx", StringComparison.OrdinalIgnoreCase) || + Name.EndsWith(".onnx.data", StringComparison.OrdinalIgnoreCase) || + Name.EndsWith(".gguf", StringComparison.OrdinalIgnoreCase) || + Name.EndsWith(".safetensors", StringComparison.OrdinalIgnoreCase)); + + /// + /// Gets a value indicating whether this file has a hash available for verification. + /// + public bool HasVerificationHash => !string.IsNullOrEmpty(Sha256); } \ No newline at end of file diff --git a/AIDevGallery.Utils/ModelInformationHelper.cs b/AIDevGallery.Utils/ModelInformationHelper.cs index 08293a94..d5e69e59 100644 --- a/AIDevGallery.Utils/ModelInformationHelper.cs +++ b/AIDevGallery.Utils/ModelInformationHelper.cs @@ -58,13 +58,55 @@ public static async Task> GetDownloadFilesFromGitHub(GitH } return files.Select(f => - new ModelFileDetails() + { + string? sha256 = null; + + if (f.Content != null && f.Encoding == "base64") + { + try + { + var decodedContent = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(f.Content)); + sha256 = ParseLfsPointerSha256(decodedContent); + } + catch (FormatException) + { + Debug.WriteLine($"Failed to decode base64 content for {f.Path}"); + } + } + + return new ModelFileDetails() { DownloadUrl = f.DownloadUrl, Size = f.Size, Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(), - Path = f.Path - }).ToList(); + Path = f.Path, + Sha256 = sha256 + }; + }).ToList(); + } + + /// + /// Parses a Git LFS pointer file content to extract the SHA256 hash. + /// + /// The content of the LFS pointer file. + /// The SHA256 hash if found, otherwise null. + private static string? ParseLfsPointerSha256(string lfsPointerContent) + { + if (string.IsNullOrEmpty(lfsPointerContent)) + { + return null; + } + + var lines = lfsPointerContent.Split(['\n', '\r'], StringSplitOptions.RemoveEmptyEntries); + foreach (var line in lines) + { + if (line.StartsWith("oid sha256:", StringComparison.OrdinalIgnoreCase)) + { + return line.Substring("oid sha256:".Length).Trim(); + } + } + + return null; } /// @@ -172,13 +214,26 @@ public static async Task> GetDownloadFilesFromHuggingFace } return hfFiles.Where(f => f.Type != "directory").Select(f => - new ModelFileDetails() + { + string? sha256 = null; + if (f.Lfs?.Oid != null) + { + sha256 = f.Lfs.Oid; + if (sha256.StartsWith("sha256:", StringComparison.OrdinalIgnoreCase)) + { + sha256 = sha256.Substring(7); + } + } + + return new ModelFileDetails() { DownloadUrl = $"https://huggingface.co/{hfUrl.Organization}/{hfUrl.Repo}/resolve/{hfUrl.Ref}/{f.Path}", Size = f.Size, Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(), - Path = f.Path - }).ToList(); + Path = f.Path, + Sha256 = sha256 + }; + }).ToList(); } /// diff --git a/AIDevGallery.Utils/SourceGenerationContext.cs b/AIDevGallery.Utils/SourceGenerationContext.cs index 2fa31fda..6c02f5f5 100644 --- a/AIDevGallery.Utils/SourceGenerationContext.cs +++ b/AIDevGallery.Utils/SourceGenerationContext.cs @@ -9,6 +9,7 @@ namespace AIDevGallery.Utils; [JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true)] [JsonSerializable(typeof(List))] [JsonSerializable(typeof(List))] +[JsonSerializable(typeof(HuggingFaceLfsInfo))] internal partial class SourceGenerationContext : JsonSerializerContext { } \ No newline at end of file diff --git a/AIDevGallery/Controls/DownloadProgressList.xaml b/AIDevGallery/Controls/DownloadProgressList.xaml index 50bd8c8d..38a238d3 100644 --- a/AIDevGallery/Controls/DownloadProgressList.xaml +++ b/AIDevGallery/Controls/DownloadProgressList.xaml @@ -169,6 +169,24 @@ Tag="{x:Bind}" ToolTipService.ToolTip="Cancel" Visibility="{x:Bind vm:DownloadableModel.VisibleWhenDownloading(Status), Mode=OneWay}" /> + diff --git a/AIDevGallery/Controls/DownloadProgressList.xaml.cs b/AIDevGallery/Controls/DownloadProgressList.xaml.cs index 1cbfe46c..0baf849b 100644 --- a/AIDevGallery/Controls/DownloadProgressList.xaml.cs +++ b/AIDevGallery/Controls/DownloadProgressList.xaml.cs @@ -80,10 +80,20 @@ private void ClearHistory_Click(object sender, RoutedEventArgs e) { foreach (DownloadableModel model in downloadProgresses.ToList()) { - if (model.Status is DownloadStatus.Completed or DownloadStatus.Canceled) + if (model.Status is DownloadStatus.Completed or DownloadStatus.Canceled or DownloadStatus.VerificationFailed) { downloadProgresses.Remove(model); } } } + + private void VerificationFailedClicked(object sender, RoutedEventArgs e) + { + // Retry download when verification failed + if (sender is Button button && button.Tag is DownloadableModel downloadableModel) + { + downloadProgresses.Remove(downloadableModel); + App.ModelDownloadQueue.AddModel(downloadableModel.ModelDetails); + } + } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelIntegrityVerificationFailedEvent.cs b/AIDevGallery/Telemetry/Events/ModelIntegrityVerificationFailedEvent.cs new file mode 100644 index 00000000..c9cbfb1a --- /dev/null +++ b/AIDevGallery/Telemetry/Events/ModelIntegrityVerificationFailedEvent.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.Diagnostics.Telemetry; +using Microsoft.Diagnostics.Telemetry.Internal; +using System; +using System.Diagnostics.Tracing; + +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelIntegrityVerificationFailedEvent : EventBase +{ + internal ModelIntegrityVerificationFailedEvent(string modelUrl, string fileName, string verificationType, string expectedValue, string actualValue) + { + ModelUrl = modelUrl; + FileName = fileName; + VerificationType = verificationType; + ExpectedValue = expectedValue; + ActualValue = actualValue; + EventTime = DateTime.UtcNow; + } + + public string ModelUrl { get; private set; } + public string FileName { get; private set; } + public string VerificationType { get; private set; } + public string ExpectedValue { get; private set; } + public string ActualValue { get; private set; } + public DateTime EventTime { get; private set; } + + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + } + + public static void Log(string modelUrl, string fileName, string verificationType, string expectedValue, string actualValue) + { + TelemetryFactory.Get().LogError( + "ModelIntegrityVerificationFailed_Event", + LogLevel.Info, + new ModelIntegrityVerificationFailedEvent(modelUrl, fileName, verificationType, expectedValue, actualValue)); + } +} \ No newline at end of file diff --git a/AIDevGallery/Utils/ModelDownload.cs b/AIDevGallery/Utils/ModelDownload.cs index 6dfd7dc7..201fa8c1 100644 --- a/AIDevGallery/Utils/ModelDownload.cs +++ b/AIDevGallery/Utils/ModelDownload.cs @@ -6,9 +6,11 @@ using AIDevGallery.Telemetry.Events; using System; using System.Collections.Generic; +using System.Globalization; using System.IO; using System.Linq; using System.Net.Http; +using System.Security.Cryptography; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; @@ -50,6 +52,22 @@ protected set } } + private string? _verificationFailureMessage; + public string? VerificationFailureMessage + { + get => _verificationFailureMessage; + protected set + { + _verificationFailureMessage = value; + StateChanged?.Invoke(this, new ModelDownloadEventArgs + { + Progress = DownloadProgress, + Status = DownloadStatus, + VerificationFailureMessage = _verificationFailureMessage + }); + } + } + protected CancellationTokenSource CancellationTokenSource { get; } public void Dispose() @@ -108,12 +126,15 @@ public override async Task StartDownload() if (cachedModel == null) { - DownloadStatus = DownloadStatus.Canceled; - - var localPath = ModelUrl.GetLocalPath(App.AppData.ModelCachePath); - if (Directory.Exists(localPath)) + if (DownloadStatus != DownloadStatus.VerificationFailed) { - Directory.Delete(localPath, true); + DownloadStatus = DownloadStatus.Canceled; + + var localPath = ModelUrl.GetLocalPath(App.AppData.ModelCachePath); + if (Directory.Exists(localPath)) + { + Directory.Delete(localPath, true); + } } return false; @@ -130,7 +151,7 @@ public override void CancelDownload() DownloadStatus = DownloadStatus.Canceled; } - private async Task DownloadModel(string cacheDir, IProgress? progress = null) + private async Task DownloadModel(string cacheDir, IProgress? progress = null) { ModelUrl url; List filesToDownload; @@ -171,6 +192,9 @@ private async Task DownloadModel(string cacheDir, IProgress? using var client = new HttpClient(); + // Track files that need verification + List<(string FilePath, ModelFileDetails FileDetails)> filesToVerify = []; + foreach (var downloadableFile in filesToDownload) { if (downloadableFile.DownloadUrl == null) @@ -187,6 +211,12 @@ private async Task DownloadModel(string cacheDir, IProgress? var existingFileInfo = new FileInfo(existingFile); if (existingFileInfo.Length == downloadableFile.Size) { + // Still need to verify existing files if they have a hash + if (downloadableFile.ShouldVerifyIntegrity && downloadableFile.HasVerificationHash) + { + filesToVerify.Add((filePath, downloadableFile)); + } + continue; } } @@ -201,16 +231,85 @@ private async Task DownloadModel(string cacheDir, IProgress? var fileInfo = new FileInfo(filePath); if (fileInfo.Length != downloadableFile.Size) { - // file did not download properly, should retry + // Size mismatch - log telemetry + ModelIntegrityVerificationFailedEvent.Log( + Details.Url, + downloadableFile.Name ?? filePath, + verificationType: "Size", + expectedValue: downloadableFile.Size.ToString(CultureInfo.InvariantCulture), + actualValue: fileInfo.Length.ToString(CultureInfo.InvariantCulture)); + VerificationFailureMessage = $"Size verification failed for: {downloadableFile.Name}"; + DownloadStatus = DownloadStatus.VerificationFailed; + + var localPath = url.GetLocalPath(cacheDir); + if (Directory.Exists(localPath)) + { + Directory.Delete(localPath, true); + } + + return null; + } + + // Add to verification list if it's a main model file with hash + if (downloadableFile.ShouldVerifyIntegrity && downloadableFile.HasVerificationHash) + { + filesToVerify.Add((filePath, downloadableFile)); } bytesDownloaded += downloadableFile.Size; } + // Verify integrity of main model files + if (filesToVerify.Count > 0) + { + DownloadStatus = DownloadStatus.Verifying; + + foreach (var (filePath, fileDetails) in filesToVerify) + { + if (string.IsNullOrEmpty(fileDetails.Sha256)) + { + continue; + } + + var expectedHash = fileDetails.Sha256; + var actualHash = await ComputeSha256Async(filePath, cancellationToken); + var verified = string.Equals(actualHash, expectedHash, StringComparison.OrdinalIgnoreCase); + + if (!verified) + { + ModelIntegrityVerificationFailedEvent.Log( + Details.Url, + fileDetails.Name ?? filePath, + verificationType: "SHA256", + expectedValue: expectedHash, + actualValue: actualHash); + VerificationFailureMessage = $"Integrity verification failed for: {fileDetails.Name ?? filePath}"; + DownloadStatus = DownloadStatus.VerificationFailed; + + // Delete the downloaded files + var localPath = url.GetLocalPath(cacheDir); + if (Directory.Exists(localPath)) + { + Directory.Delete(localPath, true); + } + + return null; + } + } + } + var modelDirectory = url.GetLocalPath(cacheDir); return new CachedModel(Details, url.IsFile ? $"{modelDirectory}\\{filesToDownload.First().Name}" : modelDirectory, url.IsFile, modelSize); } + + private static async Task ComputeSha256Async(string filePath, CancellationToken cancellationToken) + { + using var sha256 = SHA256.Create(); + using var stream = new FileStream(filePath, FileMode.Open, FileAccess.Read, FileShare.Read, 81920, useAsync: true); + var hashBytes = await sha256.ComputeHashAsync(stream, cancellationToken); + return Convert.ToHexString(hashBytes).ToLowerInvariant(); + } } internal class FoundryLocalModelDownload : ModelDownload @@ -263,12 +362,15 @@ internal enum DownloadStatus { Waiting, InProgress, + Verifying, Completed, - Canceled + Canceled, + VerificationFailed } internal class ModelDownloadEventArgs { public required float Progress { get; init; } public required DownloadStatus Status { get; init; } + public string? VerificationFailureMessage { get; init; } } \ No newline at end of file diff --git a/AIDevGallery/ViewModels/DownloadableModel.cs b/AIDevGallery/ViewModels/DownloadableModel.cs index 5304d984..aa946bed 100644 --- a/AIDevGallery/ViewModels/DownloadableModel.cs +++ b/AIDevGallery/ViewModels/DownloadableModel.cs @@ -22,6 +22,9 @@ internal partial class DownloadableModel : BaseModel [ObservableProperty] public partial DownloadStatus Status { get; set; } = DownloadStatus.Waiting; + [ObservableProperty] + public partial string? VerificationFailureMessage { get; set; } + public bool IsDownloadEnabled => Compatibility.CompatibilityState != ModelCompatibilityState.NotCompatible; private ModelDownload? _modelDownload; @@ -52,6 +55,7 @@ public ModelDownload? ModelDownload _modelDownload.StateChanged += ModelDownload_StateChanged; Status = _modelDownload.DownloadStatus; Progress = _modelDownload.DownloadProgress; + VerificationFailureMessage = _modelDownload.VerificationFailureMessage; CanDownload = false; } } @@ -94,7 +98,13 @@ private void ModelDownload_StateChanged(object? sender, ModelDownloadEventArgs e _progressTimer.Start(); } - if (e.Progress == 1) + if (e.Progress == 1 && e.Status == DownloadStatus.InProgress) + { + // Download complete, but may still need verification + return; + } + + if (e.Status == DownloadStatus.Completed) { Status = DownloadStatus.Completed; ModelDownload = null; @@ -106,6 +116,19 @@ private void ModelDownload_StateChanged(object? sender, ModelDownloadEventArgs e ModelDownload = null; Progress = 0; } + + if (e.Status == DownloadStatus.VerificationFailed) + { + Status = DownloadStatus.VerificationFailed; + VerificationFailureMessage = e.VerificationFailureMessage; + ModelDownload = null; + Progress = 0; + } + + if (e.Status == DownloadStatus.Verifying) + { + Status = DownloadStatus.Verifying; + } } private void ProgressTimer_Tick(object? sender, object e) @@ -115,6 +138,7 @@ private void ProgressTimer_Tick(object? sender, object e) { Progress = ModelDownload.DownloadProgress * 100; Status = ModelDownload.DownloadStatus; + VerificationFailureMessage = ModelDownload.VerificationFailureMessage; } } @@ -144,7 +168,7 @@ public static Visibility DownloadStatusButtonVisibility(ModelDownload download) public static Visibility VisibleWhenDownloading(DownloadStatus status) { - return status is DownloadStatus.InProgress or DownloadStatus.Waiting ? Visibility.Visible : Visibility.Collapsed; + return status is DownloadStatus.InProgress or DownloadStatus.Waiting or DownloadStatus.Verifying ? Visibility.Visible : Visibility.Collapsed; } public static Visibility VisibleWhenCanceled(DownloadStatus status) @@ -167,6 +191,11 @@ public static Visibility VisibleWhenDownloaded(DownloadStatus status, ModelDetai return Visibility.Visible; } + public static Visibility VisibleWhenVerificationFailed(DownloadStatus status) + { + return status == DownloadStatus.VerificationFailed ? Visibility.Visible : Visibility.Collapsed; + } + public static Visibility BoolToVisibilityInverse(bool value) { return value ? Visibility.Collapsed : Visibility.Visible; @@ -190,10 +219,14 @@ public static string StatusToText(DownloadStatus status) return "Waiting.."; case DownloadStatus.InProgress: return "Downloading.."; + case DownloadStatus.Verifying: + return "Verifying integrity.."; case DownloadStatus.Completed: return "Downloaded"; case DownloadStatus.Canceled: return "Canceled"; + case DownloadStatus.VerificationFailed: + return "Verification failed"; default: return string.Empty; }