CS 307: Week 15

# standard imports
import matplotlib.pyplot as plt
import numpy as np

# sklearn data generation
from sklearn.datasets import make_classification
from sklearn.datasets import make_circles

# pytorch imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchsummary import summary

Logistic Regression as a Neural Network: Linear Data

# Generate "linear" data
X, y = make_classification(
    n_samples=100,
    n_features=2,
    n_informative=2,
    n_redundant=0,
    n_clusters_per_class=1,
    random_state=2,
    n_classes=2,
)
# Create a new figure and an axes
fig, ax = plt.subplots()

# Plot the generated data
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)

# Set labels and title
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_title("Linear Data")

# Add a grid
ax.grid(color="lightgrey", linestyle="--")

# Display the plot
plt.show()

# Define the logistic regression model class
class LogisticRegression(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.linear = nn.Linear(input_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out


# Create the model instance
input_size = X.shape[1]
model = LogisticRegression(input_size)
print(model)

# Define the loss function
loss_fn = nn.BCELoss()

# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Convert the data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# Train the model
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_tensor)
    loss = loss_fn(outputs, y_tensor.view(-1, 1))

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)
LogisticRegression(
  (linear): Linear(in_features=2, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)
Trained model parameters:
linear.weight tensor([[1.8507, 0.4726]])
linear.bias tensor([-0.1388])
# Generate a grid of points
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]

# Convert the grid points to PyTorch tensor
grid_tensor = torch.tensor(grid_points, dtype=torch.float32)

# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
    predictions = model(grid_tensor)
    labels = (predictions >= 0.5).float().numpy().reshape(xx.shape)

# Create a new figure and an axes
fig, ax = plt.subplots()

# Plot the decision boundary
ax.contourf(xx, yy, labels, alpha=0.5, cmap=plt.cm.Set1)

# Plot the generated data
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)

# Set labels and title
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_title("Linear Data with Decision Boundary")

# Add a legend
legend_elements = [
    plt.Line2D(
        [0], [0], marker="o", color="w", label="Class 0", markerfacecolor="grey", markersize=8
    ),
    plt.Line2D([0], [0], marker="o", color="w", label="Class 1", markerfacecolor="r", markersize=8),
]
ax.legend(handles=legend_elements)

# Add a grid
ax.grid(color="lightgrey", linestyle="--")

# Display the plot
plt.show()

Logistic Regression as a Neural Network: Circle Data

# Generate circles data
X, y = make_circles(n_samples=100, noise=0.05, random_state=42)

# Plot the generated data
fig, ax = plt.subplots()
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_title("Circles Data")
ax.grid(color="lightgrey", linestyle="--")
plt.show()

# Create the model instance
input_size = X.shape[1]
model = LogisticRegression(input_size)
print(model)

# Define the loss function
criterion = nn.BCELoss()

# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Convert the data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# Train the model
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor.view(-1, 1))

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)
LogisticRegression(
  (linear): Linear(in_features=2, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)
Trained model parameters:
linear.weight tensor([[-0.0548, -0.2432]])
linear.bias tensor([0.0183])
# Generate a grid of points
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]

# Convert the grid points to PyTorch tensor
grid_tensor = torch.tensor(grid_points, dtype=torch.float32)

# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
    predictions = model(grid_tensor)
    labels = (predictions >= 0.5).float().numpy().reshape(xx.shape)

# Create a new figure and an axes
fig, ax = plt.subplots()

# Plot the decision boundary
ax.contourf(xx, yy, labels, alpha=0.5, cmap=plt.cm.Set1)

# Plot the generated data
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)

# Set labels and title
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_title("Linearly Separable Data with Decision Boundary")

# Add a legend
legend_elements = [
    plt.Line2D(
        [0], [0], marker="o", color="w", label="Class 0", markerfacecolor="grey", markersize=8
    ),
    plt.Line2D([0], [0], marker="o", color="w", label="Class 1", markerfacecolor="r", markersize=8),
]
ax.legend(handles=legend_elements)

# Add a grid
ax.grid(True, color="lightgrey")

# Display the plot
plt.show()

Multi-Layer Neural Network: Circle Data

# Define multi-layer neural network class
class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(input_size, 100),
            nn.ReLU(),
            nn.Linear(100, 10),
            nn.ReLU(),
            nn.Linear(10, 1),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out


# Create the model instance
input_size = X.shape[1]
model = MLP(input_size)
print(model)

# Define the loss function
criterion = nn.BCELoss()

# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Convert the data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# Train the model
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor.view(-1, 1))

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)
MLP(
  (linear): Sequential(
    (0): Linear(in_features=2, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
    (3): ReLU()
    (4): Linear(in_features=10, out_features=1, bias=True)
  )
  (sigmoid): Sigmoid()
)
Trained model parameters:
linear.0.weight tensor([[-0.1552,  0.1662],
        [ 0.3908,  0.0612],
        [ 0.9017, -0.6374],
        [ 0.6277, -0.2109],
        [-1.1124,  0.1918],
        [-0.3708,  0.5855],
        [-0.6589,  0.4729],
        [-0.0327, -0.0749],
        [ 0.8348,  0.6613],
        [ 0.0705, -0.0430],
        [ 0.7425,  0.5211],
        [ 0.8188,  0.7887],
        [ 0.1538,  0.8428],
        [-0.2951,  0.0111],
        [-0.1261, -0.6081],
        [-0.7813, -0.7862],
        [ 0.4460, -0.1123],
        [ 0.5477,  0.3837],
        [-0.5079, -0.7850],
        [-0.0103,  0.6520],
        [-0.0152, -0.2751],
        [-0.1480, -0.0415],
        [-0.0583,  0.5241],
        [ 0.1338, -0.4054],
        [ 0.3726, -0.0295],
        [-0.1696,  0.3993],
        [-0.8201, -0.0072],
        [-0.2884,  0.7008],
        [ 0.1657,  0.5699],
        [ 0.3853,  0.3299],
        [ 0.7182, -0.3519],
        [ 0.7159,  0.8815],
        [ 0.5753,  0.1802],
        [-0.3963, -0.6490],
        [ 0.4729,  0.4773],
        [-0.2325, -0.2971],
        [-0.5070, -0.3100],
        [-0.7994,  0.0789],
        [-0.2821, -0.3900],
        [ 0.0781, -0.2086],
        [-0.4211,  0.5708],
        [-0.0714,  0.0250],
        [ 0.7506,  0.9815],
        [-0.0383, -0.8640],
        [ 0.2069,  0.9009],
        [-0.4418,  0.1689],
        [-0.8756, -0.7831],
        [ 0.2057, -0.8478],
        [ 0.2507, -0.1250],
        [-0.2398,  0.7155],
        [ 0.3409,  0.2362],
        [-0.6451,  0.3934],
        [-0.4057, -0.7846],
        [ 0.8075,  0.6639],
        [-0.3638,  1.0874],
        [-0.2788,  0.0188],
        [-0.5361,  0.3410],
        [-1.0274,  0.1794],
        [ 0.1164,  0.3535],
        [-0.2864, -0.4005],
        [-0.5056,  0.4572],
        [-0.2261, -0.6815],
        [ 0.8208, -0.4890],
        [ 0.0553, -1.0744],
        [ 0.6270, -0.0787],
        [ 0.0265, -0.7450],
        [ 0.2776,  0.5807],
        [-0.2465,  0.6260],
        [ 0.5899, -0.3780],
        [ 0.7019,  0.2656],
        [ 0.7896, -0.7558],
        [ 0.5840, -0.0084],
        [-0.5880,  0.2824],
        [ 0.6537, -0.5193],
        [ 0.4866, -0.1538],
        [ 0.7754,  0.9852],
        [-0.1837,  0.2515],
        [-0.0591, -0.6305],
        [-0.3105,  0.9375],
        [ 0.8369,  0.8242],
        [ 0.3355, -0.0236],
        [-0.2867, -0.1375],
        [ 0.4715,  0.5563],
        [-0.1985,  0.7202],
        [ 0.2024, -0.0913],
        [-0.1607,  0.3007],
        [ 0.3729, -0.1655],
        [-1.0366, -0.0524],
        [-0.6882,  0.0674],
        [-0.7025,  0.0544],
        [ 0.5286,  0.0897],
        [-0.4045,  0.0972],
        [ 0.2119, -0.4636],
        [ 1.0665,  0.3329],
        [-0.0363,  0.5967],
        [-0.2567,  0.7672],
        [ 0.7972, -1.0749],
        [ 0.4944,  0.8297],
        [ 0.3284, -0.1175],
        [ 0.1765, -0.2422]])
linear.0.bias tensor([ 0.2559, -0.4661, -0.3638, -0.1724, -0.0470, -0.0437, -0.2974, -0.4033,
         0.0828,  0.7428, -0.0816, -0.4901,  0.1522,  0.5311,  0.6859, -0.4152,
         0.2144, -0.0608,  0.2163,  0.6613,  0.2834, -0.5163,  0.5305,  0.4867,
         0.4400,  0.0984,  0.0130,  0.7802,  0.6411,  0.5216, -0.3364, -0.5505,
        -0.2215, -0.1535, -0.2762, -0.4448,  0.6125, -0.0136, -0.5109,  0.3542,
         0.0012, -0.5364, -0.3735, -0.0551, -0.4924, -0.2116, -0.0246,  0.1173,
         0.6741, -0.2093, -0.0397,  0.2076,  0.2921,  0.0633, -0.3179,  0.2936,
         0.8245, -0.0416,  0.0018, -0.5485,  0.1155,  0.0684,  0.0998, -0.3673,
        -0.0914, -0.2500, -0.4425, -0.7039,  0.7244,  0.8031, -0.2319,  0.4367,
        -0.1549,  0.1707, -0.5516, -0.3935,  0.3885,  0.0465, -0.2711, -0.4959,
         0.0425,  0.0829,  0.1840, -0.1718,  0.0774,  0.6461,  0.0195, -0.2263,
        -0.0039, -0.0451,  0.2014, -0.6608,  0.7401, -0.4128,  0.6558, -0.2245,
        -0.2245,  0.2906, -0.4606,  0.3155])
linear.2.weight tensor([[ 1.8622e-02,  6.2621e-02,  3.3473e-01,  1.1944e-01,  3.5441e-02,
          1.3992e-01,  4.2941e-02,  7.8328e-02,  2.0340e-01, -6.5843e-02,
          2.4190e-01,  2.5889e-01, -4.4221e-03, -1.0388e-01,  1.2024e-01,
          1.4759e-01,  2.2038e-02,  4.3134e-02,  1.3794e-01, -2.1506e-01,
         -2.9418e-02, -4.5284e-02, -2.0085e-01,  1.9045e-02,  7.3719e-02,
          1.6164e-03,  8.5542e-02, -1.5003e-01, -5.4246e-03, -7.6768e-02,
          1.0049e-01,  1.7256e-01,  1.9108e-01,  2.2537e-01,  4.8786e-02,
         -1.2172e-02, -4.9769e-02,  1.5382e-01, -6.8928e-03,  9.3331e-03,
          1.0387e-01,  1.4768e-02,  1.9063e-01,  2.4183e-01,  1.7646e-01,
         -7.5387e-03,  2.4984e-01,  1.4568e-01,  7.5192e-02,  1.3471e-02,
          6.4658e-02, -2.9605e-02,  1.1065e-01,  1.8541e-01,  1.0226e-01,
         -3.2445e-02, -2.4388e-01,  5.1804e-02,  4.7509e-02,  7.8389e-02,
          7.8369e-04,  1.6394e-01,  2.1678e-01,  2.0788e-01,  1.6703e-01,
          1.6338e-01,  4.4945e-02, -3.7065e-02,  4.7272e-02, -5.9655e-02,
          1.6999e-01,  3.1912e-02,  5.6825e-02,  1.9961e-01, -1.1651e-02,
          1.3927e-01, -4.1680e-03,  6.9804e-02,  5.9821e-04,  2.0834e-01,
          5.2322e-02,  7.3099e-03,  1.9956e-01, -5.8514e-03, -4.1797e-03,
         -1.5351e-01, -8.4918e-03,  1.6469e-01,  1.4604e-01,  9.3315e-02,
          6.2983e-02,  8.3957e-02,  1.2398e-01,  2.2664e-01, -1.5328e-01,
         -1.6967e-02,  2.2102e-01,  1.0020e-01,  2.9564e-02,  9.4237e-02],
        [ 6.8362e-02, -3.3226e-02,  3.7335e-02,  5.1855e-04,  1.9892e-02,
          2.4783e-02,  1.8599e-01, -9.6560e-02, -7.3750e-02,  1.0674e-02,
         -6.4436e-02,  2.7056e-02,  2.9267e-02,  3.4604e-02, -7.3872e-02,
          3.3928e-02,  1.3361e-02,  1.0624e-02, -3.7467e-02, -1.1267e-01,
         -1.1807e-01,  6.4024e-02,  7.7406e-02, -1.5442e-01, -5.7367e-02,
         -3.4019e-02,  1.3052e-01,  4.4240e-02, -5.6302e-02, -6.4122e-03,
         -5.0595e-02, -2.6326e-02,  2.1997e-02, -7.5018e-03,  8.7540e-02,
         -7.5942e-02,  7.0556e-02,  1.8117e-01, -6.5708e-02, -5.2910e-02,
          4.0863e-02, -2.3050e-03,  4.8455e-02, -1.7060e-02,  6.5711e-02,
         -3.2473e-02,  1.5825e-01, -8.5744e-02, -1.0961e-04,  3.4896e-02,
         -4.2165e-02,  7.3071e-02, -4.0172e-02, -8.4240e-02,  8.1444e-02,
          1.0928e-01, -6.0524e-02,  1.5955e-01,  7.6167e-02, -8.3315e-02,
          9.3807e-02,  5.3447e-02, -8.5636e-02,  1.3680e-03, -4.5060e-03,
          1.0577e-02,  4.2857e-02, -8.7732e-02, -2.8685e-02, -1.0850e-01,
          9.1827e-02, -9.0820e-02,  4.1101e-02,  3.4341e-02, -4.1979e-02,
         -4.2798e-02,  3.7781e-02,  3.3529e-02, -2.5193e-02,  5.6737e-02,
          5.5406e-02,  1.0230e-01, -9.4538e-02,  2.7106e-02,  8.0571e-02,
         -3.8201e-02,  6.9393e-02,  2.1378e-01, -2.1380e-02,  7.2287e-02,
          4.0608e-02,  2.0958e-02, -5.3582e-02, -3.2791e-02,  5.3421e-02,
          7.0369e-02, -9.7057e-02, -7.0962e-02, -1.3053e-02, -1.3154e-01],
        [ 2.5573e-02, -8.8045e-02,  1.1591e-01,  5.0938e-02,  9.9390e-02,
          6.2487e-02,  7.5293e-02,  5.3585e-02,  1.9810e-01, -6.3387e-02,
          4.4192e-02,  3.6844e-02, -7.9424e-02, -3.7022e-02,  1.2105e-01,
          2.3678e-01,  1.1041e-01,  1.0673e-01,  1.6140e-01, -2.2515e-01,
          6.5968e-02, -5.3194e-02, -7.2484e-02,  1.0130e-01, -1.1547e-01,
          1.8224e-02,  1.3817e-01, -1.5857e-01, -1.0646e-01, -1.5017e-01,
          8.6377e-02,  1.0853e-01,  5.5264e-02,  4.1839e-02,  8.9809e-02,
          1.3154e-02,  9.3088e-02,  7.2051e-02, -7.4552e-02, -3.5353e-02,
          7.3739e-02,  6.8989e-02,  9.3409e-02,  1.9123e-01, -6.8728e-02,
          3.6068e-02,  1.0257e-01,  1.3651e-01, -7.0038e-02, -7.1325e-02,
         -1.8393e-02,  2.8170e-02,  1.5445e-01,  6.0257e-02,  1.1762e-01,
         -1.2453e-01, -2.4311e-01,  1.9340e-01,  2.9600e-02,  2.5503e-03,
         -7.5670e-02,  1.3132e-01,  2.0947e-01,  7.9620e-02,  1.9708e-01,
          1.9160e-01,  2.7549e-02, -9.8971e-02,  1.0308e-01,  5.9113e-03,
          2.3905e-01,  7.2379e-02, -6.1117e-03,  2.2074e-02,  1.5481e-02,
          8.7933e-02, -9.4487e-02, -3.8530e-03, -8.8162e-02,  7.8810e-02,
          8.9444e-03, -3.8364e-02,  8.5541e-02,  7.2419e-02,  5.4599e-02,
         -1.7598e-01,  1.5230e-01,  1.9488e-01,  2.2441e-02,  3.7222e-02,
          9.8094e-02, -9.7661e-02, -1.3322e-02,  2.3324e-01, -2.1238e-01,
          9.7348e-03,  1.9131e-01, -3.0601e-02, -6.7119e-02, -1.0651e-02],
        [ 1.2001e-01,  6.6184e-02, -3.8852e-01, -1.6756e-01, -4.1074e-01,
         -2.7874e-01, -3.2311e-01,  4.6253e-02, -2.7442e-01,  2.8411e-01,
         -2.5296e-01, -2.0116e-01, -9.7619e-02,  1.9062e-01,  3.3622e-01,
         -2.2674e-01, -4.8533e-02, -1.0872e-01, -2.1704e-01,  2.3131e-01,
          1.4035e-01, -8.9710e-02,  1.2239e-01,  2.1842e-01,  1.5730e-01,
         -8.5866e-02, -2.3182e-01,  2.1580e-01,  1.9729e-01,  1.7873e-01,
         -2.9578e-01, -2.0459e-01, -1.7349e-01, -1.6372e-01, -1.2350e-01,
          6.6234e-04,  2.4676e-01, -2.7329e-01,  1.2991e-02,  1.7668e-01,
         -1.2833e-01, -1.5400e-02, -4.0336e-01, -8.7595e-02, -2.4353e-01,
         -1.2881e-01, -3.9048e-01, -2.2145e-01,  2.3017e-01, -3.0747e-01,
         -1.2543e-01, -1.7666e-01, -2.3409e-02, -1.8915e-01, -3.4945e-01,
          1.2548e-01,  2.9361e-01, -2.9707e-01, -1.6898e-01,  7.6719e-02,
         -2.0233e-01, -5.7121e-02, -1.8372e-01, -2.6115e-01, -1.1382e-01,
         -1.3189e-01, -8.4151e-02, -1.0080e-02,  1.5255e-01,  3.1058e-01,
         -3.8296e-01,  1.0671e-01, -2.1891e-01, -7.6357e-02,  2.1142e-02,
         -2.9106e-01,  1.4400e-01, -1.4506e-01, -2.8305e-01, -2.3570e-01,
         -8.7220e-02, -7.0244e-03, -7.5768e-02, -3.4835e-01,  4.1225e-02,
          2.2703e-01, -1.5725e-01, -2.9117e-01, -1.2888e-01, -1.7092e-01,
          3.4581e-03,  5.3998e-02,  3.3479e-01, -2.9778e-01,  2.3645e-01,
         -3.4729e-01, -2.7921e-01, -1.5069e-01,  1.9752e-02,  5.8065e-02],
        [-5.4553e-02, -4.5602e-02,  3.2947e-01,  1.0174e-01, -5.2202e-04,
          8.2068e-02, -3.8506e-02, -5.3125e-02,  8.6090e-02,  6.1649e-02,
          3.6994e-03,  9.2804e-02,  4.3781e-02, -1.2755e-01, -4.7284e-03,
          1.6225e-01, -2.9808e-02, -3.9025e-02,  1.3004e-01, -2.2772e-01,
          8.8639e-02,  6.9559e-02, -2.0367e-01,  1.3371e-01,  6.5520e-02,
         -1.0507e-01,  1.2314e-01, -2.4592e-01, -1.0513e-01, -4.0831e-03,
          1.7138e-01, -2.4589e-02,  4.0559e-02,  1.6504e-01,  5.6009e-02,
         -6.1749e-02, -9.3438e-02,  7.2871e-02, -4.9889e-02,  1.1650e-01,
          2.2124e-02, -8.0194e-02,  1.3347e-01,  1.6645e-01,  9.0684e-02,
         -5.9240e-02,  1.7189e-01,  2.3240e-01,  4.5507e-02, -6.2648e-02,
          4.0546e-02, -4.1254e-02,  1.7739e-01,  3.0887e-02, -3.5224e-02,
         -3.7210e-02, -1.5947e-01, -3.1872e-02, -3.3351e-02, -3.6145e-02,
          7.8965e-02,  1.1181e-01,  1.6171e-01,  1.5147e-01,  2.1435e-01,
          4.9757e-02, -9.4517e-02, -5.3667e-02,  1.2132e-01, -7.5534e-02,
          2.8248e-01,  7.5536e-02, -6.7861e-02,  2.3059e-01, -7.1580e-02,
          6.7014e-02, -8.1256e-02,  1.1424e-01,  6.3468e-02,  6.2035e-02,
          1.4526e-01,  8.1979e-02, -1.5853e-03,  4.3363e-02,  7.5537e-02,
         -2.1663e-01,  1.8044e-02, -1.5894e-02,  8.0112e-02,  9.1889e-02,
          1.0201e-01, -5.0367e-02,  3.2039e-02,  2.0247e-01, -2.5252e-01,
          8.4274e-02,  3.6241e-01,  8.3770e-02,  9.2811e-02,  2.2759e-02],
        [-9.9068e-02,  8.3931e-02,  1.4499e-01,  2.7427e-02,  1.9210e-02,
          9.4238e-02,  1.0492e-01, -7.1372e-02,  6.0575e-02,  1.4431e-02,
         -2.9898e-03,  1.3384e-01, -9.9928e-02,  2.3328e-02,  1.0005e-01,
          9.5734e-02,  8.8587e-02, -8.5461e-03,  1.7961e-01, -1.3813e-01,
          4.2987e-02,  1.7874e-02, -1.9960e-02,  5.1177e-03,  8.2496e-02,
          1.6523e-02,  3.0284e-03, -2.0240e-01, -1.0978e-01, -1.0157e-01,
          7.4652e-02,  5.1767e-02, -1.8972e-02,  1.0541e-01,  4.0136e-02,
          7.2443e-02,  3.8280e-02,  9.0670e-02,  8.1112e-02,  5.0933e-02,
          6.3011e-02,  8.6634e-03,  6.6254e-02,  1.7582e-01, -5.9472e-02,
         -2.0289e-02,  7.3015e-02,  1.4283e-02, -3.8810e-02,  9.7630e-02,
          1.1447e-01,  1.1003e-01,  1.0590e-01,  7.7853e-02,  8.9942e-02,
         -8.6978e-02, -1.5699e-01,  1.6661e-01, -3.5301e-02, -9.3873e-02,
          1.8533e-02,  8.6197e-02,  1.1259e-01,  1.4526e-01,  1.1808e-01,
          6.1545e-02, -4.5351e-02, -6.2482e-02,  1.2297e-01, -1.3761e-02,
          5.4073e-02, -2.2039e-02, -1.8448e-02,  1.0806e-01,  7.6928e-02,
          9.6910e-02, -1.1135e-01,  1.5403e-01,  9.4210e-02,  8.7735e-02,
          2.4225e-02, -4.6059e-02,  1.0244e-01, -5.1533e-02,  7.1413e-03,
         -6.4063e-03,  5.0946e-02,  7.4849e-05, -1.1183e-02,  2.5344e-02,
          8.3296e-02,  9.3464e-03,  1.0900e-01,  1.9461e-01, -1.9923e-01,
          9.8326e-02,  1.4660e-01, -5.7975e-02,  9.2319e-02,  8.1158e-02],
        [-6.7167e-02,  7.1184e-04,  1.5791e-01, -9.8793e-03,  2.1015e-01,
          7.5992e-03,  2.7879e-02, -7.8128e-02,  6.3670e-03, -7.0086e-02,
         -5.9196e-02, -4.4937e-02,  9.6057e-03, -2.0698e-02,  9.7075e-02,
          2.6569e-01, -6.2905e-02,  6.1533e-02,  1.7112e-01, -1.3718e-01,
          8.8466e-02, -7.8341e-02, -2.3541e-01, -4.4109e-02, -1.6006e-01,
         -5.0120e-02,  1.8363e-01, -2.1196e-01, -1.9872e-01, -1.2940e-01,
          3.3612e-02,  3.6166e-02, -2.2094e-02,  1.0273e-01,  8.0716e-02,
         -6.6213e-02, -1.6117e-03,  2.0900e-01,  7.8651e-03, -1.3583e-03,
          3.6920e-02,  9.9968e-02,  8.3010e-02,  2.3016e-01, -8.4851e-02,
          8.9986e-02,  1.8001e-01,  1.3624e-01, -6.3242e-02, -2.4393e-02,
          7.6437e-02,  4.5540e-03,  1.3074e-01,  3.9968e-02,  2.9840e-02,
          7.1045e-03, -8.2129e-02,  2.0141e-01, -7.8104e-02, -8.0765e-02,
         -6.6262e-04,  2.0895e-01,  1.3460e-01,  2.2653e-01, -2.8971e-02,
          7.9451e-02, -4.3767e-03,  7.9412e-02, -1.6735e-02, -2.4690e-01,
          9.0527e-02, -1.2564e-01,  4.2072e-02, -3.8262e-02, -2.7756e-03,
          7.3702e-02, -5.3369e-02,  1.3912e-02, -6.7335e-02, -1.0511e-02,
         -9.6422e-03, -2.9568e-02,  5.7698e-03, -4.0521e-02, -4.2858e-02,
         -9.5626e-02,  5.2717e-02,  1.8911e-01,  1.1738e-01,  2.1925e-01,
         -8.4953e-02,  9.4305e-02, -3.2372e-02,  4.5195e-02, -2.7433e-01,
          9.5198e-02,  2.1260e-01, -9.3293e-02,  3.7255e-02, -1.6457e-02],
        [ 7.9190e-02, -9.0281e-02, -4.8434e-01, -2.2859e-01, -4.7977e-01,
         -2.8357e-01, -3.3863e-01,  5.7739e-02, -3.9289e-01,  3.6095e-01,
         -2.4943e-01, -3.0130e-01, -1.9908e-01,  2.6051e-01,  3.4980e-01,
         -4.7705e-01,  3.9321e-02, -3.3786e-01, -2.0962e-01,  2.4211e-01,
          1.1756e-01, -5.5718e-02,  1.5113e-01,  2.4147e-01,  2.5145e-01,
         -6.7860e-02, -2.4019e-01,  2.6808e-01,  2.7704e-01,  1.9411e-01,
         -3.4520e-01, -3.6380e-01, -2.2810e-01, -1.1759e-01, -1.4810e-01,
         -1.7398e-02,  2.7162e-01, -1.9678e-01,  1.8633e-02,  2.2822e-01,
         -2.7284e-01, -8.7660e-02, -3.8525e-01, -1.7863e-01, -3.8974e-01,
         -1.2402e-01, -3.7496e-01, -3.3932e-01,  3.9204e-01, -2.9171e-01,
         -2.0227e-01, -2.2362e-01,  2.8815e-02, -2.2475e-01, -4.1650e-01,
          1.5006e-01,  3.5947e-01, -4.0337e-01, -1.4737e-01, -4.8684e-02,
         -2.7786e-01, -7.9308e-02, -1.3816e-01, -4.3264e-01, -2.1838e-01,
         -2.5365e-01, -1.1853e-01,  6.2728e-02,  3.8133e-01,  3.9223e-01,
         -3.5855e-01,  7.4689e-02, -1.7893e-01, -8.7796e-02, -4.0315e-02,
         -3.8301e-01,  1.5087e-01, -1.7118e-02, -3.3258e-01, -3.4386e-01,
          9.4994e-03, -1.7390e-02, -2.7306e-01, -2.5622e-01, -7.9605e-02,
          2.6046e-01, -8.5071e-02, -3.7498e-01, -1.9644e-01, -2.7132e-01,
          1.9121e-03,  5.6775e-02,  3.1357e-01, -4.1275e-01,  1.9716e-01,
         -2.8575e-01, -4.0113e-01, -3.0445e-01,  1.8189e-02,  2.6216e-01],
        [ 7.9037e-02, -7.3755e-02,  7.2562e-02,  7.8806e-02, -5.2712e-02,
          5.8386e-02, -1.2423e-03,  7.9217e-02,  3.3525e-02,  5.0399e-02,
          8.8164e-02,  2.3555e-02,  1.7022e-02,  3.3145e-02, -6.5795e-02,
         -8.3565e-02, -5.4582e-02, -7.1657e-02,  4.7716e-02,  8.8822e-02,
          8.3683e-02, -5.5514e-02,  8.6972e-03,  2.2264e-02, -9.7950e-02,
          7.5916e-03, -8.2544e-02, -4.7238e-02,  1.4602e-02,  9.3010e-02,
         -1.7160e-02, -1.0316e-01,  3.0830e-02, -6.4606e-02,  1.4981e-02,
         -1.5266e-02, -3.9313e-02, -7.5917e-02,  1.0241e-02, -8.6801e-02,
         -4.0603e-02,  8.2484e-02, -3.7631e-02,  4.1494e-02, -1.2643e-02,
         -1.3242e-02,  3.8772e-02, -1.0443e-01,  9.8534e-02,  1.7557e-02,
         -4.5179e-02,  9.1086e-03,  3.9724e-02,  3.1044e-02, -1.2642e-02,
          3.4352e-02, -2.6861e-02,  4.1675e-02, -6.6145e-02, -4.4983e-02,
         -3.3449e-02,  7.8412e-02,  1.1455e-02, -5.8443e-02,  8.1806e-02,
         -9.7032e-02, -6.2266e-02, -4.0597e-02,  2.7196e-02, -3.9327e-02,
          3.6457e-02, -4.1492e-02,  5.7631e-02, -2.0380e-02, -2.2952e-02,
          1.4316e-02, -7.8320e-02, -6.7979e-02, -1.7935e-03,  8.0750e-02,
          7.1062e-02, -5.4308e-02, -1.3770e-02,  1.1597e-02,  3.4109e-03,
          2.2041e-02, -9.7834e-02,  3.3806e-02, -7.7119e-02,  4.1218e-02,
          2.8916e-03,  8.4049e-02, -4.7936e-02, -1.8897e-02, -1.1663e-02,
          7.2805e-02,  2.5388e-02, -5.3749e-02,  7.6119e-02,  2.3170e-02],
        [-1.2779e-01, -7.2839e-02,  4.3740e-01,  3.3655e-01,  5.9364e-01,
          2.8691e-01,  4.3598e-01, -7.2318e-02,  4.0493e-01, -3.1163e-01,
          3.4899e-01,  4.0392e-01,  3.3925e-01, -1.4136e-01, -3.0471e-01,
          5.2648e-01,  4.4249e-02,  2.3727e-01,  3.5257e-01, -1.5803e-01,
         -8.4397e-02,  5.7744e-03, -1.5526e-01, -1.5471e-01, -1.5006e-01,
          1.3938e-01,  2.0168e-01, -1.7486e-01, -3.0012e-01, -1.9509e-01,
          3.6304e-01,  3.5061e-01,  3.4702e-01,  3.7665e-01,  3.3841e-01,
          4.3340e-02, -5.6519e-02,  2.7955e-01,  2.6702e-02, -9.8020e-02,
          3.4523e-01,  5.2789e-02,  6.1452e-01,  4.5269e-01,  5.0737e-01,
          3.0226e-01,  4.4934e-01,  3.5509e-01, -2.8427e-01,  3.4035e-01,
          2.3064e-01,  3.1353e-01,  1.7449e-01,  4.6626e-01,  6.2623e-01,
         -4.0061e-02, -3.4904e-01,  4.2496e-01,  1.5327e-01,  7.3077e-02,
          2.9751e-01,  1.6017e-01,  3.1879e-01,  4.4732e-01,  3.3274e-01,
          3.6114e-01,  4.4661e-02, -5.9551e-02, -2.9137e-01, -2.9942e-01,
          5.5477e-01,  3.2878e-02,  3.9211e-01,  2.8149e-01, -9.4424e-02,
          6.5450e-01, -2.0422e-01,  2.6395e-01,  5.1152e-01,  5.0055e-01,
          5.7453e-02,  4.8850e-02,  1.6315e-01,  3.6555e-01,  1.5478e-02,
         -1.7038e-01,  1.7878e-01,  5.3672e-01,  3.0686e-01,  3.4304e-01,
          1.1242e-01, -9.6645e-02, -3.6526e-01,  5.6846e-01, -1.7448e-01,
          4.4394e-01,  6.9694e-01,  3.3837e-01,  6.0589e-03, -3.6500e-02]])
linear.2.bias tensor([-0.1901, -0.1517, -0.0684,  0.3705, -0.1703, -0.1565, -0.1555,  0.6600,
         0.0371, -0.4299])
linear.4.weight tensor([[-1.0731, -0.5081, -0.8607,  1.7612, -1.0147, -0.6238, -0.9651,  2.2342,
         -0.0099, -2.6831]])
linear.4.bias tensor([0.5882])
# Generate a grid of points
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]

