Skip to content

Commit 3bee40b

Browse files
committed
Stateful/stateless processing, selector op wrappers
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent bdde828 commit 3bee40b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1032
-645
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,12 @@
222222
import org.tensorflow.op.core.StageClear;
223223
import org.tensorflow.op.core.StagePeek;
224224
import org.tensorflow.op.core.StageSize;
225+
import org.tensorflow.op.core.StatefulCase;
226+
import org.tensorflow.op.core.StatefulIf;
225227
import org.tensorflow.op.core.StatefulPartitionedCall;
228+
import org.tensorflow.op.core.StatefulWhile;
226229
import org.tensorflow.op.core.StatelessIf;
230+
import org.tensorflow.op.core.StatelessPartitionedCall;
227231
import org.tensorflow.op.core.StatelessWhile;
228232
import org.tensorflow.op.core.StopGradient;
229233
import org.tensorflow.op.core.StridedSlice;
@@ -1230,6 +1234,7 @@ public Map<String, Operand<?>> call(ConcreteFunction function,
12301234
* }
12311235
* ```
12321236
* </pre>
1237+
* Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments.
12331238
*
12341239
* @param branchIndex The branch selector, an int32 Tensor.
12351240
* @param input A list of input tensors passed to the branch function.
@@ -2946,6 +2951,7 @@ public IdentityN identityN(Iterable<Operand<?>> input) {
29462951

29472952
/**
29482953
* output = cond ? then_branch(input) : else_branch(input)
2954+
* Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments.
29492955
*
29502956
* @param cond <pre>
29512957
* A Tensor. If the tensor is a scalar of non-boolean type, the
@@ -4015,14 +4021,16 @@ public <T extends TType> ParallelDynamicStitch<T> parallelDynamicStitch(
40154021

40164022
/**
40174023
* returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
4024+
* Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments.
40184025
*
40194026
* @param args A list of input tensors.
40204027
* @param Tout A list of output types.
40214028
* @param f <pre>
40224029
* A function that takes 'args', a list of tensors, and returns 'output',
40234030
* another list of tensors. Input and output types are specified by 'Tin'
40244031
* and 'Tout'. The function body of f will be placed and partitioned across
4025-
* devices, setting this op apart from the regular Call op.
4032+
* devices, setting this op apart from the regular Call op. This op is
4033+
* stateful.
40264034
* </pre>
40274035
* @param options carries optional attribute values
40284036
* @return a new instance of PartitionedCall
@@ -6048,6 +6056,72 @@ public StageSize stageSize(List<Class<? extends TType>> dtypes, StageSize.Option
60486056
return StageSize.create(scope, dtypes, options);
60496057
}
60506058

6059+
/**
6060+
* An n-way switch statement which calls a single branch function.
6061+
* <pre>
6062+
* An n-way switch statement, implementing the following:
6063+
* ```
6064+
* switch (branch_index) {
6065+
* case 0:
6066+
* output = branches[0](input);
6067+
* break;
6068+
* case 1:
6069+
* output = branches[1](input);
6070+
* break;
6071+
* ...
6072+
* case [[nbranches-1]]:
6073+
* default:
6074+
* output = branches[nbranches-1](input);
6075+
* break;
6076+
* }
6077+
* ```
6078+
* </pre>
6079+
*
6080+
* @param branchIndex The branch selector, an int32 Tensor.
6081+
* @param input A list of input tensors passed to the branch function.
6082+
* @param Tout A list of output types.
6083+
* @param branches <pre>
6084+
* A list of functions each of which takes 'inputs' and returns a list of
6085+
* tensors, whose types are the same as what every other branch returns.
6086+
* </pre>
6087+
* @param options carries optional attribute values
6088+
* @return a new instance of StatefulCase
6089+
*/
6090+
public StatefulCase statefulCase(Operand<TInt32> branchIndex, Iterable<Operand<?>> input,
6091+
List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Case.Options... options) {
6092+
return StatefulCase.create(scope, branchIndex, input, Tout, branches, options);
6093+
}
6094+
6095+
/**
6096+
* output = cond ? then_branch(input) : else_branch(input)
6097+
*
6098+
* @param cond <pre>
6099+
* A Tensor. If the tensor is a scalar of non-boolean type, the
6100+
* scalar is converted to a boolean according to the
6101+
* following rule: if the scalar is a numerical value, non-zero means
6102+
* `True` and zero means False; if the scalar is a string, non-empty
6103+
* means `True` and empty means `False`. If the tensor is not a scalar,
6104+
* being empty means False and being non-empty means True.
6105+
* </pre>
6106+
* @param input A list of input tensors.
6107+
* @param Tout A list of output types.
6108+
* @param thenBranch <pre>
6109+
* A function that takes 'inputs' and returns a list of tensors, whose
6110+
* types are the same as what else_branch returns.
6111+
* </pre>
6112+
* @param elseBranch <pre>
6113+
* A function that takes 'inputs' and returns a list of tensors, whose
6114+
* types are the same as what then_branch returns.
6115+
* </pre>
6116+
* @param options carries optional attribute values
6117+
* @return a new instance of StatefulIf
6118+
*/
6119+
public StatefulIf statefulIf(Operand<? extends TType> cond, Iterable<Operand<?>> input,
6120+
List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch,
6121+
If.Options... options) {
6122+
return StatefulIf.create(scope, cond, input, Tout, thenBranch, elseBranch, options);
6123+
}
6124+
60516125
/**
60526126
* returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
60536127
*
@@ -6064,11 +6138,36 @@ public StageSize stageSize(List<Class<? extends TType>> dtypes, StageSize.Option
60646138
* @return a new instance of StatefulPartitionedCall
60656139
*/
60666140
public StatefulPartitionedCall statefulPartitionedCall(Iterable<Operand<?>> args,
6067-
List<Class<? extends TType>> Tout, ConcreteFunction f,
6068-
StatefulPartitionedCall.Options... options) {
6141+
List<Class<? extends TType>> Tout, ConcreteFunction f, PartitionedCall.Options... options) {
60696142
return StatefulPartitionedCall.create(scope, args, Tout, f, options);
60706143
}
60716144

6145+
/**
6146+
* output = input; While (Cond(output)) { output = Body(output) }
6147+
*
6148+
* @param input A list of input tensors whose types are T.
6149+
* @param cond <pre>
6150+
* A function takes 'input' and returns a tensor. If the tensor is
6151+
* a scalar of non-boolean, the scalar is converted to a boolean
6152+
* according to the following rule: if the scalar is a numerical
6153+
* value, non-zero means True and zero means False; if the scalar is
6154+
* a string, non-empty means True and empty means False. If the
6155+
* tensor is not a scalar, non-emptiness means True and False
6156+
* otherwise.
6157+
* </pre>
6158+
* @param body <pre>
6159+
* A function that takes a list of tensors and returns another
6160+
* list of tensors. Both lists have the same types as specified
6161+
* by T.
6162+
* </pre>
6163+
* @param options carries optional attribute values
6164+
* @return a new instance of StatefulWhile
6165+
*/
6166+
public StatefulWhile statefulWhile(Iterable<Operand<?>> input, ConcreteFunction cond,
6167+
ConcreteFunction body, While.Options... options) {
6168+
return StatefulWhile.create(scope, input, cond, body, options);
6169+
}
6170+
60726171
/**
60736172
* output = cond ? then_branch(input) : else_branch(input)
60746173
*
@@ -6098,10 +6197,29 @@ public StatefulPartitionedCall statefulPartitionedCall(Iterable<Operand<?>> args
60986197
*/
60996198
public StatelessIf statelessIf(Operand<? extends TType> cond, Iterable<Operand<?>> input,
61006199
List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch,
6101-
StatelessIf.Options... options) {
6200+
If.Options... options) {
61026201
return StatelessIf.create(scope, cond, input, Tout, thenBranch, elseBranch, options);
61036202
}
61046203

6204+
/**
6205+
* returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
6206+
*
6207+
* @param args A list of input tensors.
6208+
* @param Tout A list of output types.
6209+
* @param f <pre>
6210+
* A function that takes 'args', a list of tensors, and returns 'output',
6211+
* another list of tensors. Input and output types are specified by 'Tin'
6212+
* and 'Tout'. The function body of f will be placed and partitioned across
6213+
* devices, setting this op apart from the regular Call op.
6214+
* </pre>
6215+
* @param options carries optional attribute values
6216+
* @return a new instance of StatelessPartitionedCall
6217+
*/
6218+
public StatelessPartitionedCall statelessPartitionedCall(Iterable<Operand<?>> args,
6219+
List<Class<? extends TType>> Tout, ConcreteFunction f, PartitionedCall.Options... options) {
6220+
return StatelessPartitionedCall.create(scope, args, Tout, f, options);
6221+
}
6222+
61056223
/**
61066224
* output = input; While (Cond(output)) { output = Body(output) }
61076225
*
@@ -6127,7 +6245,7 @@ public StatelessIf statelessIf(Operand<? extends TType> cond, Iterable<Operand<?
61276245
* @return a new instance of StatelessWhile
61286246
*/
61296247
public StatelessWhile statelessWhile(Iterable<Operand<?>> input, ConcreteFunction cond,
6130-
ConcreteFunction body, StatelessWhile.Options... options) {
6248+
ConcreteFunction body, While.Options... options) {
61316249
return StatelessWhile.create(scope, input, cond, body, options);
61326250
}
61336251

@@ -7990,6 +8108,7 @@ public Where where(Operand<? extends TType> condition) {
79908108

79918109
/**
79928110
* output = input; While (Cond(output)) { output = Body(output) }
8111+
* Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments.
79938112
*
79948113
* @param input A list of input tensors whose types are T.
79958114
* @param cond <pre>

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public static BatchFunction create(Scope scope, Iterable<Operand<?>> inTensors,
109109
Iterable<Operand<?>> capturedTensors, ConcreteFunction f, Long numBatchThreads,
110110
Long maxBatchSize, Long batchTimeoutMicros, List<Class<? extends TType>> Tout,
111111
Options... options) {
112-
OperationBuilder opBuilder = scope.env().opBuilder("BatchFunction", scope.makeOpName("BatchFunction"));
112+
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("BatchFunction"));
113113
opBuilder.addInputList(Operands.asOutputs(inTensors));
114114
opBuilder.addInputList(Operands.asOutputs(capturedTensors));
115115
opBuilder = scope.apply(opBuilder);
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
16+
// This class has been generated, DO NOT EDIT!
17+
18+
package org.tensorflow.op.core;
19+
20+
import java.util.Arrays;
21+
import java.util.Iterator;
22+
import java.util.List;
23+
import org.tensorflow.ConcreteFunction;
24+
import org.tensorflow.Operand;
25+
import org.tensorflow.Output;
26+
import org.tensorflow.ndarray.Shape;
27+
import org.tensorflow.op.Scope;
28+
import org.tensorflow.op.annotation.Endpoint;
29+
import org.tensorflow.op.annotation.Operator;
30+
import org.tensorflow.types.TInt32;
31+
import org.tensorflow.types.family.TType;
32+
33+
/**
34+
* An n-way switch statement which calls a single branch function.
35+
* <pre>
36+
* An n-way switch statement, implementing the following:
37+
* ```
38+
* switch (branch_index) {
39+
* case 0:
40+
* output = branches[0](input);
41+
* break;
42+
* case 1:
43+
* output = branches[1](input);
44+
* break;
45+
* ...
46+
* case [[nbranches-1]]:
47+
* default:
48+
* output = branches[nbranches-1](input);
49+
* break;
50+
* }
51+
* ```
52+
* </pre>
53+
* Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments.
54+
*/
55+
@Operator
56+
public interface Case extends Iterable<Operand<TType>> {
57+
/**
58+
* Factory method to create a class wrapping a new Case operation.
59+
*
60+
* @param scope current scope
61+
* @param branchIndex The branch selector, an int32 Tensor.
62+
* @param input A list of input tensors passed to the branch function.
63+
* @param Tout A list of output types.
64+
* @param branches <pre>
65+
* A list of functions each of which takes 'inputs' and returns a list of
66+
* tensors, whose types are the same as what every other branch returns.
67+
* </pre>
68+
* @param options carries optional attribute values
69+
* @return a new instance of Case
70+
*/
71+
@Endpoint(
72+
describeByClass = true,
73+
name = "caseOp"
74+
)
75+
static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input,
76+
List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Options... options) {
77+
boolean isStateful = false;
78+
if (branches.stream().anyMatch(x -> x.isStateful())) {
79+
isStateful = true;
80+
}
81+
if (isStateful) {
82+
return StatefulCase.create(scope, branchIndex, input, Tout, branches, options);
83+
} else {
84+
return StatelessCase.create(scope, branchIndex, input, Tout, branches, options);
85+
}
86+
}
87+
88+
/**
89+
* Sets the outputShapes option.
90+
*
91+
* @param outputShapes the outputShapes option
92+
* @return this Options instance.
93+
*/
94+
static Options outputShapes(List<Shape> outputShapes) {
95+
return new Options().outputShapes(outputShapes);
96+
}
97+
98+
/**
99+
* Sets the outputShapes option.
100+
*
101+
* @param outputShapes the outputShapes option
102+
* @return this Options instance.
103+
*/
104+
static Options outputShapes(Shape[] outputShapes) {
105+
return new Options().outputShapes(outputShapes);
106+
}
107+
108+
/**
109+
* Gets output.
110+
* A list of return values.
111+
* @return output.
112+
*/
113+
List<Output<?>> output();
114+
115+
@Override
116+
@SuppressWarnings({"rawtypes", "unchecked"})
117+
Iterator<Operand<TType>> iterator();
118+
119+
/**
120+
* Optional attributes for {@link org.tensorflow.op.core.Case}
121+
*/
122+
class Options {
123+
List<Shape> outputShapes;
124+
125+
private Options() {
126+
}
127+
128+
/**
129+
* Sets the outputShapes option.
130+
*
131+
* @param outputShapes the outputShapes option
132+
* @return this Options instance.
133+
*/
134+
public Options outputShapes(List<Shape> outputShapes) {
135+
this.outputShapes = outputShapes;
136+
return this;
137+
}
138+
139+
/**
140+
* Sets the outputShapes option.
141+
*
142+
* @param outputShapes the outputShapes option
143+
* @return this Options instance.
144+
*/
145+
public Options outputShapes(Shape... outputShapes) {
146+
this.outputShapes = Arrays.asList(outputShapes);
147+
return this;
148+
}
149+
}
150+
}

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ private For(Operation operation) {
7878
)
7979
public static For create(Scope scope, Operand<TInt32> start, Operand<TInt32> limit,
8080
Operand<TInt32> delta, Iterable<Operand<?>> input, ConcreteFunction body) {
81-
OperationBuilder opBuilder = scope.env().opBuilder("For", scope.makeOpName("For"));
81+
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("For"));
8282
opBuilder.addInput(start.asOutput());
8383
opBuilder.addInput(limit.asOutput());
8484
opBuilder.addInput(delta.asOutput());

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public static GroupByReducerDataset create(Scope scope, Operand<? extends TType>
8282
Iterable<Operand<?>> finalizeFuncOtherArguments, ConcreteFunction keyFunc,
8383
ConcreteFunction initFunc, ConcreteFunction reduceFunc, ConcreteFunction finalizeFunc,
8484
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
85-
OperationBuilder opBuilder = scope.env().opBuilder("GroupByReducerDataset", scope.makeOpName("GroupByReducerDataset"));
85+
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("GroupByReducerDataset"));
8686
opBuilder.addInput(inputDataset.asOutput());
8787
opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments));
8888
opBuilder.addInputList(Operands.asOutputs(initFuncOtherArguments));

0 commit comments

Comments
 (0)