1- // Copyright (c) TensorStack. All rights reserved.
1+ // Copyright (c) TensorStack, Advanced Micro Devices . All rights reserved.
22// Licensed under the Apache 2.0 License.
33using Microsoft . ML . OnnxRuntime ;
44using System . Collections . Generic ;
88
99namespace TensorStack . Providers
1010{
11+ /// <summary>
12+ /// RyzenAI NPU provider with DirectML GPU fallback
13+ /// </summary>
1114 public static class Provider
1215 {
1316 private static bool _isInitialized ;
14- private const string _providerName = "DMLExecutionProvider" ;
15-
17+ private const string _providerName = "RyzenAIExecutionProvider" ;
1618
1719 /// <summary>
1820 /// Initializes the Provider
@@ -23,7 +25,7 @@ public static void Initialize()
2325 return ;
2426
2527 _isInitialized = true ;
26- DeviceManager . Initialize ( _providerName ) ;
28+ DeviceManager . Initialize ( "DMLExecutionProvider" ) ;
2729 }
2830
2931
@@ -37,7 +39,7 @@ public static void Initialize(EnvironmentCreationOptions environmentOptions)
3739 return ;
3840
3941 _isInitialized = true ;
40- DeviceManager . Initialize ( environmentOptions , _providerName ) ;
42+ DeviceManager . Initialize ( environmentOptions , "DMLExecutionProvider" ) ;
4143 }
4244
4345
@@ -73,6 +75,9 @@ public static Device GetDevice()
7375 /// <param name="deviceType">Type of the device.</param>
7476 public static Device GetDevice ( DeviceType deviceType )
7577 {
78+ if ( deviceType == DeviceType . NPU )
79+ return GetDevices ( ) . FirstOrDefault ( x => x . Type == DeviceType . GPU ) ;
80+
7681 return GetDevices ( ) . FirstOrDefault ( x => x . Type == deviceType ) ;
7782 }
7883
@@ -84,12 +89,15 @@ public static Device GetDevice(DeviceType deviceType)
8489 /// <param name="deviceId">The device identifier.</param>
8590 public static Device GetDevice ( DeviceType deviceType , int deviceId )
8691 {
92+ if ( deviceType == DeviceType . NPU )
93+ return GetDevices ( ) . FirstOrDefault ( x => x . Type == DeviceType . GPU && x . DeviceId == deviceId ) ;
94+
8795 return GetDevices ( ) . FirstOrDefault ( x => x . Type == deviceType && x . DeviceId == deviceId ) ;
8896 }
8997
9098
9199 /// <summary>
92- /// Gets the DirectML provider this DeviceType.
100+ /// Gets the RyzenAI provider this DeviceType.
93101 /// </summary>
94102 /// <param name="optimizationLevel">The optimization level.</param>
95103 public static ExecutionProvider GetProvider ( GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel . ORT_ENABLE_ALL )
@@ -99,7 +107,7 @@ public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationL
99107
100108
101109 /// <summary>
102- /// Gets the DirectML provider this DeviceType.
110+ /// Gets the RyzenAI provider this DeviceType.
103111 /// </summary>
104112 /// <param name="deviceType">Type of the device.</param>
105113 /// <param name="optimizationLevel">The optimization level.</param>
@@ -110,7 +118,7 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimiza
110118
111119
112120 /// <summary>
113- /// Gets the DirectML provider this DeviceType, DeviceId.
121+ /// Gets the RyzenAI provider this DeviceType, DeviceId.
114122 /// </summary>
115123 /// <param name="deviceType">Type of the device.</param>
116124 /// <param name="deviceId">The device identifier.</param>
@@ -122,25 +130,23 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, int deviceId,
122130
123131
124132 /// <summary>
125- /// Gets the DirectML provider for this Device.
133+ /// Gets the RyzenAI provider for this Device.
126134 /// </summary>
127135 /// <param name="device">The device.</param>
128136 /// <param name="optimizationLevel">The optimization level.</param>
129137 public static ExecutionProvider GetProvider ( this Device device , GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel . ORT_ENABLE_ALL )
130138 {
131139 if ( device == null )
132140 return default ;
133- else if ( device . Type == DeviceType . NPU )
134- return CreateRyzenProvider ( device . DeviceId , optimizationLevel ) ;
135141 else if ( device . Type == DeviceType . CPU )
136142 return CreateCPUProvider ( optimizationLevel ) ;
137143
138- return CreateDMLProvider ( device . DeviceId , optimizationLevel ) ;
144+ return CreateRyzenProvider ( device . DeviceId , optimizationLevel ) ;
139145 }
140146
141147
142148 /// <summary>
143- /// Gets the DirectML provider for this DeviceId.
149+ /// Gets the RyzenAI provider for this DeviceId.
144150 /// </summary>
145151 /// <param name="deviceId">The device identifier.</param>
146152 /// <param name="optimizationLevel">The optimization level.</param>
@@ -166,29 +172,6 @@ private static ExecutionProvider CreateRyzenProvider(int deviceId, GraphOptimiza
166172 }
167173
168174
169- /// <summary>
170- /// Gets the DirectML provider for this DeviceId.
171- /// </summary>
172- /// <param name="deviceId">The device identifier.</param>
173- /// <param name="optimizationLevel">The optimization level.</param>
174- private static ExecutionProvider CreateDMLProvider ( int deviceId , GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel . ORT_ENABLE_ALL )
175- {
176- var memoryInfo = new OrtMemoryInfo ( OrtMemoryInfo . allocatorCPU , OrtAllocatorType . DeviceAllocator , deviceId , OrtMemType . Default ) ;
177- return new ExecutionProvider ( _providerName , memoryInfo , configuration =>
178- {
179- var sessionOptions = new SessionOptions
180- {
181- GraphOptimizationLevel = optimizationLevel
182- } ;
183-
184- sessionOptions . AddSessionConfigEntries ( configuration . SessionOptions ) ;
185- sessionOptions . AppendExecutionProvider_DML ( deviceId ) ;
186- sessionOptions . AppendExecutionProvider_CPU ( ) ;
187- return sessionOptions ;
188- } ) ;
189- }
190-
191-
192175 /// <summary>
193176 /// Gets the CPU provider.
194177 /// </summary>
0 commit comments