@@ -118,14 +118,14 @@ public static ExecutionProvider GetProvider(DeviceType deviceType, GraphOptimiza
118118
119119
120120 /// <summary>
121- /// Gets the RyzenAI provider this DeviceType, DeviceId .
121+ /// Gets the RyzenAI provider for NPU if supported, else DirectML GPU fallback .
122122 /// </summary>
123123 /// <param name="deviceType">Type of the device.</param>
124124 /// <param name="deviceId">The device identifier.</param>
125125 /// <param name="optimizationLevel">The optimization level.</param>
126- public static ExecutionProvider GetProvider ( DeviceType deviceType , int deviceId , GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel . ORT_ENABLE_ALL )
126+ public static ExecutionProvider GetProvider ( DeviceType deviceType , int fallbackDeviceId , GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel . ORT_ENABLE_ALL )
127127 {
128- return GetDevice ( deviceType , deviceId ) . GetProvider ( optimizationLevel ) ;
128+ return GetDevice ( deviceType , fallbackDeviceId ) . GetProvider ( optimizationLevel ) ;
129129 }
130130
131131
@@ -146,13 +146,13 @@ public static ExecutionProvider GetProvider(this Device device, GraphOptimizatio
146146
147147
148148 /// <summary>
149- /// Gets the RyzenAI provider for this DeviceId .
149+ /// Gets the RyzenAI provider for NPU if supported, else DirectML fallback .
150150 /// </summary>
151- /// <param name="deviceId ">The device identifier.</param>
151+ /// <param name="fallbackDeviceId ">The fallback device identifier.</param>
152152 /// <param name="optimizationLevel">The optimization level.</param>
153- private static ExecutionProvider CreateRyzenProvider ( int deviceId , GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel . ORT_ENABLE_ALL )
153+ private static ExecutionProvider CreateRyzenProvider ( int fallbackDeviceId , GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel . ORT_ENABLE_ALL )
154154 {
155- var memoryInfo = new OrtMemoryInfo ( OrtMemoryInfo . allocatorCPU , OrtAllocatorType . DeviceAllocator , deviceId , OrtMemType . Default ) ;
155+ var memoryInfo = new OrtMemoryInfo ( OrtMemoryInfo . allocatorCPU , OrtAllocatorType . DeviceAllocator , fallbackDeviceId , OrtMemType . Default ) ;
156156 return new ExecutionProvider ( _providerName , memoryInfo , configuration =>
157157 {
158158 var sessionOptions = new SessionOptions
@@ -166,6 +166,7 @@ private static ExecutionProvider CreateRyzenProvider(int deviceId, GraphOptimiza
166166
167167 sessionOptions . AddSessionConfigEntries ( configuration . SessionOptions ) ;
168168 sessionOptions . RegisterCustomOpLibrary ( "onnx_custom_ops.dll" ) ;
169+ sessionOptions . AppendExecutionProvider_DML ( ) ;
169170 sessionOptions . AppendExecutionProvider_CPU ( ) ;
170171 return sessionOptions ;
171172 } ) ;
0 commit comments