# Convert the grid points to PyTorch tensor
grid_tensor = torch.tensor(grid_points, dtype=torch.float32)

# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
    predictions = model(grid_tensor)
    labels = (predictions >= 0.5).float().numpy().reshape(xx.shape)

# Create a new figure and an axes
fig, ax = plt.subplots()

# Plot the decision boundary
ax.contourf(xx, yy, labels, alpha=0.5, cmap=plt.cm.Set1)

# Plot the generated data
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)

# Set labels and title
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_title("Circle Data with Decision Boundary")

# Add a legend
legend_elements = [
    plt.Line2D(
        [0], [0], marker="o", color="w", label="Class 0", markerfacecolor="grey", markersize=8
    ),
    plt.Line2D([0], [0], marker="o", color="w", label="Class 1", markerfacecolor="r", markersize=8),
]
ax.legend(handles=legend_elements)

# Add a grid
ax.grid(True, color="lightgrey")

# Display the plot
plt.show()

Multi-Layer Neural Network: Five Class Data

from sklearn.datasets import make_blobs

# Generate the dataset
X, y = make_blobs(n_samples=1000, centers=5, random_state=42)

# Print the shape of the dataset
print("Shape of X:", X.shape)
print("Shape of y:", y.shape)
Shape of X: (1000, 2)
Shape of y: (1000,)
# Create a new figure and an axes
fig, ax = plt.subplots()

