In [ ]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jnr
import jax
import optax

plt.rcParams.update({'font.size': 20})
values = jnr.normal(jnr.PRNGKey(4), (10,)) * 0.1
values
Out[ ]:
Array([ 0.11777242,  0.07384811, -0.10801565,  0.03344669,  0.00033997,
        0.02359026, -0.06941637,  0.12355181,  0.11009053, -0.02784047],      dtype=float32)
In [ ]:
plt.bar(range(10), values)
plt.title("Initial values")
plt.show()
No description has been provided for this image
In [ ]:
plt.bar(range(10), jnn.softmax(values))
plt.title("Initial softmax values")
plt.show()
No description has been provided for this image
In [ ]:
optimizer = optax.adam(3e-1)
opt_state = optimizer.init(values)

# The log function could be replaced with *any* concave function
# Minimization will force to one-hot
# Maximization will spread the values (entropy)
def loss_fn(values):
    return jnp.sum(jnp.log(jnn.softmax(values)))
    #return -jnp.sum(jnn.softmax(values)**2.0) # convex so invert with -, also works
In [ ]:
# Optimize loop
results = []
epochs = 100
for i in range(epochs):
    grads = jax.grad(loss_fn)(values)
    updates, opt_state = optimizer.update(grads, opt_state)
    values = optax.apply_updates(values, updates)
    results += [jnn.softmax(values)]
In [ ]:
# Plotting
res_arr = jnp.array(results)  # shape: (epochs, num_classes)
epochs = res_arr.shape[0]
class_indices = np.arange(res_arr.shape[1])

x = np.tile(class_indices, epochs)
y = res_arr.reshape(-1)
colors = np.repeat(np.arange(epochs), res_arr.shape[1])

plt.figure(figsize=(10, 5))
scatter = plt.scatter(x, y, c=colors, cmap='viridis', s=30, alpha=0.8)
plt.colorbar(scatter, label="Epoch")
plt.xlabel("Element index")
plt.ylabel("Softmax")
plt.title("Softmax over epochs (Color = Epoch)", fontsize=15)
plt.xticks(class_indices)
plt.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.show()
No description has been provided for this image
In [ ]:
plt.plot(res_arr, linewidth=1)
plt.xlabel("Epochs")
plt.ylabel("Probability")
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.tight_layout() 
plt.show()
No description has been provided for this image
In [ ]:
# Modern art version for printing
plt.plot(res_arr, linewidth=2)
# remove all axis etc
plt.axis('off')
# large figsize
plt.gcf().set_size_inches(19, 10)
plt.show()
No description has been provided for this image