Skip to content

Commit 2662120

Browse files
committed
RyzenAI execution provider
1 parent 43c6bb7 commit 2662120

File tree

4 files changed

+287
-0
lines changed

4 files changed

+287
-0
lines changed

TensorStack.Common/ModelConfig.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3+
using System.Collections.Generic;
34
using System.Text.Json.Serialization;
45

56
namespace TensorStack.Common
@@ -14,6 +15,9 @@ public record ModelConfig
1415
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
1516
public bool IsOptimizationSupported { get; set; }
1617

18+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
19+
public Dictionary<string, string> SessionOptions { get; set; }
20+
1721
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
1822
public ExecutionProvider ExecutionProvider
1923
{
17.7 MB
Binary file not shown.
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using Microsoft.ML.OnnxRuntime;
4+
using System.Collections.Generic;
5+
using System.IO;
6+
using System.Linq;
7+
using TensorStack.Common;
8+
9+
namespace TensorStack.Providers
10+
{
11+
public static class Provider
12+
{
13+
private static bool _isInitialized;
14+
private const string _providerName = "DMLExecutionProvider";
15+
16+
17+
/// <summary>
18+
/// Initializes the Provider
19+
/// </summary>
20+
public static void Initialize()
21+
{
22+
if (_isInitialized)
23+
return;
24+
25+
_isInitialized = true;
26+
DeviceManager.Initialize(_providerName);
27+
}
28+
29+
30+
/// <summary>
31+
/// Initializes the Provider with the specified environment options.
32+
/// </summary>
33+
/// <param name="environmentOptions">The environment options.</param>
34+
public static void Initialize(EnvironmentCreationOptions environmentOptions)
35+
{
36+
if (_isInitialized)
37+
return;
38+
39+
_isInitialized = true;
40+
DeviceManager.Initialize(environmentOptions, _providerName);
41+
}
42+
43+
44+
/// <summary>
45+
/// Gets the name of the provider.
46+
/// </summary>
47+
public static string ProviderName => _providerName;
48+
49+
50+
/// <summary>
51+
/// Gets the devices.
52+
/// </summary>
53+
public static IReadOnlyList<Device> GetDevices()
54+
{
55+
Initialize(); // Ensure Initialized
56+
return DeviceManager.Devices;
57+
}
58+
59+
60+
/// <summary>
61+
/// Gets the best device.
62+
/// </summary>
63+
/// <param name="deviceType">Type of the device.</param>
64+
public static Device GetDevice()
65+
{
66+
return GetDevice(DeviceType.NPU);
67+
}
68+
69+
70+
/// <summary>
71+
/// Gets the best device.
72+
/// </summary>
73+
/// <param name="deviceType">Type of the device.</param>
74+
public static Device GetDevice(DeviceType deviceType)
75+
{
76+
return GetDevices().FirstOrDefault(x => x.Type == deviceType);
77+
}
78+
79+
80+
/// <summary>
81+
/// Gets the Device by DeviceId.
82+
/// </summary>
83+
/// <param name="deviceType">Type of the device.</param>
84+
/// <param name="deviceId">The device identifier.</param>
85+
public static Device GetDevice(DeviceType deviceType, int deviceId)
86+
{
87+
return GetDevices().FirstOrDefault(x => x.Type == deviceType && x.DeviceId == deviceId);
88+
}
89+
90+
91+
/// <summary>
92+
/// Gets the DirectML provider this DeviceType.
93+
/// </summary>
94+
/// <param name="optimizationLevel">The optimization level.</param>
95+
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
96+
{
97+
return GetDevice().GetProvider(optimizationLevel);
98+
}
99+
100+
101+
/// <summary>
102+
/// Gets the DirectML provider this DeviceType.
103+
/// </summary>
104+
/// <param name="deviceType">Type of the device.</param>
105+
/// <param name="optimizationLevel">The optimization level.</param>
106+
public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
107+
{
108+
return GetDevice(deviceType).GetProvider(optimizationLevel);
109+
}
110+
111+
112+
/// <summary>
113+
/// Gets the DirectML provider this DeviceType, DeviceId.
114+
/// </summary>
115+
/// <param name="deviceType">Type of the device.</param>
116+
/// <param name="deviceId">The device identifier.</param>
117+
/// <param name="optimizationLevel">The optimization level.</param>
118+
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
119+
{
120+
return GetDevice(deviceType, deviceId).GetProvider(optimizationLevel);
121+
}
122+
123+
124+
/// <summary>
125+
/// Gets the DirectML provider for this Device.
126+
/// </summary>
127+
/// <param name="device">The device.</param>
128+
/// <param name="optimizationLevel">The optimization level.</param>
129+
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
130+
{
131+
if (device == null)
132+
return default;
133+
else if (device.Type == DeviceType.NPU)
134+
return CreateRyzenProvider(device.DeviceId, optimizationLevel);
135+
else if (device.Type == DeviceType.CPU)
136+
return CreateCPUProvider(optimizationLevel);
137+
138+
return CreateDMLProvider(device.DeviceId, optimizationLevel);
139+
}
140+
141+
142+
/// <summary>
143+
/// Gets the DirectML provider for this DeviceId.
144+
/// </summary>
145+
/// <param name="deviceId">The device identifier.</param>
146+
/// <param name="optimizationLevel">The optimization level.</param>
147+
private static ExecutionProvider CreateRyzenProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
148+
{
149+
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
150+
return new ExecutionProvider(_providerName, memoryInfo, configuration =>
151+
{
152+
var sessionOptions = new SessionOptions
153+
{
154+
GraphOptimizationLevel = optimizationLevel
155+
};
156+
157+
var modelCache = Path.Combine(Path.GetDirectoryName(configuration.Path), ".cache");
158+
if (Directory.Exists(modelCache))
159+
sessionOptions.AddSessionConfigEntry("dd_cache", modelCache);
160+
161+
if (!configuration.SessionOptions.IsNullOrEmpty())
162+
{
163+
foreach (var sessionOption in configuration.SessionOptions)
164+
{
165+
sessionOptions.AddSessionConfigEntry(sessionOption.Key, sessionOption.Value);
166+
}
167+
}
168+
169+
sessionOptions.RegisterCustomOpLibrary("onnx_custom_ops.dll");
170+
sessionOptions.AppendExecutionProvider_CPU();
171+
sessionOptions.AppendExecutionProvider_CPU();
172+
return sessionOptions;
173+
});
174+
}
175+
176+
177+
/// <summary>
178+
/// Gets the DirectML provider for this DeviceId.
179+
/// </summary>
180+
/// <param name="deviceId">The device identifier.</param>
181+
/// <param name="optimizationLevel">The optimization level.</param>
182+
private static ExecutionProvider CreateDMLProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
183+
{
184+
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
185+
return new ExecutionProvider(_providerName, memoryInfo, configuration =>
186+
{
187+
var sessionOptions = new SessionOptions
188+
{
189+
GraphOptimizationLevel = optimizationLevel
190+
};
191+
192+
sessionOptions.AppendExecutionProvider_DML(deviceId);
193+
sessionOptions.AppendExecutionProvider_CPU();
194+
return sessionOptions;
195+
});
196+
}
197+
198+
199+
/// <summary>
200+
/// Gets the CPU provider.
201+
/// </summary>
202+
/// <param name="optimizationLevel">The optimization level.</param>
203+
/// <returns>ExecutionProvider.</returns>
204+
private static ExecutionProvider CreateCPUProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
205+
{
206+
return new ExecutionProvider(DeviceManager.CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
207+
{
208+
var sessionOptions = new SessionOptions
209+
{
210+
EnableCpuMemArena = true,
211+
EnableMemoryPattern = true,
212+
GraphOptimizationLevel = optimizationLevel
213+
};
214+
sessionOptions.AppendExecutionProvider_CPU();
215+
return sessionOptions;
216+
});
217+
}
218+
}
219+
220+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFrameworks>net9.0-windows10.0.17763.0;net10.0-windows10.0.17763.0</TargetFrameworks>
5+
<PlatformTarget>x64</PlatformTarget>
6+
</PropertyGroup>
7+
8+
<!--Projects-->
9+
<ItemGroup Condition=" '$(Configuration)' == 'Debug'">
10+
<ProjectReference Include="..\TensorStack.Common\TensorStack.Common.csproj" />
11+
</ItemGroup>
12+
13+
<!--Packages-->
14+
<ItemGroup Condition=" '$(Configuration)' == 'Release'">
15+
<PackageReference Include="TensorStack.Common" Version="$(Version)" />
16+
</ItemGroup>
17+
18+
<!--Common Packages-->
19+
<ItemGroup>
20+
<PackageReference Include="Microsoft.ML.OnnxRuntime.DirectML" Version="1.23.0" />
21+
</ItemGroup>
22+
23+
<PropertyGroup Label="Globals">
24+
<Microsoft_AI_DirectML_SkipLibraryCopy>True</Microsoft_AI_DirectML_SkipLibraryCopy>
25+
<Microsoft_AI_DirectML_SkipDebugLayerCopy>True</Microsoft_AI_DirectML_SkipDebugLayerCopy>
26+
</PropertyGroup>
27+
28+
<!--Binaries-->
29+
<ItemGroup>
30+
<Content Include="DirectML.dll"
31+
Pack="true"
32+
PackagePath="contentFiles\any\any\"
33+
CopyToOutputDirectory="Always"
34+
PackageCopyToOutput="true" />
35+
<None Include="DirectML.dll" Pack="false" />
36+
</ItemGroup>
37+
38+
39+
<!--Nuget Settings-->
40+
<PropertyGroup>
41+
<Title>$(AssemblyName)</Title>
42+
<PackageId>$(AssemblyName)</PackageId>
43+
<Product>$(AssemblyName)</Product>
44+
<PackageIcon>Icon.png</PackageIcon>
45+
<PackageReadmeFile>README.md</PackageReadmeFile>
46+
<Description>RyzenAI NPU backend for ONNX tensor computation.</Description>
47+
</PropertyGroup>
48+
<ItemGroup Condition="'$(Configuration)' == 'Debug'">
49+
<None Remove="README.md" />
50+
<None Remove="Icon.png" />
51+
</ItemGroup>
52+
<ItemGroup Condition="'$(Configuration)' == 'Release'">
53+
<None Include="README.md">
54+
<Pack>True</Pack>
55+
<PackagePath>\</PackagePath>
56+
</None>
57+
<None Include="..\Assets\Icon.png">
58+
<Pack>True</Pack>
59+
<PackagePath>\</PackagePath>
60+
</None>
61+
</ItemGroup>
62+
63+
</Project>

0 commit comments

Comments
 (0)