# Plot the generated data
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)

# Set labels and title
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_title("Linear Data")

# Add a grid
ax.grid(color="lightgrey", linestyle="--")

# Display the plot
plt.show()

y
array([1, 1, 2, 1, 4, 2, 3, 4, 1, 4, 4, 2, 3, 2, 3, 2, 0, 3, 0, 4, 1, 0,
       3, 2, 0, 1, 0, 3, 4, 3, 1, 3, 4, 3, 0, 2, 4, 4, 4, 0, 4, 3, 3, 2,
       3, 2, 1, 1, 0, 2, 1, 3, 2, 0, 0, 0, 0, 2, 2, 0, 1, 3, 2, 4, 2, 1,
       3, 3, 1, 1, 3, 2, 1, 2, 3, 0, 3, 1, 1, 3, 0, 1, 2, 2, 3, 1, 0, 4,
       0, 3, 1, 3, 1, 4, 4, 2, 0, 0, 2, 3, 2, 4, 1, 2, 0, 1, 0, 2, 4, 0,
       0, 1, 2, 2, 3, 0, 3, 2, 0, 4, 2, 0, 2, 2, 2, 2, 3, 2, 0, 1, 1, 3,
       1, 4, 2, 4, 1, 4, 0, 3, 2, 0, 1, 0, 1, 1, 2, 1, 1, 4, 3, 2, 1, 1,
       0, 1, 1, 0, 3, 1, 3, 1, 3, 0, 3, 1, 4, 3, 0, 1, 2, 0, 0, 2, 4, 2,
       3, 2, 1, 4, 1, 0, 1, 3, 4, 4, 1, 2, 3, 0, 3, 0, 1, 4, 1, 0, 4, 2,
       0, 1, 2, 3, 2, 2, 3, 4, 4, 0, 4, 1, 0, 3, 3, 4, 1, 2, 2, 2, 3, 4,
       2, 2, 0, 2, 0, 4, 2, 4, 1, 1, 0, 1, 0, 1, 1, 3, 0, 0, 3, 1, 4, 3,
       2, 4, 2, 2, 2, 2, 3, 4, 2, 2, 4, 1, 4, 0, 3, 1, 1, 0, 1, 4, 1, 3,
       1, 3, 3, 3, 2, 2, 1, 1, 3, 4, 0, 4, 2, 0, 3, 4, 4, 2, 1, 4, 2, 3,
       1, 1, 1, 0, 4, 0, 3, 4, 1, 3, 4, 4, 0, 4, 3, 3, 0, 4, 1, 3, 3, 0,
       2, 3, 3, 2, 4, 4, 4, 3, 2, 2, 1, 2, 4, 0, 2, 4, 3, 2, 4, 2, 3, 0,
       0, 3, 0, 0, 1, 4, 3, 3, 3, 0, 3, 1, 4, 3, 4, 4, 2, 2, 0, 1, 2, 3,
       0, 0, 0, 0, 0, 1, 0, 1, 4, 1, 1, 4, 1, 2, 2, 1, 3, 2, 2, 3, 2, 2,
       2, 1, 1, 2, 0, 1, 0, 4, 1, 0, 3, 3, 4, 2, 2, 4, 0, 4, 3, 3, 2, 0,
       3, 0, 4, 2, 4, 1, 1, 1, 0, 3, 0, 2, 1, 2, 2, 0, 0, 3, 2, 0, 4, 2,
       4, 3, 3, 0, 4, 1, 0, 1, 0, 1, 0, 4, 3, 3, 4, 3, 0, 0, 3, 2, 0, 2,
       3, 3, 2, 3, 0, 2, 2, 4, 2, 0, 1, 2, 1, 4, 4, 4, 2, 4, 1, 1, 4, 0,
       3, 3, 1, 0, 1, 1, 0, 1, 2, 2, 2, 4, 3, 0, 4, 1, 1, 4, 4, 3, 2, 1,
       2, 0, 3, 2, 4, 4, 1, 0, 1, 3, 1, 1, 0, 3, 0, 3, 0, 4, 2, 4, 4, 3,
       3, 1, 4, 0, 1, 3, 1, 1, 3, 0, 3, 2, 4, 0, 0, 0, 0, 1, 1, 1, 2, 1,
       2, 2, 2, 0, 0, 1, 1, 3, 1, 0, 0, 2, 2, 3, 3, 4, 3, 0, 4, 1, 0, 1,
       3, 3, 4, 0, 0, 0, 1, 1, 2, 4, 0, 0, 1, 0, 0, 0, 3, 4, 1, 2, 1, 0,
       2, 2, 3, 4, 1, 4, 4, 4, 1, 0, 3, 4, 4, 0, 4, 1, 3, 2, 2, 4, 3, 4,
       3, 0, 3, 4, 0, 2, 0, 2, 2, 1, 2, 0, 3, 4, 4, 1, 4, 1, 1, 4, 0, 3,
       4, 4, 4, 1, 0, 4, 3, 1, 0, 1, 4, 3, 2, 4, 2, 1, 2, 3, 4, 0, 1, 3,
       0, 3, 4, 2, 4, 2, 0, 2, 4, 1, 1, 2, 4, 4, 2, 2, 4, 4, 2, 3, 1, 0,
       2, 0, 2, 4, 1, 4, 1, 2, 0, 1, 3, 0, 2, 2, 4, 1, 3, 4, 3, 2, 3, 0,
       1, 3, 2, 1, 1, 1, 2, 1, 0, 0, 4, 1, 0, 2, 0, 2, 2, 4, 1, 4, 0, 4,
       2, 3, 3, 4, 1, 3, 3, 2, 4, 4, 0, 0, 4, 4, 1, 3, 4, 2, 0, 3, 3, 4,
       1, 0, 4, 2, 4, 1, 1, 4, 2, 1, 3, 1, 3, 4, 1, 2, 3, 4, 4, 3, 2, 0,
       0, 3, 2, 3, 1, 0, 4, 4, 3, 3, 4, 3, 3, 3, 4, 4, 3, 0, 2, 3, 0, 0,
       1, 4, 1, 3, 2, 0, 2, 4, 0, 3, 0, 3, 4, 1, 3, 0, 1, 3, 0, 3, 4, 2,
       2, 4, 0, 4, 3, 0, 0, 0, 4, 2, 4, 4, 1, 4, 1, 0, 2, 2, 4, 0, 1, 0,
       0, 1, 4, 1, 4, 3, 0, 3, 0, 3, 4, 4, 2, 0, 1, 3, 2, 4, 0, 1, 3, 1,
       4, 2, 1, 1, 3, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 1, 3, 3, 4, 4, 3, 3,
       3, 2, 4, 3, 2, 2, 0, 4, 1, 3, 4, 0, 4, 3, 4, 1, 4, 3, 2, 2, 3, 2,
       3, 0, 0, 0, 0, 4, 4, 3, 2, 2, 3, 0, 3, 3, 2, 3, 3, 3, 0, 4, 4, 1,
       1, 4, 4, 0, 1, 1, 1, 4, 0, 3, 0, 4, 1, 2, 2, 2, 3, 4, 0, 3, 1, 3,
       4, 4, 3, 0, 2, 3, 2, 0, 1, 1, 3, 0, 0, 4, 2, 4, 4, 2, 0, 0, 4, 0,
       3, 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 1, 0, 3, 1, 1, 3, 0, 2, 2, 3, 2,
       3, 1, 4, 1, 4, 2, 1, 3, 3, 0, 2, 4, 3, 4, 1, 4, 2, 1, 1, 2, 4, 0,
       0, 0, 2, 2, 1, 0, 1, 3, 1, 3])
