Skip to content

Commit bdde828

Browse files
committed
Start of stateful/stateless processing
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent db51573 commit bdde828

File tree

8 files changed

+196
-68
lines changed

8 files changed

+196
-68
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@
252252
<executions>
253253
<execution>
254254
<!-- Runs in initialize phase to fail fast in case of formatting issues (should be before codegen).-->
255-
<id>spotless-check</id>
255+
<id>spotless-apply</id>
256256
<phase>initialize</phase>
257257
<goals>
258258
<goal>check</goal>

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java renamed to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
* </pre>
5757
*/
5858
@Operator
59-
public final class Case extends RawOp implements Iterable<Operand<TType>> {
59+
public final class StatefulCase extends RawOp implements Iterable<Operand<TType>> {
6060
/**
6161
* The name of this op, as known by TensorFlow core engine
6262
*/
@@ -65,7 +65,7 @@ public final class Case extends RawOp implements Iterable<Operand<TType>> {
6565
private List<Output<?>> output;
6666

6767
@SuppressWarnings("unchecked")
68-
private Case(Operation operation) {
68+
private StatefulCase(Operation operation) {
6969
super(operation);
7070
int outputIdx = 0;
7171
int outputLength = operation.outputListLength("output");
@@ -85,15 +85,15 @@ private Case(Operation operation) {
8585
* tensors, whose types are the same as what every other branch returns.
8686
* </pre>
8787
* @param options carries optional attribute values
88-
* @return a new instance of Case
88+
* @return a new instance of StatefulCase
8989
*/
9090
@Endpoint(
91-
describeByClass = true,
92-
name = "caseOp"
91+
describeByClass = true
9392
)
94-
public static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input,
95-
List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Options... options) {
96-
OperationBuilder opBuilder = scope.env().opBuilder("Case", scope.makeOpName("Case"));
93+
public static StatefulCase create(Scope scope, Operand<TInt32> branchIndex,
94+
Iterable<Operand<?>> input, List<Class<? extends TType>> Tout,
95+
List<ConcreteFunction> branches, Options... options) {
96+
OperationBuilder opBuilder = scope.env().opBuilder("Case", scope.makeOpName("StatefulCase"));
9797
opBuilder.addInput(branchIndex.asOutput());
9898
opBuilder.addInputList(Operands.asOutputs(input));
9999
opBuilder = scope.apply(opBuilder);
@@ -114,7 +114,7 @@ public static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Ope
114114
}
115115
}
116116
}
117-
return new Case(opBuilder.build());
117+
return new StatefulCase(opBuilder.build());
118118
}
119119

120120
/**
@@ -153,7 +153,7 @@ public Iterator<Operand<TType>> iterator() {
153153
}
154154

155155
/**
156-
* Optional attributes for {@link org.tensorflow.op.core.Case}
156+
* Optional attributes for {@link org.tensorflow.op.core.StatefulCase}
157157
*/
158158
public static class Options {
159159
private List<Shape> outputShapes;

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java renamed to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
* output = cond ? then_branch(input) : else_branch(input)
3838
*/
3939
@Operator
40-
public final class If extends RawOp implements Iterable<Operand<TType>> {
40+
public final class StatefulIf extends RawOp implements Iterable<Operand<TType>> {
4141
/**
4242
* The name of this op, as known by TensorFlow core engine
4343
*/
@@ -46,7 +46,7 @@ public final class If extends RawOp implements Iterable<Operand<TType>> {
4646
private List<Output<?>> output;
4747

4848
@SuppressWarnings("unchecked")
49-
private If(Operation operation) {
49+
private StatefulIf(Operation operation) {
5050
super(operation);
5151
int outputIdx = 0;
5252
int outputLength = operation.outputListLength("output");
@@ -77,16 +77,15 @@ private If(Operation operation) {
7777
* types are the same as what then_branch returns.
7878
* </pre>
7979
* @param options carries optional attribute values
80-
* @return a new instance of If
80+
* @return a new instance of StatefulIf
8181
*/
8282
@Endpoint(
83-
describeByClass = true,
84-
name = "ifOp"
83+
describeByClass = true
8584
)
86-
public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> input,
87-
List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch,
88-
Options... options) {
89-
OperationBuilder opBuilder = scope.env().opBuilder("If", scope.makeOpName("If"));
85+
public static StatefulIf create(Scope scope, Operand<? extends TType> cond,
86+
Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction thenBranch,
87+
ConcreteFunction elseBranch, Options... options) {
88+
OperationBuilder opBuilder = scope.env().opBuilder("If", scope.makeOpName("StatefulIf"));
9089
opBuilder.addInput(cond.asOutput());
9190
opBuilder.addInputList(Operands.asOutputs(input));
9291
opBuilder = scope.apply(opBuilder);
@@ -104,7 +103,7 @@ public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Ope
104103
}
105104
}
106105
}
107-
return new If(opBuilder.build());
106+
return new StatefulIf(opBuilder.build());
108107
}
109108

110109
/**
@@ -143,7 +142,7 @@ public Iterator<Operand<TType>> iterator() {
143142
}
144143

145144
/**
146-
* Optional attributes for {@link org.tensorflow.op.core.If}
145+
* Optional attributes for {@link org.tensorflow.op.core.StatefulIf}
147146
*/
148147
public static class Options {
149148
private List<Shape> outputShapes;

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java renamed to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
* output = input; While (Cond(output)) { output = Body(output) }
3838
*/
3939
@Operator
40-
public final class While extends RawOp implements Iterable<Operand<TType>> {
40+
public final class StatefulWhile extends RawOp implements Iterable<Operand<TType>> {
4141
/**
4242
* The name of this op, as known by TensorFlow core engine
4343
*/
@@ -46,7 +46,7 @@ public final class While extends RawOp implements Iterable<Operand<TType>> {
4646
private List<Output<?>> output;
4747

4848
@SuppressWarnings("unchecked")
49-
private While(Operation operation) {
49+
private StatefulWhile(Operation operation) {
5050
super(operation);
5151
int outputIdx = 0;
5252
int outputLength = operation.outputListLength("output");
@@ -74,15 +74,14 @@ private While(Operation operation) {
7474
* by T.
7575
* </pre>
7676
* @param options carries optional attribute values
77-
* @return a new instance of While
77+
* @return a new instance of StatefulWhile
7878
*/
7979
@Endpoint(
80-
describeByClass = true,
81-
name = "whileOp"
80+
describeByClass = true
8281
)
83-
public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond,
82+
public static StatefulWhile create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond,
8483
ConcreteFunction body, Options... options) {
85-
OperationBuilder opBuilder = scope.env().opBuilder("While", scope.makeOpName("While"));
84+
OperationBuilder opBuilder = scope.env().opBuilder("While", scope.makeOpName("StatefulWhile"));
8685
opBuilder.addInputList(Operands.asOutputs(input));
8786
opBuilder = scope.apply(opBuilder);
8887
opBuilder.setAttr("cond", cond);
@@ -101,7 +100,7 @@ public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunc
101100
}
102101
}
103102
}
104-
return new While(opBuilder.build());
103+
return new StatefulWhile(opBuilder.build());
105104
}
106105

107106
/**
@@ -150,7 +149,7 @@ public Iterator<Operand<TType>> iterator() {
150149
}
151150

152151
/**
153-
* Optional attributes for {@link org.tensorflow.op.core.While}
152+
* Optional attributes for {@link org.tensorflow.op.core.StatefulWhile}
154153
*/
155154
public static class Options {
156155
private List<Shape> outputShapes;

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java renamed to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
* returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
3737
*/
3838
@Operator
39-
public final class PartitionedCall extends RawOp implements Iterable<Operand<TType>> {
39+
public final class StatelessPartitionedCall extends RawOp implements Iterable<Operand<TType>> {
4040
/**
4141
* The name of this op, as known by TensorFlow core engine
4242
*/
@@ -45,7 +45,7 @@ public final class PartitionedCall extends RawOp implements Iterable<Operand<TTy
4545
private List<Output<?>> output;
4646

4747
@SuppressWarnings("unchecked")
48-
private PartitionedCall(Operation operation) {
48+
private StatelessPartitionedCall(Operation operation) {
4949
super(operation);
5050
int outputIdx = 0;
5151
int outputLength = operation.outputListLength("output");
@@ -66,14 +66,14 @@ private PartitionedCall(Operation operation) {
6666
* devices, setting this op apart from the regular Call op.
6767
* </pre>
6868
* @param options carries optional attribute values
69-
* @return a new instance of PartitionedCall
69+
* @return a new instance of StatelessPartitionedCall
7070
*/
7171
@Endpoint(
7272
describeByClass = true
7373
)
74-
public static PartitionedCall create(Scope scope, Iterable<Operand<?>> args,
74+
public static StatelessPartitionedCall create(Scope scope, Iterable<Operand<?>> args,
7575
List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) {
76-
OperationBuilder opBuilder = scope.env().opBuilder("PartitionedCall", scope.makeOpName("PartitionedCall"));
76+
OperationBuilder opBuilder = scope.env().opBuilder("PartitionedCall", scope.makeOpName("StatelessPartitionedCall"));
7777
opBuilder.addInputList(Operands.asOutputs(args));
7878
opBuilder = scope.apply(opBuilder);
7979
opBuilder.setAttr("Tout", Operands.toDataTypes(Tout));
@@ -91,7 +91,7 @@ public static PartitionedCall create(Scope scope, Iterable<Operand<?>> args,
9191
}
9292
}
9393
}
94-
return new PartitionedCall(opBuilder.build());
94+
return new StatelessPartitionedCall(opBuilder.build());
9595
}
9696

9797
/**
@@ -140,7 +140,7 @@ public Iterator<Operand<TType>> iterator() {
140140
}
141141

142142
/**
143-
* Optional attributes for {@link org.tensorflow.op.core.PartitionedCall}
143+
* Optional attributes for {@link org.tensorflow.op.core.StatelessPartitionedCall}
144144
*/
145145
public static class Options {
146146
private String config;

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.tensorflow.op.generator;
1717

1818
import com.squareup.javapoet.TypeSpec;
19+
import java.util.StringJoiner;
1920
import org.tensorflow.proto.framework.ApiDef;
2021
import org.tensorflow.proto.framework.ApiDef.Endpoint;
2122
import org.tensorflow.proto.framework.OpDef;
@@ -56,13 +57,21 @@ public boolean isStateful() {
5657
return opDef.getIsStateful();
5758
}
5859

59-
public boolean equalOtherThanState(FullOpDef other) {
60+
public boolean isStateVariant(FullOpDef other) {
61+
if (this.equals(other)) return false;
62+
63+
if (this.isStateful() == other.isStateful()) return false;
64+
6065
OpDef copy =
6166
opDef.toBuilder().setName(other.opDef.getName()).setIsStateful(other.isStateful()).build();
62-
return copy.equals(other.opDef);
67+
return copy.equals(other.opDef) && packageName.equals(other.packageName);
6368
}
6469

6570
public TypeSpec buildOpClass() {
71+
return buildOpClass(className);
72+
}
73+
74+
public TypeSpec buildOpClass(String className) {
6675
TypeSpec.Builder cls = TypeSpec.classBuilder(className);
6776
try {
6877
new ClassGenerator(cls, opDef, apiDef, basePackage, packageName, group, className, endpoint)
@@ -116,4 +125,17 @@ public int hashCode() {
116125
result = 31 * result + endpoint.hashCode();
117126
return result;
118127
}
128+
129+
@Override
130+
public String toString() {
131+
return new StringJoiner(", ", FullOpDef.class.getSimpleName() + "(", ")")
132+
.add("opDef=" + opDef)
133+
.add("apiDef=" + apiDef)
134+
.add("basePackage='" + basePackage + "'")
135+
.add("packageName='" + packageName + "'")
136+
.add("group='" + group + "'")
137+
.add("className='" + className + "'")
138+
.add("endpoint=" + endpoint)
139+
.toString();
140+
}
119141
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,30 @@ private static void generate(File outputDir, String packageName, File opDefs) {
186186
generate(outputDir, packageName, defs);
187187
}
188188

189+
private static void writeToFile(TypeSpec spec, File outputDir, String packageName) {
190+
JavaFile file =
191+
JavaFile.builder(packageName, spec).indent(" ").skipJavaLangImports(true).build();
192+
193+
File outputFile =
194+
new File(outputDir, packageName.replace('.', '/') + '/' + spec.name + ".java");
195+
outputFile.getParentFile().mkdirs();
196+
try {
197+
StringBuilder builder = new StringBuilder();
198+
builder.append(LICENSE + '\n');
199+
builder.append("// This class has been generated, DO NOT EDIT!\n\n");
200+
file.writeTo(builder);
201+
202+
Files.write(
203+
outputFile.toPath(),
204+
builder.toString().getBytes(StandardCharsets.UTF_8),
205+
StandardOpenOption.WRITE,
206+
StandardOpenOption.CREATE,
207+
StandardOpenOption.TRUNCATE_EXISTING);
208+
} catch (IOException ioException) {
209+
throw new IllegalStateException("Failed to write file " + outputFile, ioException);
210+
}
211+
}
212+
189213
/** Generate all the ops that pass {@link ClassGenerator#canGenerateOp(OpDef, ApiDef)}. */
190214
private static void generate(File outputDir, String basePackage, Map<OpDef, ApiDef> ops) {
191215
List<FullOpDef> fullOps =
@@ -219,34 +243,19 @@ private static void generate(File outputDir, String basePackage, Map<OpDef, ApiD
219243
}))
220244
.collect(Collectors.toList());
221245

246+
List<StatefulPair> statefulPairs = StatefulPair.extractStatefulPairs(fullOps);
247+
222248
fullOps.forEach(
223249
(def) -> {
224250
TypeSpec spec = def.buildOpClass();
225251

226-
JavaFile file =
227-
JavaFile.builder(def.packageName, spec)
228-
.indent(" ")
229-
.skipJavaLangImports(true)
230-
.build();
231-
232-
File outputFile =
233-
new File(outputDir, def.packageName.replace('.', '/') + '/' + spec.name + ".java");
234-
outputFile.getParentFile().mkdirs();
235-
try {
236-
StringBuilder builder = new StringBuilder();
237-
builder.append(LICENSE + '\n');
238-
builder.append("// This class has been generated, DO NOT EDIT!\n\n");
239-
file.writeTo(builder);
240-
241-
Files.write(
242-
outputFile.toPath(),
243-
builder.toString().getBytes(StandardCharsets.UTF_8),
244-
StandardOpenOption.WRITE,
245-
StandardOpenOption.CREATE,
246-
StandardOpenOption.TRUNCATE_EXISTING);
247-
} catch (IOException ioException) {
248-
throw new IllegalStateException("Failed to write file " + outputFile, ioException);
249-
}
252+
writeToFile(spec, outputDir, def.packageName);
253+
});
254+
255+
statefulPairs.forEach(
256+
(pair) -> {
257+
pair.buildOpClasses()
258+
.forEach((spec) -> writeToFile(spec, outputDir, pair.getPackageName()));
250259
});
251260
}
252261
}

0 commit comments

Comments
 (0)