|
42 | 42 | import org.tensorflow.op.Scope; |
43 | 43 | import org.tensorflow.op.core.Placeholder; |
44 | 44 | import org.tensorflow.op.core.PlaceholderWithDefault; |
| 45 | +import org.tensorflow.op.core.StatefulPartitionedCall; |
| 46 | +import org.tensorflow.op.core.StatelessPartitionedCall; |
45 | 47 | import org.tensorflow.proto.framework.AttrValue; |
46 | 48 | import org.tensorflow.proto.framework.DataType; |
47 | 49 | import org.tensorflow.proto.framework.FunctionDef; |
@@ -207,11 +209,6 @@ public String toString() { |
207 | 209 | return signature.toString(); |
208 | 210 | } |
209 | 211 |
|
210 | | - // TODO migrate to the actual ops once they are generated |
211 | | - public static final String CALL_OP = "PartitionedCall"; |
212 | | - // TODO migrate to the actual ops once they are generated |
213 | | - public static final String STATEFUL_CALL_OP = "StatefulPartitionedCall"; |
214 | | - |
215 | 212 | /** |
216 | 213 | * Calls the function in an execution environment, adding its graph as a function if it isn't |
217 | 214 | * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. |
@@ -255,7 +252,9 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen |
255 | 252 | OperationBuilder opBuilder = |
256 | 253 | scope |
257 | 254 | .env() |
258 | | - .opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName)); |
| 255 | + .opBuilder( |
| 256 | + isStateful() ? StatefulPartitionedCall.OP_NAME : StatelessPartitionedCall.OP_NAME, |
| 257 | + scope.makeOpName(displayName)); |
259 | 258 |
|
260 | 259 | opBuilder.addInputList(inputs); |
261 | 260 |
|
|
0 commit comments