# Define multi-layer neural network class
class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(input_size, 100),
            nn.ReLU(),
            nn.Linear(100, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 5),
        )

    def forward(self, x):
        out = self.linear(x)
        return out


# Create the model instance
input_size = X.shape[1]
model = MLP(input_size)
print(model)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Convert the data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

# Train the model
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)
MLP(
  (linear): Sequential(
    (0): Linear(in_features=2, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
    (3): ReLU()
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): ReLU()
    (6): Linear(in_features=10, out_features=5, bias=True)
  )
)
Trained model parameters:
linear.0.weight tensor([[-0.7157, -0.0858],
        [ 0.8468,  0.7653],
        [-0.6592, -0.3552],
        [-0.3824, -0.4735],
        [ 0.7051,  0.5343],
        [ 0.9784,  0.6403],
        [-0.6483,  0.4195],
        [-0.8238, -0.3240],
        [ 0.3995,  0.3407],
        [ 0.1389,  0.1510],
        [ 0.5828,  0.1263],
        [-0.4348, -0.2220],
        [-0.9549,  0.6182],
        [ 0.6842, -0.9855],
        [-0.1463, -0.4805],
        [-0.1099, -0.2883],
        [-0.3011, -0.5472],
        [-0.3674, -0.4947],
        [ 0.2536, -0.5226],
        [-0.1205, -0.4218],
        [-0.0295, -0.4661],
        [-0.3602,  0.7304],
        [ 0.5083,  0.6896],
        [-0.0423, -0.6907],
        [ 0.6464, -0.9132],
        [-0.5978,  0.3346],
        [ 0.3256,  0.2024],
        [-0.7463, -0.3477],
        [ 0.6988,  0.1492],
        [ 0.6165,  0.4373],
        [ 0.6975,  0.1740],
        [-0.6336, -0.3413],
        [-0.7579, -0.4046],
        [-0.6774, -0.3200],
        [-1.0890,  1.1611],
        [-0.1804, -0.7217],
        [-0.9017, -0.5263],
        [-0.6620, -0.9339],
        [-0.1110, -0.5026],
        [-0.3379,  0.1735],
        [-0.9509,  0.2527],
        [ 0.6989,  0.1140],
        [-0.8720,  0.2883],
        [-0.3166,  0.7973],
        [ 0.8615,  0.5377],
        [-0.6605,  0.2671],
        [-0.3825,  0.3720],
        [-0.1639, -0.5769],
        [ 0.0574,  0.6976],
        [-0.3036,  0.4223],
        [ 0.2659,  0.5329],
        [-0.8830, -0.2844],
        [ 0.9417, -0.0289],
        [ 0.6779,  0.0245],
        [-0.1493,  0.5753],
        [-0.5855,  0.7524],
        [-0.5105, -0.5524],
        [-0.3072, -0.1284],
        [ 0.2744, -0.3410],
        [ 0.3426,  0.5276],
        [ 0.8322,  0.7613],
        [-0.1371, -0.6629],
        [ 1.4103,  0.0273],
        [-0.0424,  0.4994],
        [-0.2259, -0.4207],
        [ 0.2477,  0.5408],
        [-0.1335,  0.4674],
        [-0.5586,  0.2430],
        [-0.5568,  0.1298],
        [ 0.4542,  0.2175],
        [-0.0707,  0.8915],
        [ 0.1364,  0.5181],
        [-0.8475, -0.2746],
        [ 0.7219,  0.1751],
        [-0.6209,  0.2819],
        [ 0.7304, -1.0402],
        [-1.1036, -0.3350],
        [ 0.5568, -0.5678],
        [-0.2201, -0.4648],
        [-0.7181,  0.7876],
        [ 0.5435,  0.6588],
        [ 0.5208, -0.7508],
        [ 0.6194, -0.0130],
        [ 0.8899, -0.1771],
        [ 0.9394, -1.3419],
        [-0.5504,  0.1257],
        [ 0.1923,  0.2786],
        [ 0.6418,  0.5386],
        [ 0.4070,  0.6006],
        [ 0.3295, -0.9100],
        [-0.2571,  0.7362],
        [ 0.2759,  0.3048],
        [ 0.0132, -0.0330],
        [-0.6007, -0.1256],
        [-0.3616, -0.8582],
        [ 0.3765,  0.1562],
        [ 0.7790, -1.1187],
        [-0.1705,  0.6608],
        [-0.2889, -0.4123],
        [ 0.8092,  0.0993]])
