Skip to content

Commit e6512b4

Browse files
committed
Add gramian algorithm
1 parent 8106aea commit e6512b4

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

computer_vision/gramian.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Image style reconstruction with Gram matrices.
3+
4+
https://arxiv.org/pdf/1603.08155#page=7&zoom=auto,-294,3
5+
"""
6+
7+
import numpy as np
8+
9+
10+
def gram_matrix(mat: np.ndarray) -> np.ndarray:
11+
"""
12+
Returns the Gram (Gramian) matrix of an image.
13+
14+
:param mat: matrix of shape (C, H, W); C = color channels, H = height, W = width.
15+
:type mat: np.ndarray
16+
:return: matrix of shape (C, C).
17+
:rtype: np.ndarray
18+
19+
Examples
20+
--------
21+
>>> gram_matrix(np.ones((2,5,5)))
22+
array([[0.5, 0.5],
23+
[0.5, 0.5]])
24+
>>> gram_matrix(np.ones((3,5,5)))
25+
array([[0.33333333, 0.33333333, 0.33333333],
26+
[0.33333333, 0.33333333, 0.33333333],
27+
[0.33333333, 0.33333333, 0.33333333]])
28+
>>> gram_matrix(np.ones((3,5,5))).shape
29+
(3, 3)
30+
"""
31+
color, height, width = mat.shape
32+
vec = mat.reshape(color, height * width)
33+
gram = vec @ vec.T
34+
return gram / (color * height * width)
35+
36+
37+
def gram_loss(input_features: np.ndarray, reference_features: np.ndarray) -> np.float64:
38+
"""
39+
Calculates the squared Frobenius norm of the difference between
40+
the Gram matrices of the input and reference image.
41+
42+
:param input_features: Feature map of shape (C, H, W)
43+
:type input_features: np.ndarray
44+
:param reference_features: Feature map of shape (C, H, W)
45+
:type reference_features: np.ndarray
46+
:return: Gram loss between the two feature maps.
47+
:rtype: float64
48+
49+
Examples
50+
--------
51+
>>> a = np.random.randn(3,5,5)
52+
>>> gram_loss(a, a)
53+
np.float64(0.0)
54+
>>> a = np.zeros((3,5,5))
55+
>>> b = np.ones((3,5,5))
56+
>>> gram_loss(a, b)
57+
np.float64(1.0)
58+
"""
59+
input_gram = gram_matrix(input_features)
60+
reference_gram = gram_matrix(reference_features)
61+
return np.sum(np.square(input_gram - reference_gram)).astype(np.float64)
62+
63+
64+
if __name__ == "__main__":
65+
import doctest
66+
67+
doctest.testmod()

0 commit comments

Comments
 (0)