Skip to content

Commit ec92d15

Browse files
committed
Test for wrappers, using If
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent 3bee40b commit ec92d15

File tree

7 files changed

+190
-8
lines changed

7 files changed

+190
-8
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,8 @@ public Map<String, Operand<?>> call(ConcreteFunction function,
12341234
* }
12351235
* ```
12361236
* </pre>
1237-
* Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments.
1237+
*
1238+
* <p>Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments.
12381239
*
12391240
* @param branchIndex The branch selector, an int32 Tensor.
12401241
* @param input A list of input tensors passed to the branch function.
@@ -2951,7 +2952,8 @@ public IdentityN identityN(Iterable<Operand<?>> input) {
29512952

29522953
/**
29532954
* output = cond ? then_branch(input) : else_branch(input)
2954-
* Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments.
2955+
*
2956+
* <p>Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments.
29552957
*
29562958
* @param cond <pre>
29572959
* A Tensor. If the tensor is a scalar of non-boolean type, the
@@ -4021,7 +4023,8 @@ public <T extends TType> ParallelDynamicStitch<T> parallelDynamicStitch(
40214023

40224024
/**
40234025
* returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
4024-
* Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments.
4026+
*
4027+
* <p>Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments.
40254028
*
40264029
* @param args A list of input tensors.
40274030
* @param Tout A list of output types.
@@ -8108,7 +8111,8 @@ public Where where(Operand<? extends TType> condition) {
81088111

81098112
/**
81108113
* output = input; While (Cond(output)) { output = Body(output) }
8111-
* Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments.
8114+
*
8115+
* <p>Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments.
81128116
*
81138117
* @param input A list of input tensors whose types are T.
81148118
* @param cond <pre>

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
* }
5151
* ```
5252
* </pre>
53-
* Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments.
53+
*
54+
* <p>Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments.
5455
*/
5556
@Operator
5657
public interface Case extends Iterable<Operand<TType>> {

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
/**
3333
* output = cond ? then_branch(input) : else_branch(input)
34-
* Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments.
34+
*
35+
* <p>Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments.
3536
*/
3637
@Operator
3738
public interface If extends Iterable<Operand<TType>> {

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030
/**
3131
* returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned.
32-
* Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments.
32+
*
33+
* <p>Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments.
3334
*/
3435
@Operator
3536
public interface PartitionedCall extends Iterable<Operand<TType>> {

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
/**
3333
* output = input; While (Cond(output)) { output = Body(output) }
34-
* Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments.
34+
*
35+
* <p>Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments.
3536
*/
3637
@Operator
3738
public interface While extends Iterable<Operand<TType>> {
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
16+
package org.tensorflow.op.core;
17+
18+
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
20+
import java.util.Arrays;
21+
import java.util.Collections;
22+
import org.junit.jupiter.api.Test;
23+
import org.tensorflow.ConcreteFunction;
24+
import org.tensorflow.EagerSession;
25+
import org.tensorflow.Graph;
26+
import org.tensorflow.Operand;
27+
import org.tensorflow.Session;
28+
import org.tensorflow.Signature;
29+
import org.tensorflow.op.Ops;
30+
import org.tensorflow.types.TInt32;
31+
32+
public class IfTest {
33+
34+
private static Operand<TInt32> basicIf(Ops tf, Operand<TInt32> a, Operand<TInt32> b) {
35+
ConcreteFunction thenBranch =
36+
ConcreteFunction.create(
37+
(ops) -> {
38+
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
39+
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
40+
return Signature.builder().input("a", a1).input("b", b1).output("y", a1).build();
41+
});
42+
43+
ConcreteFunction elseBranch =
44+
ConcreteFunction.create(
45+
(ops) -> {
46+
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
47+
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
48+
Operand<TInt32> y = ops.math.neg(b1);
49+
return Signature.builder().input("a", a1).input("b", b1).output("y", y).build();
50+
});
51+
52+
return (Operand<TInt32>)
53+
tf.ifOp(
54+
tf.math.greater(a, b),
55+
Arrays.asList(a, b),
56+
Arrays.asList(TInt32.class),
57+
thenBranch,
58+
elseBranch)
59+
.output()
60+
.get(0);
61+
}
62+
63+
@Test
64+
public void testGraph() {
65+
try (Graph g = new Graph();
66+
Session s = new Session(g)) {
67+
Ops tf = Ops.create(g);
68+
Operand<TInt32> a = tf.placeholder(TInt32.class);
69+
Operand<TInt32> b = tf.placeholder(TInt32.class);
70+
Operand<TInt32> c = basicIf(tf, a, b);
71+
72+
assertEquals(StatelessIf.OP_NAME, c.op().type());
73+
74+
try (TInt32 out =
75+
(TInt32)
76+
s.runner()
77+
.feed(a, TInt32.scalarOf(2))
78+
.feed(b, TInt32.scalarOf(1))
79+
.fetch(c)
80+
.run()
81+
.get(0)) {
82+
assertEquals(2, out.getInt());
83+
}
84+
85+
try (TInt32 out =
86+
(TInt32)
87+
s.runner()
88+
.feed(a, TInt32.scalarOf(2))
89+
.feed(b, TInt32.scalarOf(3))
90+
.fetch(c)
91+
.run()
92+
.get(0)) {
93+
assertEquals(-3, out.getInt());
94+
}
95+
}
96+
}
97+
98+
@Test
99+
public void testStatefullness() {
100+
try (Graph g = new Graph()) {
101+
Ops tf = Ops.create(g);
102+
Operand<TInt32> a = tf.placeholder(TInt32.class);
103+
Operand<TInt32> b = tf.placeholder(TInt32.class);
104+
105+
ConcreteFunction thenBranch =
106+
ConcreteFunction.create(
107+
(ops) -> {
108+
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
109+
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
110+
Operand<TInt32> result =
111+
(Operand<TInt32>)
112+
ops.statefulIf(
113+
ops.constant(false),
114+
Collections.emptyList(),
115+
Arrays.asList(TInt32.class),
116+
ConcreteFunction.create(
117+
(ops1) ->
118+
Signature.builder().output("y", ops1.constant(1)).build()),
119+
ConcreteFunction.create(
120+
(ops1) ->
121+
Signature.builder().output("y", ops1.constant(1)).build()))
122+
.output()
123+
.get(0);
124+
return Signature.builder()
125+
.input("a", a1)
126+
.input("b", b1)
127+
.output("y", result)
128+
.build();
129+
});
130+
131+
ConcreteFunction elseBranch =
132+
ConcreteFunction.create(
133+
(ops) -> {
134+
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
135+
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
136+
Operand<TInt32> y = ops.math.neg(b1);
137+
return Signature.builder().input("a", a1).input("b", b1).output("y", y).build();
138+
});
139+
140+
Operand<TInt32> output =
141+
(Operand<TInt32>)
142+
tf.ifOp(
143+
tf.math.greater(a, b),
144+
Arrays.asList(a, b),
145+
Arrays.asList(TInt32.class),
146+
thenBranch,
147+
elseBranch)
148+
.output()
149+
.get(0);
150+
151+
assertEquals(StatefulIf.OP_NAME, output.op().type());
152+
}
153+
}
154+
155+
@Test
156+
public void testEager() {
157+
try (EagerSession e = EagerSession.create()) {
158+
Ops tf = Ops.create(e);
159+
160+
Operand<TInt32> out1 = basicIf(tf, tf.constant(2), tf.constant(1));
161+
162+
assertEquals(StatelessIf.OP_NAME, out1.op().type());
163+
164+
try (TInt32 out = out1.asTensor()) {
165+
assertEquals(2, out.getInt());
166+
}
167+
168+
try (TInt32 out = basicIf(tf, tf.constant(2), tf.constant(3)).asTensor()) {
169+
assertEquals(-3, out.getInt());
170+
}
171+
}
172+
}
173+
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ void buildClass() {
240240
}
241241

242242
if (isStateSelector) {
243+
builder.addJavadoc("\n<p>");
243244
builder.addJavadoc(
244245
"Selects between {@link "
245246
+ statefulPair.statefulClassName

0 commit comments

Comments
 (0)