linear.0.bias tensor([-0.3259,  0.4577, -0.7003,  0.6879,  0.2302,  1.2855,  0.4530,  0.3994,
         0.8464,  0.3627, -0.6694,  0.4105,  0.6682, -0.6181,  0.1842,  0.3095,
         0.7343,  0.0305, -0.0508, -0.0894, -0.3163, -0.4564,  0.9797,  0.1803,
        -0.5751, -0.4383, -0.7591, -0.4591, -0.7893,  0.5091,  1.5919, -0.7568,
        -0.4396, -0.2020, -1.4721,  0.2546,  0.2386,  0.7231,  0.4761, -0.9233,
        -1.2430, -0.9320, -0.5040,  0.6994,  0.0944, -0.3940,  0.2890,  0.2279,
        -0.4219,  0.0791,  0.7432,  0.2517,  0.0751, -0.2889, -1.4752,  0.3308,
        -0.4447, -0.5624,  0.2417,  0.4560,  0.3327, -0.5938,  0.3420, -0.8407,
        -0.0536,  0.6167,  0.3761, -0.4051, -0.8270,  1.2853, -0.0592,  0.9944,
        -0.1577,  0.8056, -1.5105, -0.6463,  1.9510, -0.4385,  0.6694, -0.9947,
         0.8537, -0.4753,  1.5830,  0.4305, -0.8299, -0.3378,  0.5756,  0.4544,
         0.8127,  1.5117,  0.6624,  0.6745, -0.6311,  0.6612, -0.4096,  0.3036,
        -0.6912, -0.7848,  1.2749, -0.3954])
