Skip to content

Commit cbd5256

Browse files
committed
RyzenAI fallback to GPU not CPU
1 parent b10865b commit cbd5256

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

TensorStack.Providers.RyzenAI/Provider.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,14 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimiza
118118

119119

120120
/// <summary>
121-
/// Gets the RyzenAI provider this DeviceType, DeviceId.
121+
/// Gets the RyzenAI provider for NPU if supported, else DirectML GPU fallback.
122122
/// </summary>
123123
/// <param name="deviceType">Type of the device.</param>
124124
/// <param name="deviceId">The device identifier.</param>
125125
/// <param name="optimizationLevel">The optimization level.</param>
126-
public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
126+
public static ExecutionProvider GetProvider(DeviceType deviceType, int fallbackDeviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
127127
{
128-
return GetDevice(deviceType, deviceId).GetProvider(optimizationLevel);
128+
return GetDevice(deviceType, fallbackDeviceId).GetProvider(optimizationLevel);
129129
}
130130

131131

@@ -146,13 +146,13 @@ public static ExecutionProvider GetProvider(this Device device, GraphOptimizatio
146146

147147

148148
/// <summary>
149-
/// Gets the RyzenAI provider for this DeviceId.
149+
/// Gets the RyzenAI provider for NPU if supported, else DirectML fallback.
150150
/// </summary>
151-
/// <param name="deviceId">The device identifier.</param>
151+
/// <param name="fallbackDeviceId">The fallback device identifier.</param>
152152
/// <param name="optimizationLevel">The optimization level.</param>
153-
private static ExecutionProvider CreateRyzenProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
153+
private static ExecutionProvider CreateRyzenProvider(int fallbackDeviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
154154
{
155-
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
155+
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.DeviceAllocator, fallbackDeviceId, OrtMemType.Default);
156156
return new ExecutionProvider(_providerName, memoryInfo, configuration =>
157157
{
158158
var sessionOptions = new SessionOptions
@@ -166,6 +166,7 @@ private static ExecutionProvider CreateRyzenProvider(int deviceId, GraphOptimiza
166166

167167
sessionOptions.AddSessionConfigEntries(configuration.SessionOptions);
168168
sessionOptions.RegisterCustomOpLibrary("onnx_custom_ops.dll");
169+
sessionOptions.AppendExecutionProvider_DML();
169170
sessionOptions.AppendExecutionProvider_CPU();
170171
return sessionOptions;
171172
});

TensorStack.sln

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorStack.TextGeneration"
2929
EndProject
3030
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{02EA681E-C7D8-13C7-8484-4AC65E1B71E8}"
3131
ProjectSection(SolutionItems) = preProject
32+
BuildDemos.bat = BuildDemos.bat
33+
BuildRelease.bat = BuildRelease.bat
3234
Directory.Build.props = Directory.Build.props
3335
EndProjectSection
3436
EndProject
@@ -50,6 +52,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorStack.Example.Extract
5052
EndProject
5153
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorStack.Example.TextGeneration", "Examples\TensorStack.Example.TextGeneration\TensorStack.Example.TextGeneration.csproj", "{0B9E97D9-76FA-43BF-8217-7DBE5536EB88}"
5254
EndProject
55+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorStack.Providers.RyzenAI", "TensorStack.Providers.RyzenAI\TensorStack.Providers.RyzenAI.csproj", "{93621B62-037D-44FA-9FA5-7A9A5566F7FE}"
56+
EndProject
5357
Global
5458
GlobalSection(SolutionConfigurationPlatforms) = preSolution
5559
Debug_CUDA|Any CPU = Debug_CUDA|Any CPU
@@ -218,6 +222,14 @@ Global
218222
{0B9E97D9-76FA-43BF-8217-7DBE5536EB88}.Release_CUDA|Any CPU.Build.0 = Release|Any CPU
219223
{0B9E97D9-76FA-43BF-8217-7DBE5536EB88}.Release|Any CPU.ActiveCfg = Release|Any CPU
220224
{0B9E97D9-76FA-43BF-8217-7DBE5536EB88}.Release|Any CPU.Build.0 = Release|Any CPU
225+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Debug_CUDA|Any CPU.ActiveCfg = Debug|Any CPU
226+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Debug_CUDA|Any CPU.Build.0 = Debug|Any CPU
227+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
228+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Debug|Any CPU.Build.0 = Debug|Any CPU
229+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Release_CUDA|Any CPU.ActiveCfg = Release|Any CPU
230+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Release_CUDA|Any CPU.Build.0 = Release|Any CPU
231+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Release|Any CPU.ActiveCfg = Release|Any CPU
232+
{93621B62-037D-44FA-9FA5-7A9A5566F7FE}.Release|Any CPU.Build.0 = Release|Any CPU
221233
EndGlobalSection
222234
GlobalSection(SolutionProperties) = preSolution
223235
HideSolutionNode = FALSE

0 commit comments

Comments
 (0)