Skip to content

Commit b10865b

Browse files
committed
Fix provider device selection
1 parent 1ced5e1 commit b10865b

File tree

3 files changed

+22
-38
lines changed

3 files changed

+22
-38
lines changed

Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<Project>
22
<PropertyGroup>
3-
<Version>0.2.0</Version>
3+
<Version>0.2.1</Version>
44
<Company>TensorStack</Company>
55
<Copyright>TensorStack - 2025</Copyright>
66
<RepositoryUrl>https://github.com/TensorStack-AI/TensorStack</RepositoryUrl>

TensorStack.Providers.RyzenAI/Provider.cs

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) TensorStack. All rights reserved.
1+
// Copyright (c) TensorStack, Advanced Micro Devices. All rights reserved.
22
// Licensed under the Apache 2.0 License.
33
using Microsoft.ML.OnnxRuntime;
44
using System.Collections.Generic;
@@ -8,11 +8,13 @@
88

99
namespace TensorStack.Providers
1010
{
11+
/// <summary>
12+
/// RyzenAI NPU provider with DirectML GPU fallback
13+
/// </summary>
1114
public static class Provider
1215
{
1316
private static bool _isInitialized;
14-
private const string _providerName = "DMLExecutionProvider";
15-
17+
private const string _providerName = "RyzenAIExecutionProvider";
1618

1719
/// <summary>
1820
/// Initializes the Provider
@@ -23,7 +25,7 @@ public static void Initialize()
2325
return;
2426

2527
_isInitialized = true;
26-
DeviceManager.Initialize(_providerName);
28+
DeviceManager.Initialize("DMLExecutionProvider");
2729
}
2830

2931

@@ -37,7 +39,7 @@ public static void Initialize(EnvironmentCreationOptions environmentOptions)
3739
return;
3840

3941
_isInitialized = true;
40-
DeviceManager.Initialize(environmentOptions, _providerName);
42+
DeviceManager.Initialize(environmentOptions, "DMLExecutionProvider");
4143
}
4244

4345

@@ -73,6 +75,9 @@ public static Device GetDevice()
7375
/// <param name="deviceType">Type of the device.</param>
7476
public static Device GetDevice(DeviceType deviceType)
7577
{
78+
if (deviceType == DeviceType.NPU)
79+
return GetDevices().FirstOrDefault(x => x.Type == DeviceType.GPU);
80+
7681
return GetDevices().FirstOrDefault(x => x.Type == deviceType);
7782
}
7883

@@ -84,12 +89,15 @@ public static Device GetDevice(DeviceType deviceType)
8489
/// <param name="deviceId">The device identifier.</param>
8590
public static Device GetDevice(DeviceType deviceType, int deviceId)
8691
{
92+
if (deviceType == DeviceType.NPU)
93+
return GetDevices().FirstOrDefault(x => x.Type == DeviceType.GPU && x.DeviceId == deviceId);
94+
8795
return GetDevices().FirstOrDefault(x => x.Type == deviceType && x.DeviceId == deviceId);
8896
}
8997

9098

9199
/// <summary>
92-
/// Gets the DirectML provider this DeviceType.
100+
/// Gets the RyzenAI provider this DeviceType.
93101
/// </summary>
94102
/// <param name="optimizationLevel">The optimization level.</param>
95103
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
@@ -99,7 +107,7 @@ public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationL
99107

100108

101109
/// <summary>
102-
/// Gets the DirectML provider this DeviceType.
110+
/// Gets the RyzenAI provider this DeviceType.
103111
/// </summary>
104112
/// <param name="deviceType">Type of the device.</param>
105113
/// <param name="optimizationLevel">The optimization level.</param>
@@ -110,7 +118,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimiza
110118

111119

112120
/// <summary>
113-
/// Gets the DirectML provider this DeviceType, DeviceId.
121+
/// Gets the RyzenAI provider this DeviceType, DeviceId.
114122
/// </summary>
115123
/// <param name="deviceType">Type of the device.</param>
116124
/// <param name="deviceId">The device identifier.</param>
@@ -122,25 +130,23 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId,
122130

123131

124132
/// <summary>
125-
/// Gets the DirectML provider for this Device.
133+
/// Gets the RyzenAI provider for this Device.
126134
/// </summary>
127135
/// <param name="device">The device.</param>
128136
/// <param name="optimizationLevel">The optimization level.</param>
129137
public static ExecutionProvider GetProvider(this Device device, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
130138
{
131139
if (device == null)
132140
return default;
133-
else if (device.Type == DeviceType.NPU)
134-
return CreateRyzenProvider(device.DeviceId, optimizationLevel);
135141
else if (device.Type == DeviceType.CPU)
136142
return CreateCPUProvider(optimizationLevel);
137143

138-
return CreateDMLProvider(device.DeviceId, optimizationLevel);
144+
return CreateRyzenProvider(device.DeviceId, optimizationLevel);
139145
}
140146

141147

142148
/// <summary>
143-
/// Gets the DirectML provider for this DeviceId.
149+
/// Gets the RyzenAI provider for this DeviceId.
144150
/// </summary>
145151
/// <param name="deviceId">The device identifier.</param>
146152
/// <param name="optimizationLevel">The optimization level.</param>
@@ -166,29 +172,6 @@ private static ExecutionProvider CreateRyzenProvider(int deviceId, GraphOptimiza
166172
}
167173

168174

169-
/// <summary>
170-
/// Gets the DirectML provider for this DeviceId.
171-
/// </summary>
172-
/// <param name="deviceId">The device identifier.</param>
173-
/// <param name="optimizationLevel">The optimization level.</param>
174-
private static ExecutionProvider CreateDMLProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL)
175-
{
176-
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
177-
return new ExecutionProvider(_providerName, memoryInfo, configuration =>
178-
{
179-
var sessionOptions = new SessionOptions
180-
{
181-
GraphOptimizationLevel = optimizationLevel
182-
};
183-
184-
sessionOptions.AddSessionConfigEntries(configuration.SessionOptions);
185-
sessionOptions.AppendExecutionProvider_DML(deviceId);
186-
sessionOptions.AppendExecutionProvider_CPU();
187-
return sessionOptions;
188-
});
189-
}
190-
191-
192175
/// <summary>
193176
/// Gets the CPU provider.
194177
/// </summary>

TensorStack.Providers.RyzenAI/TensorStack.Providers.RyzenAI.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
<PropertyGroup>
44
<TargetFrameworks>net9.0-windows10.0.17763.0;net10.0-windows10.0.17763.0</TargetFrameworks>
55
<PlatformTarget>x64</PlatformTarget>
6-
<Copyright>Advanced Micro Devices, TensorStack - 2025</Copyright>
6+
<Authors>Advanced Micro Devices, TensorStack</Authors>
7+
<Copyright>Advanced Micro Devices - 2025, TensorStack - 2025</Copyright>
78

89
<!--DirectML.dll content warning-->
910
<NoWarn>$(NoWarn);NU5100</NoWarn>

0 commit comments

Comments
 (0)