linear.2.weight tensor([[ 1.5666e-01, -2.2399e-01, -3.3300e-03, -1.4687e-01, -2.3979e-01,
         -1.1060e-01,  2.2122e-01,  1.7277e-01, -2.1965e-01, -2.1749e-01,
         -9.9902e-01,  2.2263e-01,  3.0209e-01, -2.0226e-01, -2.0015e-01,
         -1.9789e-01, -1.0800e-01, -2.2008e-01, -6.0523e-02, -6.8356e-02,
         -2.3571e-01,  2.5550e-01, -9.2719e-02, -1.4963e-01, -2.3781e-01,
          2.8744e-01,  3.8025e-01,  6.7432e-02, -1.2754e+00, -2.1807e-01,
         -5.7075e-01,  9.2589e-02, -2.1055e-03,  3.3537e-02,  3.6740e-01,
         -2.3746e-01,  3.5173e-02, -1.5688e-01, -1.4359e-01,  2.8132e-01,
          2.7926e-01, -1.7235e+00,  2.6881e-01,  3.0697e-01, -7.6701e-02,
          1.9857e-01,  2.2061e-01, -1.9372e-01,  1.7258e-01,  2.1779e-01,
          1.3777e-02,  2.5309e-01, -2.8005e+00, -9.4578e-03,  4.5968e-01,
          2.9711e-01, -1.7701e-01,  1.1434e-01,  1.4672e-02, -1.6135e-02,
         -2.0751e-01, -1.0493e-01, -2.9138e+00,  3.2780e-01, -2.0744e-01,
         -5.3776e-03,  2.7753e-01,  2.3219e-01,  3.0444e-01, -1.0228e+00,
          2.5998e-01,  4.0859e-02,  1.6085e-01, -4.5444e-01,  4.5938e-01,
          5.1375e-02, -2.0712e-01, -1.4677e-01, -2.2085e-01,  3.8048e-01,
         -1.4932e-01,  9.7451e-02, -9.5515e-01, -4.0730e-02,  3.1863e-02,
          3.4480e-01, -3.2218e-02, -1.8000e-01, -3.8381e-02, -1.3159e-01,
          2.7847e-01, -9.5910e-02,  4.9402e-03,  1.0162e-01, -2.2593e-01,
         -2.4486e-01, -1.1703e-01,  2.7868e-01, -9.5483e-02, -4.9591e-01],
        [ 9.7559e-02,  9.9721e-02,  5.2562e-03, -3.3478e-02, -3.9585e-02,
         -3.3170e-02, -4.6732e-02, -3.7832e-02, -8.7473e-02, -4.2170e-02,
         -1.4273e-02,  7.4447e-02, -3.3671e-03, -1.0403e-02,  3.5913e-02,
         -8.3806e-02, -1.1356e-02,  5.2117e-02, -6.0668e-02,  2.8677e-02,
          6.5223e-02, -4.0261e-02,  9.8021e-02, -4.6573e-02, -3.5789e-02,
          1.7325e-02,  9.3751e-02, -6.9879e-02, -1.8653e-02, -2.2977e-02,
         -7.5078e-02,  2.3377e-02, -5.4457e-02,  7.0035e-02, -9.3567e-02,
         -6.4344e-02, -8.1320e-02,  3.9031e-02, -2.8538e-02,  9.3248e-03,
          7.6998e-02,  6.4107e-02, -4.4981e-02, -8.2442e-02,  7.3602e-02,
          8.0456e-02,  3.3351e-02, -8.0334e-02, -5.1095e-02, -1.8214e-03,
         -4.5014e-02, -8.9755e-02,  8.5262e-02, -1.8515e-02,  2.2350e-02,
          3.9358e-02, -4.0820e-02, -5.2095e-02,  4.6096e-02,  1.8926e-03,
         -5.4466e-02,  3.1584e-02, -1.6151e-02, -9.0292e-02,  8.3092e-02,
         -9.3269e-02,  1.4231e-02, -8.2068e-02, -3.7458e-02, -3.5819e-02,
          5.7825e-02,  7.0045e-02,  7.6307e-02, -9.2453e-02,  1.8708e-02,
          6.1323e-02, -4.1380e-02, -3.1388e-02,  4.4657e-02, -2.4942e-02,
          1.5333e-03,  4.4527e-02, -1.8141e-03,  6.9774e-03,  6.7264e-02,
         -8.7481e-02, -9.8603e-03, -6.1178e-02, -2.7478e-02, -1.6766e-02,
         -1.6482e-02, -4.3876e-02,  4.6738e-02,  6.9412e-02, -3.9773e-02,
          8.4581e-02, -4.7432e-02,  1.6062e-02, -2.8793e-02, -9.7105e-02],
        [ 3.4066e-01, -2.2669e-01,  2.0638e-01,  3.0436e-01, -1.4012e-01,
         -2.1195e-01,  2.0294e-01,  3.3177e-01, -2.0788e-01, -1.6617e-01,
          1.8914e-01,  3.5530e-01,  2.5570e-01,  1.6404e-01,  2.2100e-01,
          1.3264e-01,  1.9906e-01,  2.3174e-01,  1.7208e-01,  2.1081e-01,
          1.3878e-01,  1.0688e-01, -1.1241e-01,  2.8781e-01,  1.7327e-01,
          1.8737e-01, -1.4073e-01,  2.9072e-01,  3.2045e-02, -2.0583e-01,
          1.1135e-03,  2.4074e-01,  3.5334e-01,  3.2847e-01,  8.3936e-02,
          2.1050e-01,  3.2807e-01,  3.0017e-01,  1.3213e-01,  1.6703e-01,
          2.6032e-01, -1.1563e-01,  2.3754e-01,  6.4110e-03, -8.7744e-02,
          2.5765e-01,  1.4098e-01,  2.8086e-01, -3.2349e-02,  4.8971e-02,
         -2.6792e-01,  3.1783e-01, -1.2590e-01, -1.5280e-03, -1.8023e-01,
          6.4615e-02,  2.4404e-01,  2.1271e-01,  1.5348e-01, -1.0202e-01,
         -3.1126e-01,  3.0044e-01, -1.0591e-01,  1.9645e-02,  1.8606e-01,
         -1.3696e-01, -5.1265e-03,  1.3365e-01,  3.2478e-01, -7.2510e-01,
         -4.9263e-02, -1.3800e-01,  2.6668e-01, -4.0788e-02,  1.7252e-01,
         -1.2523e-01,  2.2963e-01,  2.2844e-01,  2.7504e-01,  2.0454e-01,
         -2.1326e-01, -3.8035e-02, -3.3890e-02, -1.1165e-01,  6.5699e-02,
          3.0182e-01, -1.4481e-01, -1.6535e-01, -1.5763e-01,  2.1423e-01,
         -1.6284e-02, -2.3269e-01,  1.1910e-02,  2.7782e-01,  3.1665e-01,
         -5.9026e-02,  2.0754e-01,  4.0115e-02,  2.6167e-01, -5.7767e-02],
        [-1.1849e-01, -7.7969e-03, -1.5661e-01, -6.1125e-02,  2.2625e-02,
         -9.1237e-02,  4.0104e-02, -6.6335e-02,  2.6674e-02, -1.2272e-01,
         -2.2597e-02,  4.3267e-02, -8.9480e-02, -6.1147e-02, -1.0059e-02,
         -1.3936e-01, -8.6508e-02, -2.2515e-02, -9.3664e-02, -6.6825e-03,
         -1.0201e-01, -1.0720e-01, -5.0068e-02, -1.1679e-01,  3.0857e-02,
         -1.1464e-01, -5.6722e-02, -1.3030e-01, -1.5132e-01,  4.6598e-02,
         -4.3135e-02,  2.5682e-03,  4.3127e-03, -8.0890e-02, -6.5783e-02,
         -7.9201e-02, -9.9236e-02, -1.4267e-02, -9.8268e-04, -4.1820e-03,
         -1.1445e-01, -1.5028e-01, -2.5004e-02, -2.6387e-03, -1.2209e-01,
         -1.5330e-02, -1.5169e-01, -1.5652e-01, -1.0176e-01, -1.7024e-02,
          3.9792e-02, -1.5171e-01,  3.4623e-02, -2.0488e-02, -4.3970e-03,
         -1.4869e-01, -4.2210e-02,  3.5327e-02,  2.2972e-02, -7.8951e-02,
          2.4533e-02, -9.5509e-02, -5.1651e-02, -8.3961e-02, -1.8472e-02,
          4.2319e-02, -2.7645e-02,  9.0295e-03,  2.2153e-02, -6.5848e-02,
          1.6244e-02,  9.4096e-03, -1.3045e-02, -3.8044e-02, -2.4185e-02,
         -1.3695e-01,  1.5228e-02, -1.0525e-02,  4.6720e-03, -2.6101e-02,
         -4.3436e-03, -4.6502e-02, -9.1774e-02, -1.1180e-01,  1.8824e-03,
          6.0288e-04, -1.1517e-01,  1.9949e-02, -1.0356e-01, -1.3704e-01,
         -8.3517e-02, -6.9738e-02,  5.1833e-02, -1.4469e-01,  3.5419e-02,
         -1.2094e-01,  1.7723e-02,  6.9318e-03, -1.4521e-01, -1.0900e-01],
        [ 3.3103e-02, -2.9652e-02,  1.9098e-02, -1.1797e-01, -8.8985e-03,
         -9.5452e-02,  1.2919e-02, -1.7926e-02, -4.0095e-02, -2.6798e-02,
         -8.7017e-02, -1.0745e-01, -3.8287e-02, -9.4597e-02, -2.5650e-02,
          2.9995e-02, -7.4850e-02, -5.7272e-02,  2.8965e-02, -1.4841e-01,
         -1.3993e-01, -8.7357e-02, -9.6490e-02,  3.6011e-02, -9.7039e-02,
          3.8971e-02,  1.4832e-03, -1.3230e-01, -6.2370e-02, -4.9680e-02,
          5.6635e-02, -6.6010e-02, -2.8427e-02,  1.0241e-02, -1.0028e-01,
         -1.5885e-01, -1.2632e-01, -8.8555e-02, -1.2777e-01, -2.6624e-02,
         -2.8859e-02,  8.7423e-02,  2.8555e-02, -1.5382e-01, -3.9508e-02,
         -1.2824e-01, -1.0591e-01,  1.5083e-02,  4.9543e-02,  2.0838e-02,
          9.7050e-02, -1.2707e-01, -1.0845e-01, -7.4149e-02, -9.0393e-02,
         -1.5021e-01, -6.9915e-02, -7.3932e-02, -1.5561e-01, -2.6250e-02,
         -4.8112e-02,  3.1877e-02, -8.8039e-02,  6.8075e-02,  7.8870e-03,
          6.6852e-02,  2.2080e-02, -2.7804e-02, -1.2938e-01, -3.1030e-02,
         -6.7066e-02, -1.7240e-03, -3.2691e-03, -4.2015e-02, -3.4847e-02,
         -3.3286e-02, -4.4260e-02, -3.7307e-02, -1.4968e-01, -1.4191e-01,
          7.5149e-02, -9.8460e-02, -4.1743e-02, -1.4487e-01, -4.9243e-02,
         -1.5390e-01,  2.2282e-03, -5.1129e-02, -5.0759e-02, -1.5388e-01,
         -4.8148e-02,  2.9664e-02, -3.4954e-02, -3.9795e-03, -2.4351e-02,
          4.6323e-02, -6.1543e-02,  3.4615e-02,  3.9667e-02,  5.8227e-03],
        [-4.6510e-02, -4.9032e-02, -7.1243e-02, -1.5057e-02, -3.9990e-02,
         -8.6335e-02, -5.9020e-02,  3.9320e-02, -8.6666e-02,  7.7752e-02,
         -3.2762e-02, -6.9965e-02,  4.2544e-02,  1.4781e-03,  2.6269e-02,
         -7.7447e-02,  6.3324e-02,  3.9288e-02, -9.1361e-03, -5.1340e-02,
          7.9445e-03,  2.2402e-02, -5.4517e-02,  4.5531e-02, -4.2306e-02,
         -8.5663e-02, -4.0078e-02, -4.6542e-02,  2.7651e-02,  4.3961e-02,
          9.8856e-03, -1.1676e-02, -2.2226e-04, -7.7728e-03, -1.2191e-02,
         -9.8001e-02, -6.7645e-02,  5.0356e-02,  6.0369e-02, -8.4953e-02,
         -1.0358e-01,  3.1401e-02, -8.3145e-02, -3.7875e-02, -7.1154e-02,
         -1.2621e-01, -7.8417e-02, -6.1136e-02,  2.8972e-02, -4.8982e-02,
         -9.5655e-02, -1.0762e-01,  1.7080e-02,  2.9777e-03,  9.8989e-04,
         -6.3690e-02, -3.9715e-02, -2.5035e-02,  8.5934e-02, -2.7804e-02,
         -6.8254e-02, -8.5664e-03, -2.6031e-02, -2.1849e-02,  6.0124e-02,
         -6.5162e-03,  5.4263e-03,  2.6216e-02,  5.9158e-02, -9.5998e-02,
          2.9186e-02, -3.9914e-02, -2.2287e-02, -8.0053e-02, -6.2409e-03,
         -3.5394e-02,  6.0710e-02, -4.0315e-02,  2.9293e-02, -6.5881e-02,
         -8.5151e-02, -8.0181e-02,  5.1541e-03, -7.7113e-04,  8.5335e-02,
         -6.1182e-02,  1.4217e-04, -5.8660e-02,  4.9669e-02,  1.9340e-03,
         -6.0617e-02,  7.2222e-02,  9.6410e-02, -6.6794e-02,  8.3251e-02,
         -3.4567e-02, -5.1212e-02,  5.6342e-03, -7.7391e-02, -8.5255e-02],
        [ 7.3992e-03, -9.2201e-02, -7.6840e-02, -5.7200e-04, -3.3702e-02,
          1.6946e-02, -6.7936e-02, -1.5845e-01, -1.4161e-01, -9.3280e-02,
         -9.1885e-02,  3.6732e-03, -6.5342e-03,  7.9814e-03, -9.3123e-02,
         -1.4473e-01, -1.2320e-01, -6.3563e-02, -1.5855e-01,  2.1975e-03,
         -1.3163e-01, -3.9574e-03, -1.5698e-01,  7.1899e-03, -1.4187e-01,
          3.4146e-03,  3.8125e-02,  2.5843e-03, -1.3079e-02, -1.3204e-01,
          8.3786e-02, -2.2120e-02,  3.9271e-02, -1.2596e-01,  1.2143e-02,
         -8.5971e-02, -1.2114e-01, -1.4848e-01, -1.2931e-01,  4.6192e-03,
          9.3149e-03,  4.7038e-02, -1.3509e-01, -1.3677e-01,  5.4563e-02,
         -4.3882e-02, -9.9015e-02, -1.4939e-01,  2.6642e-02, -1.7332e-02,
          4.0257e-03,  1.8777e-02,  9.0937e-04, -1.8833e-02, -9.0741e-02,
         -1.0528e-01, -1.5416e-02, -6.9905e-03, -8.7717e-02, -4.0205e-02,
          1.7305e-02, -9.7544e-02,  4.3502e-02, -1.0028e-01, -8.7938e-02,
         -4.0312e-02,  1.5484e-02,  2.4350e-02, -6.2460e-02, -9.3478e-03,
         -9.1472e-02, -1.2269e-01, -5.5012e-02,  8.3711e-02, -9.9640e-02,
          2.4393e-03, -7.6099e-02, -4.0428e-02, -1.6589e-02, -1.3674e-01,
         -7.4197e-02,  6.2661e-02,  3.1131e-02, -8.8770e-02,  8.4476e-02,
         -1.4011e-01, -7.4726e-02, -6.8375e-03, -1.0614e-01,  8.3319e-02,
         -1.1718e-01,  4.0877e-02, -3.3579e-02,  1.2562e-02,  2.6064e-02,
         -2.0993e-02, -6.6200e-02, -4.2079e-02, -3.9781e-02,  3.0279e-02],
        [-9.4463e-03,  2.5273e-01, -7.4421e-01,  2.3654e-01,  2.7598e-01,
          3.1193e-01,  7.2331e-01, -2.5938e-01,  1.8057e-01,  2.2860e-02,
         -4.2454e-02,  4.6235e-02,  5.6263e-01, -1.7237e+00,  7.5630e-02,
          1.7184e-01,  1.2490e+00,  2.7224e-01, -1.0309e-01, -4.6537e-02,
          1.3509e-01,  1.6600e-01,  2.8930e-01,  6.5499e-02, -1.6997e+00,
          1.7876e-01, -1.2066e-01, -5.9579e-01, -4.4608e-02,  1.5029e-01,
          2.5303e-01, -3.9610e-01, -3.0922e+00, -5.0539e-01,  4.2130e-01,
         -5.0266e-02, -4.2651e-01,  2.8539e-01, -1.4207e-01,  1.1180e-01,
         -1.1013e-03,  3.2779e-03,  7.6952e-02,  1.9291e-01,  1.6175e-01,
          1.9929e-01,  1.4937e-01,  1.3756e-01,  1.7965e-01,  7.7044e-02,
          1.9019e-01, -1.5527e-01,  2.6779e-01,  9.8191e-02,  1.2652e+00,
          1.6396e-01,  1.3644e-01, -3.7035e-01, -5.6507e-02,  1.2423e-01,
          1.2563e-01, -1.2620e-01,  2.7061e-01,  1.7234e-01, -5.9522e-03,
          9.2409e-02,  1.3593e-01,  1.0498e-01,  3.6392e-02,  3.4221e-01,
          1.9488e-01,  2.6107e-01, -3.8584e-01,  1.2479e-01, -4.2127e-03,
         -3.8013e-01, -4.2373e-01,  2.4846e-01,  2.1683e-01,  6.1581e-01,
          1.6604e-01, -1.7717e-01,  2.8340e-01,  2.7094e-01, -3.8997e-01,
          5.1320e-02,  2.4170e-01,  1.7787e-01,  2.5139e-01, -4.2393e-01,
          2.0944e-01,  2.7613e-01, -7.9983e-02, -2.9780e-01, -1.7064e-01,
          7.5258e-02, -4.3994e+00,  1.9829e-01, -1.4221e+00,  5.4893e-03],
        [-1.4490e-01,  3.0999e-01, -1.4755e-01, -2.7093e-02,  2.6288e-01,
          2.5813e-01, -5.3249e-01, -6.2880e-03,  2.9064e-01,  1.6039e-01,
          3.5534e-01, -1.2712e-02, -3.4056e-01,  1.4612e+00, -1.5646e-01,
         -1.1143e-01, -4.6872e-03, -3.3256e-02,  2.0862e-01, -5.4929e-02,
         -1.0332e-01, -1.9081e-01,  2.1010e-01,  4.3397e-02,  9.1897e-01,
         -1.8204e-01, -2.8244e-01, -9.0974e-03,  2.0856e-01,  1.2192e-01,
          2.2300e-01, -9.8558e-02,  3.4543e-03, -1.3415e-01, -4.9014e-01,
         -1.3772e-01, -8.4227e-02, -8.6236e-02, -1.8728e-02, -3.6871e-01,
         -5.5334e-01,  1.8979e-01, -2.3790e-01,  9.2382e-02,  2.7006e-01,
         -2.4869e-01,  1.0566e-01, -1.6615e-01, -5.9913e-03, -1.0295e-01,
          1.4341e-01, -2.9302e-02,  2.7770e-01,  1.1276e-01, -1.5408e+00,
          5.1996e-02, -1.5886e-01, -1.3873e-01,  1.8660e-01,  2.2475e-01,
          2.3754e-01,  1.5193e-03,  1.9058e-01, -1.3767e-01, -5.6253e-02,
          1.5812e-01, -3.7213e-02, -1.1783e-01, -3.6599e-01,  2.8579e-01,
          2.5834e-02,  1.7404e-01, -5.1877e-02,  3.0329e-01, -5.2061e-01,
          3.4646e-01,  2.9267e+00,  2.3889e-01, -4.5155e-02, -5.4413e-01,
          2.7813e-01,  2.6674e-01,  1.9599e-01,  2.0122e-01,  4.6119e-01,
         -1.9064e-01,  2.5398e-01,  1.7401e-01,  1.9930e-01,  4.4993e-01,
          9.3427e-02,  1.5471e-01,  4.2120e-02,  2.0102e+00, -1.2324e-01,
          2.9004e-01,  2.0032e+00, -2.1398e-02,  7.0167e-01,  2.7802e-01],
        [-9.8567e-02, -1.8608e-02,  9.8736e-03, -5.7179e-02, -8.6996e-02,
         -8.7135e-02,  8.6413e-02,  2.4692e-02,  2.5092e-02,  3.0335e-02,
         -4.2105e-02, -2.8809e-02, -4.4483e-02,  8.1848e-02,  2.9793e-02,
         -8.0890e-02,  2.2661e-02,  9.8140e-02,  1.3173e-02, -1.3363e-02,
          4.7422e-02, -5.3408e-02, -5.8562e-02, -4.1046e-02, -6.2692e-04,
          1.1508e-02, -6.1444e-02,  2.8698e-02,  6.9115e-02, -2.2743e-02,
         -3.5880e-02, -3.3627e-02,  2.8998e-02, -9.5922e-02, -9.2932e-02,
         -9.6507e-02, -3.5611e-02,  5.4674e-02, -4.7794e-02, -8.3346e-02,
         -6.8014e-02,  1.2078e-02,  1.6166e-02, -3.2733e-02, -9.1847e-02,
          8.4442e-02, -9.3683e-02,  9.7822e-02, -5.6965e-02, -2.0141e-02,
         -8.2946e-02, -1.8858e-02,  3.7601e-02, -6.8778e-02,  4.4359e-02,
          5.0159e-02, -6.2003e-02,  7.2751e-02, -5.8908e-02, -5.0894e-02,
          2.9734e-02, -2.5002e-02, -2.3154e-02, -3.2335e-02, -4.4888e-02,
          2.6441e-02, -9.5169e-03,  8.6024e-03, -3.9696e-02,  6.8351e-02,
         -6.1218e-03,  4.0225e-02, -4.4624e-02,  5.2940e-02,  4.8544e-02,
         -6.7251e-02, -8.8407e-02, -1.8300e-03,  8.6462e-03,  1.4661e-02,
         -9.3749e-02, -2.8587e-02, -9.6928e-02,  8.7356e-02, -9.0838e-02,
         -7.9566e-02, -9.5027e-02,  1.8029e-02,  1.9667e-02,  9.0467e-02,
         -8.4525e-03, -7.6852e-02,  7.9725e-03,  3.3333e-02,  8.2368e-02,
          8.3399e-03, -3.9683e-03, -6.3617e-02,  2.6470e-02, -6.8539e-02]])
