222222import org .tensorflow .op .core .StageClear ;
223223import org .tensorflow .op .core .StagePeek ;
224224import org .tensorflow .op .core .StageSize ;
225+ import org .tensorflow .op .core .StatefulCase ;
226+ import org .tensorflow .op .core .StatefulIf ;
225227import org .tensorflow .op .core .StatefulPartitionedCall ;
228+ import org .tensorflow .op .core .StatefulWhile ;
226229import org .tensorflow .op .core .StatelessIf ;
230+ import org .tensorflow .op .core .StatelessPartitionedCall ;
227231import org .tensorflow .op .core .StatelessWhile ;
228232import org .tensorflow .op .core .StopGradient ;
229233import org .tensorflow .op .core .StridedSlice ;
@@ -1230,6 +1234,7 @@ public Map<String, Operand<?>> call(ConcreteFunction function,
12301234 * }
12311235 * ```
12321236 * </pre>
1237+ * Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments.
12331238 *
12341239 * @param branchIndex The branch selector, an int32 Tensor.
12351240 * @param input A list of input tensors passed to the branch function.
@@ -2946,6 +2951,7 @@ public IdentityN identityN(Iterable<Operand<?>> input) {
29462951
29472952 /**
29482953 * output = cond ? then_branch(input) : else_branch(input)
2954+ * Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments.
29492955 *
29502956 * @param cond <pre>
29512957 * A Tensor. If the tensor is a scalar of non-boolean type, the
@@ -4015,14 +4021,16 @@ public <T extends TType> ParallelDynamicStitch<T> parallelDynamicStitch(
40154021
40164022 /**
40174023 * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
4024+ * Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments.
40184025 *
40194026 * @param args A list of input tensors.
40204027 * @param Tout A list of output types.
40214028 * @param f <pre>
40224029 * A function that takes 'args', a list of tensors, and returns 'output',
40234030 * another list of tensors. Input and output types are specified by 'Tin'
40244031 * and 'Tout'. The function body of f will be placed and partitioned across
4025- * devices, setting this op apart from the regular Call op.
4032+ * devices, setting this op apart from the regular Call op. This op is
4033+ * stateful.
40264034 * </pre>
40274035 * @param options carries optional attribute values
40284036 * @return a new instance of PartitionedCall
@@ -6048,6 +6056,72 @@ public StageSize stageSize(List<Class<? extends TType>> dtypes, StageSize.Option
60486056 return StageSize .create (scope , dtypes , options );
60496057 }
60506058
6059+ /**
6060+ * An n-way switch statement which calls a single branch function.
6061+ * <pre>
6062+ * An n-way switch statement, implementing the following:
6063+ * ```
6064+ * switch (branch_index) {
6065+ * case 0:
6066+ * output = branches[0](input);
6067+ * break;
6068+ * case 1:
6069+ * output = branches[1](input);
6070+ * break;
6071+ * ...
6072+ * case [[nbranches-1]]:
6073+ * default:
6074+ * output = branches[nbranches-1](input);
6075+ * break;
6076+ * }
6077+ * ```
6078+ * </pre>
6079+ *
6080+ * @param branchIndex The branch selector, an int32 Tensor.
6081+ * @param input A list of input tensors passed to the branch function.
6082+ * @param Tout A list of output types.
6083+ * @param branches <pre>
6084+ * A list of functions each of which takes 'inputs' and returns a list of
6085+ * tensors, whose types are the same as what every other branch returns.
6086+ * </pre>
6087+ * @param options carries optional attribute values
6088+ * @return a new instance of StatefulCase
6089+ */
6090+ public StatefulCase statefulCase (Operand <TInt32 > branchIndex , Iterable <Operand <?>> input ,
6091+ List <Class <? extends TType >> Tout , List <ConcreteFunction > branches , Case .Options ... options ) {
6092+ return StatefulCase .create (scope , branchIndex , input , Tout , branches , options );
6093+ }
6094+
6095+ /**
6096+ * output = cond ? then_branch(input) : else_branch(input)
6097+ *
6098+ * @param cond <pre>
6099+ * A Tensor. If the tensor is a scalar of non-boolean type, the
6100+ * scalar is converted to a boolean according to the
6101+ * following rule: if the scalar is a numerical value, non-zero means
6102+ * `True` and zero means False; if the scalar is a string, non-empty
6103+ * means `True` and empty means `False`. If the tensor is not a scalar,
6104+ * being empty means False and being non-empty means True.
6105+ * </pre>
6106+ * @param input A list of input tensors.
6107+ * @param Tout A list of output types.
6108+ * @param thenBranch <pre>
6109+ * A function that takes 'inputs' and returns a list of tensors, whose
6110+ * types are the same as what else_branch returns.
6111+ * </pre>
6112+ * @param elseBranch <pre>
6113+ * A function that takes 'inputs' and returns a list of tensors, whose
6114+ * types are the same as what then_branch returns.
6115+ * </pre>
6116+ * @param options carries optional attribute values
6117+ * @return a new instance of StatefulIf
6118+ */
6119+ public StatefulIf statefulIf (Operand <? extends TType > cond , Iterable <Operand <?>> input ,
6120+ List <Class <? extends TType >> Tout , ConcreteFunction thenBranch , ConcreteFunction elseBranch ,
6121+ If .Options ... options ) {
6122+ return StatefulIf .create (scope , cond , input , Tout , thenBranch , elseBranch , options );
6123+ }
6124+
60516125 /**
60526126 * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
60536127 *
@@ -6064,11 +6138,36 @@ public StageSize stageSize(List<Class<? extends TType>> dtypes, StageSize.Option
60646138 * @return a new instance of StatefulPartitionedCall
60656139 */
60666140 public StatefulPartitionedCall statefulPartitionedCall (Iterable <Operand <?>> args ,
6067- List <Class <? extends TType >> Tout , ConcreteFunction f ,
6068- StatefulPartitionedCall .Options ... options ) {
6141+ List <Class <? extends TType >> Tout , ConcreteFunction f , PartitionedCall .Options ... options ) {
60696142 return StatefulPartitionedCall .create (scope , args , Tout , f , options );
60706143 }
60716144
6145+ /**
6146+ * output = input; While (Cond(output)) { output = Body(output) }
6147+ *
6148+ * @param input A list of input tensors whose types are T.
6149+ * @param cond <pre>
6150+ * A function takes 'input' and returns a tensor. If the tensor is
6151+ * a scalar of non-boolean, the scalar is converted to a boolean
6152+ * according to the following rule: if the scalar is a numerical
6153+ * value, non-zero means True and zero means False; if the scalar is
6154+ * a string, non-empty means True and empty means False. If the
6155+ * tensor is not a scalar, non-emptiness means True and False
6156+ * otherwise.
6157+ * </pre>
6158+ * @param body <pre>
6159+ * A function that takes a list of tensors and returns another
6160+ * list of tensors. Both lists have the same types as specified
6161+ * by T.
6162+ * </pre>
6163+ * @param options carries optional attribute values
6164+ * @return a new instance of StatefulWhile
6165+ */
6166+ public StatefulWhile statefulWhile (Iterable <Operand <?>> input , ConcreteFunction cond ,
6167+ ConcreteFunction body , While .Options ... options ) {
6168+ return StatefulWhile .create (scope , input , cond , body , options );
6169+ }
6170+
60726171 /**
60736172 * output = cond ? then_branch(input) : else_branch(input)
60746173 *
@@ -6098,10 +6197,29 @@ public StatefulPartitionedCall statefulPartitionedCall(Iterable<Operand<?>> args
60986197 */
60996198 public StatelessIf statelessIf (Operand <? extends TType > cond , Iterable <Operand <?>> input ,
61006199 List <Class <? extends TType >> Tout , ConcreteFunction thenBranch , ConcreteFunction elseBranch ,
6101- StatelessIf .Options ... options ) {
6200+ If .Options ... options ) {
61026201 return StatelessIf .create (scope , cond , input , Tout , thenBranch , elseBranch , options );
61036202 }
61046203
6204+ /**
6205+ * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
6206+ *
6207+ * @param args A list of input tensors.
6208+ * @param Tout A list of output types.
6209+ * @param f <pre>
6210+ * A function that takes 'args', a list of tensors, and returns 'output',
6211+ * another list of tensors. Input and output types are specified by 'Tin'
6212+ * and 'Tout'. The function body of f will be placed and partitioned across
6213+ * devices, setting this op apart from the regular Call op.
6214+ * </pre>
6215+ * @param options carries optional attribute values
6216+ * @return a new instance of StatelessPartitionedCall
6217+ */
6218+ public StatelessPartitionedCall statelessPartitionedCall (Iterable <Operand <?>> args ,
6219+ List <Class <? extends TType >> Tout , ConcreteFunction f , PartitionedCall .Options ... options ) {
6220+ return StatelessPartitionedCall .create (scope , args , Tout , f , options );
6221+ }
6222+
61056223 /**
61066224 * output = input; While (Cond(output)) { output = Body(output) }
61076225 *
@@ -6127,7 +6245,7 @@ public StatelessIf statelessIf(Operand<? extends TType> cond, Iterable<Operand<?
61276245 * @return a new instance of StatelessWhile
61286246 */
61296247 public StatelessWhile statelessWhile (Iterable <Operand <?>> input , ConcreteFunction cond ,
6130- ConcreteFunction body , StatelessWhile .Options ... options ) {
6248+ ConcreteFunction body , While .Options ... options ) {
61316249 return StatelessWhile .create (scope , input , cond , body , options );
61326250 }
61336251
@@ -7990,6 +8108,7 @@ public Where where(Operand<? extends TType> condition) {
79908108
79918109 /**
79928110 * output = input; While (Cond(output)) { output = Body(output) }
8111+ * Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments.
79938112 *
79948113 * @param input A list of input tensors whose types are T.
79958114 * @param cond <pre>
0 commit comments