linear.2.bias tensor([-3.1027e-01,  4.5636e-02,  2.5239e-02,  2.4211e-02, -3.0502e-04,
         2.1908e-02, -9.0223e-02,  4.4743e-01,  2.0019e-01,  6.7779e-02])
linear.4.weight tensor([[-0.3507,  0.2981, -0.1999, -0.1106, -0.1271,  0.1342, -0.2358,  0.1883,
          0.5624,  0.1241],
        [ 0.5603,  0.2650,  0.0084, -0.2541,  0.2511,  0.1093,  0.0400,  0.0160,
         -0.3641, -0.0306],
        [-0.0265, -0.2103, -0.2343, -0.2293,  0.0729, -0.0196, -0.0141, -0.0164,
          0.3967,  0.0959],
        [-0.2783,  0.2216,  0.2876, -0.0598,  0.3290, -0.2400,  0.3703, -0.3434,
          0.2738, -0.2622],
        [ 0.2651,  0.1938, -0.0362, -0.0228, -0.2329, -0.1324,  0.2049,  0.3382,
         -0.1831,  0.1810],
        [-0.0159,  0.3087, -0.7259, -0.1561, -0.0213,  0.2309,  0.2505,  0.5566,
         -0.2723, -0.0016],
        [-1.0350,  0.0426, -0.1839,  0.1457,  0.2306,  0.2729,  0.1060, -0.0189,
          0.5539,  0.1271],
        [-0.2176, -0.0341, -0.1587,  0.0922,  0.2077,  0.2225,  0.1486, -0.3146,
         -0.0643, -0.0218],
        [-0.2164,  0.2170, -0.4762, -0.0297,  0.0013, -0.1703,  0.1134,  0.5581,
         -0.5296, -0.1345],
        [ 0.0289, -0.2450,  0.4034,  0.2436, -0.2010, -0.2607, -0.0513, -0.1674,
          0.1470, -0.0469]])
linear.4.bias tensor([ 0.4209, -0.4941,  0.3248,  0.2670,  0.5732,  0.1356, -0.1319,  0.1554,
         0.3631,  0.1190])
linear.6.weight tensor([[-0.7229,  0.1407, -0.5544, -0.0383,  0.2408,  0.4584, -4.2137,  0.1853,
          0.5194, -0.3519],
        [ 0.4638, -0.1626,  0.2902, -0.3500, -0.1614, -0.5834,  0.3103, -0.1113,
         -0.3584, -0.0282],
        [-0.3584,  0.1217, -0.3695,  0.5293, -0.4095,  0.0567, -0.0190, -0.1553,
         -0.3932,  0.4102],
        [-0.4371,  0.4393, -0.1164, -0.0548,  0.2365, -0.5351, -0.0148,  0.2300,
         -0.4858,  0.0137],
        [ 0.4004, -0.6929, -0.0416, -0.6043,  0.3351,  0.1766, -0.0683,  0.0609,
          0.3810, -0.0724]])
linear.6.bias tensor([-0.2300, -0.2272,  0.1777,  0.0780,  0.5216])
from sklearn.metrics import accuracy_score

# Convert the model predictions to numpy array
with torch.no_grad():
    predictions = model(X_tensor)
    predicted_labels = torch.argmax(predictions, dim=1).numpy()

# Calculate the accuracy
accuracy = accuracy_score(y, predicted_labels)
print("Accuracy:", accuracy)
Accuracy: 0.983
# Generate a grid of points
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]

# Convert the grid points to PyTorch tensor
grid_tensor = torch.tensor(grid_points, dtype=torch.float32)

# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
    predictions = model(grid_tensor)
    labels = torch.argmax(predictions, dim=1).numpy().reshape(xx.shape)

# Create a new figure and an axes
fig, ax = plt.subplots()

# Plot the decision boundary
ax.contourf(xx, yy, labels, alpha=0.5, cmap=plt.cm.Set1)

# Plot the generated data
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)

# Set labels and title
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_title("Decision Boundary")

# Add a grid
ax.grid(True, color="lightgrey")

# Display the plot
plt.show()

Image Classification with KMNIST

See Week 13!