# standard imports
import matplotlib.pyplot as plt
import numpy as np
# sklearn data generation
from sklearn.datasets import make_classification
from sklearn.datasets import make_circles
# pytorch imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchsummary import summary
CS 307: Week 15
Logistic Regression as a Neural Network: Linear Data
# Generate "linear" data
= make_classification(
X, y =100,
n_samples=2,
n_features=2,
n_informative=0,
n_redundant=1,
n_clusters_per_class=2,
random_state=2,
n_classes )
# Create a new figure and an axes
= plt.subplots()
fig, ax
# Plot the generated data
= ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
scatter
# Set labels and title
"Feature 1")
ax.set_xlabel("Feature 2")
ax.set_ylabel("Linear Data")
ax.set_title(
# Add a grid
="lightgrey", linestyle="--")
ax.grid(color
# Display the plot
plt.show()
# Define the logistic regression model class
class LogisticRegression(nn.Module):
def __init__(self, input_size):
super().__init__()
self.linear = nn.Linear(input_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
= self.linear(x)
out = self.sigmoid(out)
out return out
# Create the model instance
= X.shape[1]
input_size = LogisticRegression(input_size)
model print(model)
# Define the loss function
= nn.BCELoss()
loss_fn
# Define the optimizer
= torch.optim.SGD(model.parameters(), lr=0.01)
optimizer
# Convert the data to PyTorch tensors
= torch.tensor(X, dtype=torch.float32)
X_tensor = torch.tensor(y, dtype=torch.float32)
y_tensor
# Train the model
= 1000
num_epochs for epoch in range(num_epochs):
# Forward pass
= model(X_tensor)
outputs = loss_fn(outputs, y_tensor.view(-1, 1))
loss
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)
LogisticRegression(
(linear): Linear(in_features=2, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
Trained model parameters:
linear.weight tensor([[1.8507, 0.4726]])
linear.bias tensor([-0.1388])
# Generate a grid of points
= X[:, 0].min() - 1, X[:, 0].max() + 1
x_min, x_max = X[:, 1].min() - 1, X[:, 1].max() + 1
y_min, y_max = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
xx, yy = np.c_[xx.ravel(), yy.ravel()]
grid_points
# Convert the grid points to PyTorch tensor
= torch.tensor(grid_points, dtype=torch.float32)
grid_tensor
# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
= model(grid_tensor)
predictions = (predictions >= 0.5).float().numpy().reshape(xx.shape)
labels
# Create a new figure and an axes
= plt.subplots()
fig, ax
# Plot the decision boundary
=0.5, cmap=plt.cm.Set1)
ax.contourf(xx, yy, labels, alpha
# Plot the generated data
= ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
scatter
# Set labels and title
"Feature 1")
ax.set_xlabel("Feature 2")
ax.set_ylabel("Linear Data with Decision Boundary")
ax.set_title(
# Add a legend
= [
legend_elements
plt.Line2D(0], [0], marker="o", color="w", label="Class 0", markerfacecolor="grey", markersize=8
[
),0], [0], marker="o", color="w", label="Class 1", markerfacecolor="r", markersize=8),
plt.Line2D([
]=legend_elements)
ax.legend(handles
# Add a grid
="lightgrey", linestyle="--")
ax.grid(color
# Display the plot
plt.show()
Logistic Regression as a Neural Network: Circle Data
# Generate circles data
= make_circles(n_samples=100, noise=0.05, random_state=42)
X, y
# Plot the generated data
= plt.subplots()
fig, ax = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
scatter "Feature 1")
ax.set_xlabel("Feature 2")
ax.set_ylabel("Circles Data")
ax.set_title(="lightgrey", linestyle="--")
ax.grid(color plt.show()
# Create the model instance
= X.shape[1]
input_size = LogisticRegression(input_size)
model print(model)
# Define the loss function
= nn.BCELoss()
criterion
# Define the optimizer
= torch.optim.SGD(model.parameters(), lr=0.01)
optimizer
# Convert the data to PyTorch tensors
= torch.tensor(X, dtype=torch.float32)
X_tensor = torch.tensor(y, dtype=torch.float32)
y_tensor
# Train the model
= 1000
num_epochs for epoch in range(num_epochs):
# Forward pass
= model(X_tensor)
outputs = criterion(outputs, y_tensor.view(-1, 1))
loss
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)
LogisticRegression(
(linear): Linear(in_features=2, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
Trained model parameters:
linear.weight tensor([[-0.0548, -0.2432]])
linear.bias tensor([0.0183])
# Generate a grid of points
= X[:, 0].min() - 1, X[:, 0].max() + 1
x_min, x_max = X[:, 1].min() - 1, X[:, 1].max() + 1
y_min, y_max = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
xx, yy = np.c_[xx.ravel(), yy.ravel()]
grid_points
# Convert the grid points to PyTorch tensor
= torch.tensor(grid_points, dtype=torch.float32)
grid_tensor
# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
= model(grid_tensor)
predictions = (predictions >= 0.5).float().numpy().reshape(xx.shape)
labels
# Create a new figure and an axes
= plt.subplots()
fig, ax
# Plot the decision boundary
=0.5, cmap=plt.cm.Set1)
ax.contourf(xx, yy, labels, alpha
# Plot the generated data
= ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
scatter
# Set labels and title
"Feature 1")
ax.set_xlabel("Feature 2")
ax.set_ylabel("Linearly Separable Data with Decision Boundary")
ax.set_title(
# Add a legend
= [
legend_elements
plt.Line2D(0], [0], marker="o", color="w", label="Class 0", markerfacecolor="grey", markersize=8
[
),0], [0], marker="o", color="w", label="Class 1", markerfacecolor="r", markersize=8),
plt.Line2D([
]=legend_elements)
ax.legend(handles
# Add a grid
True, color="lightgrey")
ax.grid(
# Display the plot
plt.show()
Multi-Layer Neural Network: Circle Data
# Define multi-layer neural network class
class MLP(nn.Module):
def __init__(self, input_size):
super().__init__()
self.linear = nn.Sequential(
100),
nn.Linear(input_size,
nn.ReLU(),100, 10),
nn.Linear(
nn.ReLU(),10, 1),
nn.Linear(
)self.sigmoid = nn.Sigmoid()
def forward(self, x):
= self.linear(x)
out = self.sigmoid(out)
out return out
# Create the model instance
= X.shape[1]
input_size = MLP(input_size)
model print(model)
# Define the loss function
= nn.BCELoss()
criterion
# Define the optimizer
= torch.optim.SGD(model.parameters(), lr=0.1)
optimizer
# Convert the data to PyTorch tensors
= torch.tensor(X, dtype=torch.float32)
X_tensor = torch.tensor(y, dtype=torch.float32)
y_tensor
# Train the model
= 1000
num_epochs for epoch in range(num_epochs):
# Forward pass
= model(X_tensor)
outputs = criterion(outputs, y_tensor.view(-1, 1))
loss
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)
MLP(
(linear): Sequential(
(0): Linear(in_features=2, out_features=100, bias=True)
(1): ReLU()
(2): Linear(in_features=100, out_features=10, bias=True)
(3): ReLU()
(4): Linear(in_features=10, out_features=1, bias=True)
)
(sigmoid): Sigmoid()
)
Trained model parameters:
linear.0.weight tensor([[-0.1552, 0.1662],
[ 0.3908, 0.0612],
[ 0.9017, -0.6374],
[ 0.6277, -0.2109],
[-1.1124, 0.1918],
[-0.3708, 0.5855],
[-0.6589, 0.4729],
[-0.0327, -0.0749],
[ 0.8348, 0.6613],
[ 0.0705, -0.0430],
[ 0.7425, 0.5211],
[ 0.8188, 0.7887],
[ 0.1538, 0.8428],
[-0.2951, 0.0111],
[-0.1261, -0.6081],
[-0.7813, -0.7862],
[ 0.4460, -0.1123],
[ 0.5477, 0.3837],
[-0.5079, -0.7850],
[-0.0103, 0.6520],
[-0.0152, -0.2751],
[-0.1480, -0.0415],
[-0.0583, 0.5241],
[ 0.1338, -0.4054],
[ 0.3726, -0.0295],
[-0.1696, 0.3993],
[-0.8201, -0.0072],
[-0.2884, 0.7008],
[ 0.1657, 0.5699],
[ 0.3853, 0.3299],
[ 0.7182, -0.3519],
[ 0.7159, 0.8815],
[ 0.5753, 0.1802],
[-0.3963, -0.6490],
[ 0.4729, 0.4773],
[-0.2325, -0.2971],
[-0.5070, -0.3100],
[-0.7994, 0.0789],
[-0.2821, -0.3900],
[ 0.0781, -0.2086],
[-0.4211, 0.5708],
[-0.0714, 0.0250],
[ 0.7506, 0.9815],
[-0.0383, -0.8640],
[ 0.2069, 0.9009],
[-0.4418, 0.1689],
[-0.8756, -0.7831],
[ 0.2057, -0.8478],
[ 0.2507, -0.1250],
[-0.2398, 0.7155],
[ 0.3409, 0.2362],
[-0.6451, 0.3934],
[-0.4057, -0.7846],
[ 0.8075, 0.6639],
[-0.3638, 1.0874],
[-0.2788, 0.0188],
[-0.5361, 0.3410],
[-1.0274, 0.1794],
[ 0.1164, 0.3535],
[-0.2864, -0.4005],
[-0.5056, 0.4572],
[-0.2261, -0.6815],
[ 0.8208, -0.4890],
[ 0.0553, -1.0744],
[ 0.6270, -0.0787],
[ 0.0265, -0.7450],
[ 0.2776, 0.5807],
[-0.2465, 0.6260],
[ 0.5899, -0.3780],
[ 0.7019, 0.2656],
[ 0.7896, -0.7558],
[ 0.5840, -0.0084],
[-0.5880, 0.2824],
[ 0.6537, -0.5193],
[ 0.4866, -0.1538],
[ 0.7754, 0.9852],
[-0.1837, 0.2515],
[-0.0591, -0.6305],
[-0.3105, 0.9375],
[ 0.8369, 0.8242],
[ 0.3355, -0.0236],
[-0.2867, -0.1375],
[ 0.4715, 0.5563],
[-0.1985, 0.7202],
[ 0.2024, -0.0913],
[-0.1607, 0.3007],
[ 0.3729, -0.1655],
[-1.0366, -0.0524],
[-0.6882, 0.0674],
[-0.7025, 0.0544],
[ 0.5286, 0.0897],
[-0.4045, 0.0972],
[ 0.2119, -0.4636],
[ 1.0665, 0.3329],
[-0.0363, 0.5967],
[-0.2567, 0.7672],
[ 0.7972, -1.0749],
[ 0.4944, 0.8297],
[ 0.3284, -0.1175],
[ 0.1765, -0.2422]])
linear.0.bias tensor([ 0.2559, -0.4661, -0.3638, -0.1724, -0.0470, -0.0437, -0.2974, -0.4033,
0.0828, 0.7428, -0.0816, -0.4901, 0.1522, 0.5311, 0.6859, -0.4152,
0.2144, -0.0608, 0.2163, 0.6613, 0.2834, -0.5163, 0.5305, 0.4867,
0.4400, 0.0984, 0.0130, 0.7802, 0.6411, 0.5216, -0.3364, -0.5505,
-0.2215, -0.1535, -0.2762, -0.4448, 0.6125, -0.0136, -0.5109, 0.3542,
0.0012, -0.5364, -0.3735, -0.0551, -0.4924, -0.2116, -0.0246, 0.1173,
0.6741, -0.2093, -0.0397, 0.2076, 0.2921, 0.0633, -0.3179, 0.2936,
0.8245, -0.0416, 0.0018, -0.5485, 0.1155, 0.0684, 0.0998, -0.3673,
-0.0914, -0.2500, -0.4425, -0.7039, 0.7244, 0.8031, -0.2319, 0.4367,
-0.1549, 0.1707, -0.5516, -0.3935, 0.3885, 0.0465, -0.2711, -0.4959,
0.0425, 0.0829, 0.1840, -0.1718, 0.0774, 0.6461, 0.0195, -0.2263,
-0.0039, -0.0451, 0.2014, -0.6608, 0.7401, -0.4128, 0.6558, -0.2245,
-0.2245, 0.2906, -0.4606, 0.3155])
linear.2.weight tensor([[ 1.8622e-02, 6.2621e-02, 3.3473e-01, 1.1944e-01, 3.5441e-02,
1.3992e-01, 4.2941e-02, 7.8328e-02, 2.0340e-01, -6.5843e-02,
2.4190e-01, 2.5889e-01, -4.4221e-03, -1.0388e-01, 1.2024e-01,
1.4759e-01, 2.2038e-02, 4.3134e-02, 1.3794e-01, -2.1506e-01,
-2.9418e-02, -4.5284e-02, -2.0085e-01, 1.9045e-02, 7.3719e-02,
1.6164e-03, 8.5542e-02, -1.5003e-01, -5.4246e-03, -7.6768e-02,
1.0049e-01, 1.7256e-01, 1.9108e-01, 2.2537e-01, 4.8786e-02,
-1.2172e-02, -4.9769e-02, 1.5382e-01, -6.8928e-03, 9.3331e-03,
1.0387e-01, 1.4768e-02, 1.9063e-01, 2.4183e-01, 1.7646e-01,
-7.5387e-03, 2.4984e-01, 1.4568e-01, 7.5192e-02, 1.3471e-02,
6.4658e-02, -2.9605e-02, 1.1065e-01, 1.8541e-01, 1.0226e-01,
-3.2445e-02, -2.4388e-01, 5.1804e-02, 4.7509e-02, 7.8389e-02,
7.8369e-04, 1.6394e-01, 2.1678e-01, 2.0788e-01, 1.6703e-01,
1.6338e-01, 4.4945e-02, -3.7065e-02, 4.7272e-02, -5.9655e-02,
1.6999e-01, 3.1912e-02, 5.6825e-02, 1.9961e-01, -1.1651e-02,
1.3927e-01, -4.1680e-03, 6.9804e-02, 5.9821e-04, 2.0834e-01,
5.2322e-02, 7.3099e-03, 1.9956e-01, -5.8514e-03, -4.1797e-03,
-1.5351e-01, -8.4918e-03, 1.6469e-01, 1.4604e-01, 9.3315e-02,
6.2983e-02, 8.3957e-02, 1.2398e-01, 2.2664e-01, -1.5328e-01,
-1.6967e-02, 2.2102e-01, 1.0020e-01, 2.9564e-02, 9.4237e-02],
[ 6.8362e-02, -3.3226e-02, 3.7335e-02, 5.1855e-04, 1.9892e-02,
2.4783e-02, 1.8599e-01, -9.6560e-02, -7.3750e-02, 1.0674e-02,
-6.4436e-02, 2.7056e-02, 2.9267e-02, 3.4604e-02, -7.3872e-02,
3.3928e-02, 1.3361e-02, 1.0624e-02, -3.7467e-02, -1.1267e-01,
-1.1807e-01, 6.4024e-02, 7.7406e-02, -1.5442e-01, -5.7367e-02,
-3.4019e-02, 1.3052e-01, 4.4240e-02, -5.6302e-02, -6.4122e-03,
-5.0595e-02, -2.6326e-02, 2.1997e-02, -7.5018e-03, 8.7540e-02,
-7.5942e-02, 7.0556e-02, 1.8117e-01, -6.5708e-02, -5.2910e-02,
4.0863e-02, -2.3050e-03, 4.8455e-02, -1.7060e-02, 6.5711e-02,
-3.2473e-02, 1.5825e-01, -8.5744e-02, -1.0961e-04, 3.4896e-02,
-4.2165e-02, 7.3071e-02, -4.0172e-02, -8.4240e-02, 8.1444e-02,
1.0928e-01, -6.0524e-02, 1.5955e-01, 7.6167e-02, -8.3315e-02,
9.3807e-02, 5.3447e-02, -8.5636e-02, 1.3680e-03, -4.5060e-03,
1.0577e-02, 4.2857e-02, -8.7732e-02, -2.8685e-02, -1.0850e-01,
9.1827e-02, -9.0820e-02, 4.1101e-02, 3.4341e-02, -4.1979e-02,
-4.2798e-02, 3.7781e-02, 3.3529e-02, -2.5193e-02, 5.6737e-02,
5.5406e-02, 1.0230e-01, -9.4538e-02, 2.7106e-02, 8.0571e-02,
-3.8201e-02, 6.9393e-02, 2.1378e-01, -2.1380e-02, 7.2287e-02,
4.0608e-02, 2.0958e-02, -5.3582e-02, -3.2791e-02, 5.3421e-02,
7.0369e-02, -9.7057e-02, -7.0962e-02, -1.3053e-02, -1.3154e-01],
[ 2.5573e-02, -8.8045e-02, 1.1591e-01, 5.0938e-02, 9.9390e-02,
6.2487e-02, 7.5293e-02, 5.3585e-02, 1.9810e-01, -6.3387e-02,
4.4192e-02, 3.6844e-02, -7.9424e-02, -3.7022e-02, 1.2105e-01,
2.3678e-01, 1.1041e-01, 1.0673e-01, 1.6140e-01, -2.2515e-01,
6.5968e-02, -5.3194e-02, -7.2484e-02, 1.0130e-01, -1.1547e-01,
1.8224e-02, 1.3817e-01, -1.5857e-01, -1.0646e-01, -1.5017e-01,
8.6377e-02, 1.0853e-01, 5.5264e-02, 4.1839e-02, 8.9809e-02,
1.3154e-02, 9.3088e-02, 7.2051e-02, -7.4552e-02, -3.5353e-02,
7.3739e-02, 6.8989e-02, 9.3409e-02, 1.9123e-01, -6.8728e-02,
3.6068e-02, 1.0257e-01, 1.3651e-01, -7.0038e-02, -7.1325e-02,
-1.8393e-02, 2.8170e-02, 1.5445e-01, 6.0257e-02, 1.1762e-01,
-1.2453e-01, -2.4311e-01, 1.9340e-01, 2.9600e-02, 2.5503e-03,
-7.5670e-02, 1.3132e-01, 2.0947e-01, 7.9620e-02, 1.9708e-01,
1.9160e-01, 2.7549e-02, -9.8971e-02, 1.0308e-01, 5.9113e-03,
2.3905e-01, 7.2379e-02, -6.1117e-03, 2.2074e-02, 1.5481e-02,
8.7933e-02, -9.4487e-02, -3.8530e-03, -8.8162e-02, 7.8810e-02,
8.9444e-03, -3.8364e-02, 8.5541e-02, 7.2419e-02, 5.4599e-02,
-1.7598e-01, 1.5230e-01, 1.9488e-01, 2.2441e-02, 3.7222e-02,
9.8094e-02, -9.7661e-02, -1.3322e-02, 2.3324e-01, -2.1238e-01,
9.7348e-03, 1.9131e-01, -3.0601e-02, -6.7119e-02, -1.0651e-02],
[ 1.2001e-01, 6.6184e-02, -3.8852e-01, -1.6756e-01, -4.1074e-01,
-2.7874e-01, -3.2311e-01, 4.6253e-02, -2.7442e-01, 2.8411e-01,
-2.5296e-01, -2.0116e-01, -9.7619e-02, 1.9062e-01, 3.3622e-01,
-2.2674e-01, -4.8533e-02, -1.0872e-01, -2.1704e-01, 2.3131e-01,
1.4035e-01, -8.9710e-02, 1.2239e-01, 2.1842e-01, 1.5730e-01,
-8.5866e-02, -2.3182e-01, 2.1580e-01, 1.9729e-01, 1.7873e-01,
-2.9578e-01, -2.0459e-01, -1.7349e-01, -1.6372e-01, -1.2350e-01,
6.6234e-04, 2.4676e-01, -2.7329e-01, 1.2991e-02, 1.7668e-01,
-1.2833e-01, -1.5400e-02, -4.0336e-01, -8.7595e-02, -2.4353e-01,
-1.2881e-01, -3.9048e-01, -2.2145e-01, 2.3017e-01, -3.0747e-01,
-1.2543e-01, -1.7666e-01, -2.3409e-02, -1.8915e-01, -3.4945e-01,
1.2548e-01, 2.9361e-01, -2.9707e-01, -1.6898e-01, 7.6719e-02,
-2.0233e-01, -5.7121e-02, -1.8372e-01, -2.6115e-01, -1.1382e-01,
-1.3189e-01, -8.4151e-02, -1.0080e-02, 1.5255e-01, 3.1058e-01,
-3.8296e-01, 1.0671e-01, -2.1891e-01, -7.6357e-02, 2.1142e-02,
-2.9106e-01, 1.4400e-01, -1.4506e-01, -2.8305e-01, -2.3570e-01,
-8.7220e-02, -7.0244e-03, -7.5768e-02, -3.4835e-01, 4.1225e-02,
2.2703e-01, -1.5725e-01, -2.9117e-01, -1.2888e-01, -1.7092e-01,
3.4581e-03, 5.3998e-02, 3.3479e-01, -2.9778e-01, 2.3645e-01,
-3.4729e-01, -2.7921e-01, -1.5069e-01, 1.9752e-02, 5.8065e-02],
[-5.4553e-02, -4.5602e-02, 3.2947e-01, 1.0174e-01, -5.2202e-04,
8.2068e-02, -3.8506e-02, -5.3125e-02, 8.6090e-02, 6.1649e-02,
3.6994e-03, 9.2804e-02, 4.3781e-02, -1.2755e-01, -4.7284e-03,
1.6225e-01, -2.9808e-02, -3.9025e-02, 1.3004e-01, -2.2772e-01,
8.8639e-02, 6.9559e-02, -2.0367e-01, 1.3371e-01, 6.5520e-02,
-1.0507e-01, 1.2314e-01, -2.4592e-01, -1.0513e-01, -4.0831e-03,
1.7138e-01, -2.4589e-02, 4.0559e-02, 1.6504e-01, 5.6009e-02,
-6.1749e-02, -9.3438e-02, 7.2871e-02, -4.9889e-02, 1.1650e-01,
2.2124e-02, -8.0194e-02, 1.3347e-01, 1.6645e-01, 9.0684e-02,
-5.9240e-02, 1.7189e-01, 2.3240e-01, 4.5507e-02, -6.2648e-02,
4.0546e-02, -4.1254e-02, 1.7739e-01, 3.0887e-02, -3.5224e-02,
-3.7210e-02, -1.5947e-01, -3.1872e-02, -3.3351e-02, -3.6145e-02,
7.8965e-02, 1.1181e-01, 1.6171e-01, 1.5147e-01, 2.1435e-01,
4.9757e-02, -9.4517e-02, -5.3667e-02, 1.2132e-01, -7.5534e-02,
2.8248e-01, 7.5536e-02, -6.7861e-02, 2.3059e-01, -7.1580e-02,
6.7014e-02, -8.1256e-02, 1.1424e-01, 6.3468e-02, 6.2035e-02,
1.4526e-01, 8.1979e-02, -1.5853e-03, 4.3363e-02, 7.5537e-02,
-2.1663e-01, 1.8044e-02, -1.5894e-02, 8.0112e-02, 9.1889e-02,
1.0201e-01, -5.0367e-02, 3.2039e-02, 2.0247e-01, -2.5252e-01,
8.4274e-02, 3.6241e-01, 8.3770e-02, 9.2811e-02, 2.2759e-02],
[-9.9068e-02, 8.3931e-02, 1.4499e-01, 2.7427e-02, 1.9210e-02,
9.4238e-02, 1.0492e-01, -7.1372e-02, 6.0575e-02, 1.4431e-02,
-2.9898e-03, 1.3384e-01, -9.9928e-02, 2.3328e-02, 1.0005e-01,
9.5734e-02, 8.8587e-02, -8.5461e-03, 1.7961e-01, -1.3813e-01,
4.2987e-02, 1.7874e-02, -1.9960e-02, 5.1177e-03, 8.2496e-02,
1.6523e-02, 3.0284e-03, -2.0240e-01, -1.0978e-01, -1.0157e-01,
7.4652e-02, 5.1767e-02, -1.8972e-02, 1.0541e-01, 4.0136e-02,
7.2443e-02, 3.8280e-02, 9.0670e-02, 8.1112e-02, 5.0933e-02,
6.3011e-02, 8.6634e-03, 6.6254e-02, 1.7582e-01, -5.9472e-02,
-2.0289e-02, 7.3015e-02, 1.4283e-02, -3.8810e-02, 9.7630e-02,
1.1447e-01, 1.1003e-01, 1.0590e-01, 7.7853e-02, 8.9942e-02,
-8.6978e-02, -1.5699e-01, 1.6661e-01, -3.5301e-02, -9.3873e-02,
1.8533e-02, 8.6197e-02, 1.1259e-01, 1.4526e-01, 1.1808e-01,
6.1545e-02, -4.5351e-02, -6.2482e-02, 1.2297e-01, -1.3761e-02,
5.4073e-02, -2.2039e-02, -1.8448e-02, 1.0806e-01, 7.6928e-02,
9.6910e-02, -1.1135e-01, 1.5403e-01, 9.4210e-02, 8.7735e-02,
2.4225e-02, -4.6059e-02, 1.0244e-01, -5.1533e-02, 7.1413e-03,
-6.4063e-03, 5.0946e-02, 7.4849e-05, -1.1183e-02, 2.5344e-02,
8.3296e-02, 9.3464e-03, 1.0900e-01, 1.9461e-01, -1.9923e-01,
9.8326e-02, 1.4660e-01, -5.7975e-02, 9.2319e-02, 8.1158e-02],
[-6.7167e-02, 7.1184e-04, 1.5791e-01, -9.8793e-03, 2.1015e-01,
7.5992e-03, 2.7879e-02, -7.8128e-02, 6.3670e-03, -7.0086e-02,
-5.9196e-02, -4.4937e-02, 9.6057e-03, -2.0698e-02, 9.7075e-02,
2.6569e-01, -6.2905e-02, 6.1533e-02, 1.7112e-01, -1.3718e-01,
8.8466e-02, -7.8341e-02, -2.3541e-01, -4.4109e-02, -1.6006e-01,
-5.0120e-02, 1.8363e-01, -2.1196e-01, -1.9872e-01, -1.2940e-01,
3.3612e-02, 3.6166e-02, -2.2094e-02, 1.0273e-01, 8.0716e-02,
-6.6213e-02, -1.6117e-03, 2.0900e-01, 7.8651e-03, -1.3583e-03,
3.6920e-02, 9.9968e-02, 8.3010e-02, 2.3016e-01, -8.4851e-02,
8.9986e-02, 1.8001e-01, 1.3624e-01, -6.3242e-02, -2.4393e-02,
7.6437e-02, 4.5540e-03, 1.3074e-01, 3.9968e-02, 2.9840e-02,
7.1045e-03, -8.2129e-02, 2.0141e-01, -7.8104e-02, -8.0765e-02,
-6.6262e-04, 2.0895e-01, 1.3460e-01, 2.2653e-01, -2.8971e-02,
7.9451e-02, -4.3767e-03, 7.9412e-02, -1.6735e-02, -2.4690e-01,
9.0527e-02, -1.2564e-01, 4.2072e-02, -3.8262e-02, -2.7756e-03,
7.3702e-02, -5.3369e-02, 1.3912e-02, -6.7335e-02, -1.0511e-02,
-9.6422e-03, -2.9568e-02, 5.7698e-03, -4.0521e-02, -4.2858e-02,
-9.5626e-02, 5.2717e-02, 1.8911e-01, 1.1738e-01, 2.1925e-01,
-8.4953e-02, 9.4305e-02, -3.2372e-02, 4.5195e-02, -2.7433e-01,
9.5198e-02, 2.1260e-01, -9.3293e-02, 3.7255e-02, -1.6457e-02],
[ 7.9190e-02, -9.0281e-02, -4.8434e-01, -2.2859e-01, -4.7977e-01,
-2.8357e-01, -3.3863e-01, 5.7739e-02, -3.9289e-01, 3.6095e-01,
-2.4943e-01, -3.0130e-01, -1.9908e-01, 2.6051e-01, 3.4980e-01,
-4.7705e-01, 3.9321e-02, -3.3786e-01, -2.0962e-01, 2.4211e-01,
1.1756e-01, -5.5718e-02, 1.5113e-01, 2.4147e-01, 2.5145e-01,
-6.7860e-02, -2.4019e-01, 2.6808e-01, 2.7704e-01, 1.9411e-01,
-3.4520e-01, -3.6380e-01, -2.2810e-01, -1.1759e-01, -1.4810e-01,
-1.7398e-02, 2.7162e-01, -1.9678e-01, 1.8633e-02, 2.2822e-01,
-2.7284e-01, -8.7660e-02, -3.8525e-01, -1.7863e-01, -3.8974e-01,
-1.2402e-01, -3.7496e-01, -3.3932e-01, 3.9204e-01, -2.9171e-01,
-2.0227e-01, -2.2362e-01, 2.8815e-02, -2.2475e-01, -4.1650e-01,
1.5006e-01, 3.5947e-01, -4.0337e-01, -1.4737e-01, -4.8684e-02,
-2.7786e-01, -7.9308e-02, -1.3816e-01, -4.3264e-01, -2.1838e-01,
-2.5365e-01, -1.1853e-01, 6.2728e-02, 3.8133e-01, 3.9223e-01,
-3.5855e-01, 7.4689e-02, -1.7893e-01, -8.7796e-02, -4.0315e-02,
-3.8301e-01, 1.5087e-01, -1.7118e-02, -3.3258e-01, -3.4386e-01,
9.4994e-03, -1.7390e-02, -2.7306e-01, -2.5622e-01, -7.9605e-02,
2.6046e-01, -8.5071e-02, -3.7498e-01, -1.9644e-01, -2.7132e-01,
1.9121e-03, 5.6775e-02, 3.1357e-01, -4.1275e-01, 1.9716e-01,
-2.8575e-01, -4.0113e-01, -3.0445e-01, 1.8189e-02, 2.6216e-01],
[ 7.9037e-02, -7.3755e-02, 7.2562e-02, 7.8806e-02, -5.2712e-02,
5.8386e-02, -1.2423e-03, 7.9217e-02, 3.3525e-02, 5.0399e-02,
8.8164e-02, 2.3555e-02, 1.7022e-02, 3.3145e-02, -6.5795e-02,
-8.3565e-02, -5.4582e-02, -7.1657e-02, 4.7716e-02, 8.8822e-02,
8.3683e-02, -5.5514e-02, 8.6972e-03, 2.2264e-02, -9.7950e-02,
7.5916e-03, -8.2544e-02, -4.7238e-02, 1.4602e-02, 9.3010e-02,
-1.7160e-02, -1.0316e-01, 3.0830e-02, -6.4606e-02, 1.4981e-02,
-1.5266e-02, -3.9313e-02, -7.5917e-02, 1.0241e-02, -8.6801e-02,
-4.0603e-02, 8.2484e-02, -3.7631e-02, 4.1494e-02, -1.2643e-02,
-1.3242e-02, 3.8772e-02, -1.0443e-01, 9.8534e-02, 1.7557e-02,
-4.5179e-02, 9.1086e-03, 3.9724e-02, 3.1044e-02, -1.2642e-02,
3.4352e-02, -2.6861e-02, 4.1675e-02, -6.6145e-02, -4.4983e-02,
-3.3449e-02, 7.8412e-02, 1.1455e-02, -5.8443e-02, 8.1806e-02,
-9.7032e-02, -6.2266e-02, -4.0597e-02, 2.7196e-02, -3.9327e-02,
3.6457e-02, -4.1492e-02, 5.7631e-02, -2.0380e-02, -2.2952e-02,
1.4316e-02, -7.8320e-02, -6.7979e-02, -1.7935e-03, 8.0750e-02,
7.1062e-02, -5.4308e-02, -1.3770e-02, 1.1597e-02, 3.4109e-03,
2.2041e-02, -9.7834e-02, 3.3806e-02, -7.7119e-02, 4.1218e-02,
2.8916e-03, 8.4049e-02, -4.7936e-02, -1.8897e-02, -1.1663e-02,
7.2805e-02, 2.5388e-02, -5.3749e-02, 7.6119e-02, 2.3170e-02],
[-1.2779e-01, -7.2839e-02, 4.3740e-01, 3.3655e-01, 5.9364e-01,
2.8691e-01, 4.3598e-01, -7.2318e-02, 4.0493e-01, -3.1163e-01,
3.4899e-01, 4.0392e-01, 3.3925e-01, -1.4136e-01, -3.0471e-01,
5.2648e-01, 4.4249e-02, 2.3727e-01, 3.5257e-01, -1.5803e-01,
-8.4397e-02, 5.7744e-03, -1.5526e-01, -1.5471e-01, -1.5006e-01,
1.3938e-01, 2.0168e-01, -1.7486e-01, -3.0012e-01, -1.9509e-01,
3.6304e-01, 3.5061e-01, 3.4702e-01, 3.7665e-01, 3.3841e-01,
4.3340e-02, -5.6519e-02, 2.7955e-01, 2.6702e-02, -9.8020e-02,
3.4523e-01, 5.2789e-02, 6.1452e-01, 4.5269e-01, 5.0737e-01,
3.0226e-01, 4.4934e-01, 3.5509e-01, -2.8427e-01, 3.4035e-01,
2.3064e-01, 3.1353e-01, 1.7449e-01, 4.6626e-01, 6.2623e-01,
-4.0061e-02, -3.4904e-01, 4.2496e-01, 1.5327e-01, 7.3077e-02,
2.9751e-01, 1.6017e-01, 3.1879e-01, 4.4732e-01, 3.3274e-01,
3.6114e-01, 4.4661e-02, -5.9551e-02, -2.9137e-01, -2.9942e-01,
5.5477e-01, 3.2878e-02, 3.9211e-01, 2.8149e-01, -9.4424e-02,
6.5450e-01, -2.0422e-01, 2.6395e-01, 5.1152e-01, 5.0055e-01,
5.7453e-02, 4.8850e-02, 1.6315e-01, 3.6555e-01, 1.5478e-02,
-1.7038e-01, 1.7878e-01, 5.3672e-01, 3.0686e-01, 3.4304e-01,
1.1242e-01, -9.6645e-02, -3.6526e-01, 5.6846e-01, -1.7448e-01,
4.4394e-01, 6.9694e-01, 3.3837e-01, 6.0589e-03, -3.6500e-02]])
linear.2.bias tensor([-0.1901, -0.1517, -0.0684, 0.3705, -0.1703, -0.1565, -0.1555, 0.6600,
0.0371, -0.4299])
linear.4.weight tensor([[-1.0731, -0.5081, -0.8607, 1.7612, -1.0147, -0.6238, -0.9651, 2.2342,
-0.0099, -2.6831]])
linear.4.bias tensor([0.5882])
# Generate a grid of points
= X[:, 0].min() - 1, X[:, 0].max() + 1
x_min, x_max = X[:, 1].min() - 1, X[:, 1].max() + 1
y_min, y_max = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
xx, yy = np.c_[xx.ravel(), yy.ravel()]
grid_points
# Convert the grid points to PyTorch tensor
= torch.tensor(grid_points, dtype=torch.float32)
grid_tensor
# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
= model(grid_tensor)
predictions = (predictions >= 0.5).float().numpy().reshape(xx.shape)
labels
# Create a new figure and an axes
= plt.subplots()
fig, ax
# Plot the decision boundary
=0.5, cmap=plt.cm.Set1)
ax.contourf(xx, yy, labels, alpha
# Plot the generated data
= ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
scatter
# Set labels and title
"Feature 1")
ax.set_xlabel("Feature 2")
ax.set_ylabel("Circle Data with Decision Boundary")
ax.set_title(
# Add a legend
= [
legend_elements
plt.Line2D(0], [0], marker="o", color="w", label="Class 0", markerfacecolor="grey", markersize=8
[
),0], [0], marker="o", color="w", label="Class 1", markerfacecolor="r", markersize=8),
plt.Line2D([
]=legend_elements)
ax.legend(handles
# Add a grid
True, color="lightgrey")
ax.grid(
# Display the plot
plt.show()
Multi-Layer Neural Network: Five Class Data
from sklearn.datasets import make_blobs
# Generate the dataset
= make_blobs(n_samples=1000, centers=5, random_state=42)
X, y
# Print the shape of the dataset
print("Shape of X:", X.shape)
print("Shape of y:", y.shape)
Shape of X: (1000, 2)
Shape of y: (1000,)
# Create a new figure and an axes
= plt.subplots()
fig, ax
# Plot the generated data
= ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
scatter
# Set labels and title
"Feature 1")
ax.set_xlabel("Feature 2")
ax.set_ylabel("Linear Data")
ax.set_title(
# Add a grid
="lightgrey", linestyle="--")
ax.grid(color
# Display the plot
plt.show()
y
array([1, 1, 2, 1, 4, 2, 3, 4, 1, 4, 4, 2, 3, 2, 3, 2, 0, 3, 0, 4, 1, 0,
3, 2, 0, 1, 0, 3, 4, 3, 1, 3, 4, 3, 0, 2, 4, 4, 4, 0, 4, 3, 3, 2,
3, 2, 1, 1, 0, 2, 1, 3, 2, 0, 0, 0, 0, 2, 2, 0, 1, 3, 2, 4, 2, 1,
3, 3, 1, 1, 3, 2, 1, 2, 3, 0, 3, 1, 1, 3, 0, 1, 2, 2, 3, 1, 0, 4,
0, 3, 1, 3, 1, 4, 4, 2, 0, 0, 2, 3, 2, 4, 1, 2, 0, 1, 0, 2, 4, 0,
0, 1, 2, 2, 3, 0, 3, 2, 0, 4, 2, 0, 2, 2, 2, 2, 3, 2, 0, 1, 1, 3,
1, 4, 2, 4, 1, 4, 0, 3, 2, 0, 1, 0, 1, 1, 2, 1, 1, 4, 3, 2, 1, 1,
0, 1, 1, 0, 3, 1, 3, 1, 3, 0, 3, 1, 4, 3, 0, 1, 2, 0, 0, 2, 4, 2,
3, 2, 1, 4, 1, 0, 1, 3, 4, 4, 1, 2, 3, 0, 3, 0, 1, 4, 1, 0, 4, 2,
0, 1, 2, 3, 2, 2, 3, 4, 4, 0, 4, 1, 0, 3, 3, 4, 1, 2, 2, 2, 3, 4,
2, 2, 0, 2, 0, 4, 2, 4, 1, 1, 0, 1, 0, 1, 1, 3, 0, 0, 3, 1, 4, 3,
2, 4, 2, 2, 2, 2, 3, 4, 2, 2, 4, 1, 4, 0, 3, 1, 1, 0, 1, 4, 1, 3,
1, 3, 3, 3, 2, 2, 1, 1, 3, 4, 0, 4, 2, 0, 3, 4, 4, 2, 1, 4, 2, 3,
1, 1, 1, 0, 4, 0, 3, 4, 1, 3, 4, 4, 0, 4, 3, 3, 0, 4, 1, 3, 3, 0,
2, 3, 3, 2, 4, 4, 4, 3, 2, 2, 1, 2, 4, 0, 2, 4, 3, 2, 4, 2, 3, 0,
0, 3, 0, 0, 1, 4, 3, 3, 3, 0, 3, 1, 4, 3, 4, 4, 2, 2, 0, 1, 2, 3,
0, 0, 0, 0, 0, 1, 0, 1, 4, 1, 1, 4, 1, 2, 2, 1, 3, 2, 2, 3, 2, 2,
2, 1, 1, 2, 0, 1, 0, 4, 1, 0, 3, 3, 4, 2, 2, 4, 0, 4, 3, 3, 2, 0,
3, 0, 4, 2, 4, 1, 1, 1, 0, 3, 0, 2, 1, 2, 2, 0, 0, 3, 2, 0, 4, 2,
4, 3, 3, 0, 4, 1, 0, 1, 0, 1, 0, 4, 3, 3, 4, 3, 0, 0, 3, 2, 0, 2,
3, 3, 2, 3, 0, 2, 2, 4, 2, 0, 1, 2, 1, 4, 4, 4, 2, 4, 1, 1, 4, 0,
3, 3, 1, 0, 1, 1, 0, 1, 2, 2, 2, 4, 3, 0, 4, 1, 1, 4, 4, 3, 2, 1,
2, 0, 3, 2, 4, 4, 1, 0, 1, 3, 1, 1, 0, 3, 0, 3, 0, 4, 2, 4, 4, 3,
3, 1, 4, 0, 1, 3, 1, 1, 3, 0, 3, 2, 4, 0, 0, 0, 0, 1, 1, 1, 2, 1,
2, 2, 2, 0, 0, 1, 1, 3, 1, 0, 0, 2, 2, 3, 3, 4, 3, 0, 4, 1, 0, 1,
3, 3, 4, 0, 0, 0, 1, 1, 2, 4, 0, 0, 1, 0, 0, 0, 3, 4, 1, 2, 1, 0,
2, 2, 3, 4, 1, 4, 4, 4, 1, 0, 3, 4, 4, 0, 4, 1, 3, 2, 2, 4, 3, 4,
3, 0, 3, 4, 0, 2, 0, 2, 2, 1, 2, 0, 3, 4, 4, 1, 4, 1, 1, 4, 0, 3,
4, 4, 4, 1, 0, 4, 3, 1, 0, 1, 4, 3, 2, 4, 2, 1, 2, 3, 4, 0, 1, 3,
0, 3, 4, 2, 4, 2, 0, 2, 4, 1, 1, 2, 4, 4, 2, 2, 4, 4, 2, 3, 1, 0,
2, 0, 2, 4, 1, 4, 1, 2, 0, 1, 3, 0, 2, 2, 4, 1, 3, 4, 3, 2, 3, 0,
1, 3, 2, 1, 1, 1, 2, 1, 0, 0, 4, 1, 0, 2, 0, 2, 2, 4, 1, 4, 0, 4,
2, 3, 3, 4, 1, 3, 3, 2, 4, 4, 0, 0, 4, 4, 1, 3, 4, 2, 0, 3, 3, 4,
1, 0, 4, 2, 4, 1, 1, 4, 2, 1, 3, 1, 3, 4, 1, 2, 3, 4, 4, 3, 2, 0,
0, 3, 2, 3, 1, 0, 4, 4, 3, 3, 4, 3, 3, 3, 4, 4, 3, 0, 2, 3, 0, 0,
1, 4, 1, 3, 2, 0, 2, 4, 0, 3, 0, 3, 4, 1, 3, 0, 1, 3, 0, 3, 4, 2,
2, 4, 0, 4, 3, 0, 0, 0, 4, 2, 4, 4, 1, 4, 1, 0, 2, 2, 4, 0, 1, 0,
0, 1, 4, 1, 4, 3, 0, 3, 0, 3, 4, 4, 2, 0, 1, 3, 2, 4, 0, 1, 3, 1,
4, 2, 1, 1, 3, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 1, 3, 3, 4, 4, 3, 3,
3, 2, 4, 3, 2, 2, 0, 4, 1, 3, 4, 0, 4, 3, 4, 1, 4, 3, 2, 2, 3, 2,
3, 0, 0, 0, 0, 4, 4, 3, 2, 2, 3, 0, 3, 3, 2, 3, 3, 3, 0, 4, 4, 1,
1, 4, 4, 0, 1, 1, 1, 4, 0, 3, 0, 4, 1, 2, 2, 2, 3, 4, 0, 3, 1, 3,
4, 4, 3, 0, 2, 3, 2, 0, 1, 1, 3, 0, 0, 4, 2, 4, 4, 2, 0, 0, 4, 0,
3, 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 1, 0, 3, 1, 1, 3, 0, 2, 2, 3, 2,
3, 1, 4, 1, 4, 2, 1, 3, 3, 0, 2, 4, 3, 4, 1, 4, 2, 1, 1, 2, 4, 0,
0, 0, 2, 2, 1, 0, 1, 3, 1, 3])
# Define multi-layer neural network class
class MLP(nn.Module):
def __init__(self, input_size):
super().__init__()
self.linear = nn.Sequential(
100),
nn.Linear(input_size,
nn.ReLU(),100, 10),
nn.Linear(
nn.ReLU(),10, 10),
nn.Linear(
nn.ReLU(),10, 5),
nn.Linear(
)
def forward(self, x):
= self.linear(x)
out return out
# Create the model instance
= X.shape[1]
input_size = MLP(input_size)
model print(model)
# Define the loss function
= nn.CrossEntropyLoss()
criterion
# Define the optimizer
= torch.optim.Adam(model.parameters(), lr=0.01)
optimizer
# Convert the data to PyTorch tensors
= torch.tensor(X, dtype=torch.float32)
X_tensor = torch.tensor(y, dtype=torch.long)
y_tensor
# Train the model
= 1000
num_epochs for epoch in range(num_epochs):
# Forward pass
= model(X_tensor)
outputs = criterion(outputs, y_tensor)
loss
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print the trained model parameters
print("Trained model parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)
MLP(
(linear): Sequential(
(0): Linear(in_features=2, out_features=100, bias=True)
(1): ReLU()
(2): Linear(in_features=100, out_features=10, bias=True)
(3): ReLU()
(4): Linear(in_features=10, out_features=10, bias=True)
(5): ReLU()
(6): Linear(in_features=10, out_features=5, bias=True)
)
)
Trained model parameters:
linear.0.weight tensor([[-0.7157, -0.0858],
[ 0.8468, 0.7653],
[-0.6592, -0.3552],
[-0.3824, -0.4735],
[ 0.7051, 0.5343],
[ 0.9784, 0.6403],
[-0.6483, 0.4195],
[-0.8238, -0.3240],
[ 0.3995, 0.3407],
[ 0.1389, 0.1510],
[ 0.5828, 0.1263],
[-0.4348, -0.2220],
[-0.9549, 0.6182],
[ 0.6842, -0.9855],
[-0.1463, -0.4805],
[-0.1099, -0.2883],
[-0.3011, -0.5472],
[-0.3674, -0.4947],
[ 0.2536, -0.5226],
[-0.1205, -0.4218],
[-0.0295, -0.4661],
[-0.3602, 0.7304],
[ 0.5083, 0.6896],
[-0.0423, -0.6907],
[ 0.6464, -0.9132],
[-0.5978, 0.3346],
[ 0.3256, 0.2024],
[-0.7463, -0.3477],
[ 0.6988, 0.1492],
[ 0.6165, 0.4373],
[ 0.6975, 0.1740],
[-0.6336, -0.3413],
[-0.7579, -0.4046],
[-0.6774, -0.3200],
[-1.0890, 1.1611],
[-0.1804, -0.7217],
[-0.9017, -0.5263],
[-0.6620, -0.9339],
[-0.1110, -0.5026],
[-0.3379, 0.1735],
[-0.9509, 0.2527],
[ 0.6989, 0.1140],
[-0.8720, 0.2883],
[-0.3166, 0.7973],
[ 0.8615, 0.5377],
[-0.6605, 0.2671],
[-0.3825, 0.3720],
[-0.1639, -0.5769],
[ 0.0574, 0.6976],
[-0.3036, 0.4223],
[ 0.2659, 0.5329],
[-0.8830, -0.2844],
[ 0.9417, -0.0289],
[ 0.6779, 0.0245],
[-0.1493, 0.5753],
[-0.5855, 0.7524],
[-0.5105, -0.5524],
[-0.3072, -0.1284],
[ 0.2744, -0.3410],
[ 0.3426, 0.5276],
[ 0.8322, 0.7613],
[-0.1371, -0.6629],
[ 1.4103, 0.0273],
[-0.0424, 0.4994],
[-0.2259, -0.4207],
[ 0.2477, 0.5408],
[-0.1335, 0.4674],
[-0.5586, 0.2430],
[-0.5568, 0.1298],
[ 0.4542, 0.2175],
[-0.0707, 0.8915],
[ 0.1364, 0.5181],
[-0.8475, -0.2746],
[ 0.7219, 0.1751],
[-0.6209, 0.2819],
[ 0.7304, -1.0402],
[-1.1036, -0.3350],
[ 0.5568, -0.5678],
[-0.2201, -0.4648],
[-0.7181, 0.7876],
[ 0.5435, 0.6588],
[ 0.5208, -0.7508],
[ 0.6194, -0.0130],
[ 0.8899, -0.1771],
[ 0.9394, -1.3419],
[-0.5504, 0.1257],
[ 0.1923, 0.2786],
[ 0.6418, 0.5386],
[ 0.4070, 0.6006],
[ 0.3295, -0.9100],
[-0.2571, 0.7362],
[ 0.2759, 0.3048],
[ 0.0132, -0.0330],
[-0.6007, -0.1256],
[-0.3616, -0.8582],
[ 0.3765, 0.1562],
[ 0.7790, -1.1187],
[-0.1705, 0.6608],
[-0.2889, -0.4123],
[ 0.8092, 0.0993]])
linear.0.bias tensor([-0.3259, 0.4577, -0.7003, 0.6879, 0.2302, 1.2855, 0.4530, 0.3994,
0.8464, 0.3627, -0.6694, 0.4105, 0.6682, -0.6181, 0.1842, 0.3095,
0.7343, 0.0305, -0.0508, -0.0894, -0.3163, -0.4564, 0.9797, 0.1803,
-0.5751, -0.4383, -0.7591, -0.4591, -0.7893, 0.5091, 1.5919, -0.7568,
-0.4396, -0.2020, -1.4721, 0.2546, 0.2386, 0.7231, 0.4761, -0.9233,
-1.2430, -0.9320, -0.5040, 0.6994, 0.0944, -0.3940, 0.2890, 0.2279,
-0.4219, 0.0791, 0.7432, 0.2517, 0.0751, -0.2889, -1.4752, 0.3308,
-0.4447, -0.5624, 0.2417, 0.4560, 0.3327, -0.5938, 0.3420, -0.8407,
-0.0536, 0.6167, 0.3761, -0.4051, -0.8270, 1.2853, -0.0592, 0.9944,
-0.1577, 0.8056, -1.5105, -0.6463, 1.9510, -0.4385, 0.6694, -0.9947,
0.8537, -0.4753, 1.5830, 0.4305, -0.8299, -0.3378, 0.5756, 0.4544,
0.8127, 1.5117, 0.6624, 0.6745, -0.6311, 0.6612, -0.4096, 0.3036,
-0.6912, -0.7848, 1.2749, -0.3954])
linear.2.weight tensor([[ 1.5666e-01, -2.2399e-01, -3.3300e-03, -1.4687e-01, -2.3979e-01,
-1.1060e-01, 2.2122e-01, 1.7277e-01, -2.1965e-01, -2.1749e-01,
-9.9902e-01, 2.2263e-01, 3.0209e-01, -2.0226e-01, -2.0015e-01,
-1.9789e-01, -1.0800e-01, -2.2008e-01, -6.0523e-02, -6.8356e-02,
-2.3571e-01, 2.5550e-01, -9.2719e-02, -1.4963e-01, -2.3781e-01,
2.8744e-01, 3.8025e-01, 6.7432e-02, -1.2754e+00, -2.1807e-01,
-5.7075e-01, 9.2589e-02, -2.1055e-03, 3.3537e-02, 3.6740e-01,
-2.3746e-01, 3.5173e-02, -1.5688e-01, -1.4359e-01, 2.8132e-01,
2.7926e-01, -1.7235e+00, 2.6881e-01, 3.0697e-01, -7.6701e-02,
1.9857e-01, 2.2061e-01, -1.9372e-01, 1.7258e-01, 2.1779e-01,
1.3777e-02, 2.5309e-01, -2.8005e+00, -9.4578e-03, 4.5968e-01,
2.9711e-01, -1.7701e-01, 1.1434e-01, 1.4672e-02, -1.6135e-02,
-2.0751e-01, -1.0493e-01, -2.9138e+00, 3.2780e-01, -2.0744e-01,
-5.3776e-03, 2.7753e-01, 2.3219e-01, 3.0444e-01, -1.0228e+00,
2.5998e-01, 4.0859e-02, 1.6085e-01, -4.5444e-01, 4.5938e-01,
5.1375e-02, -2.0712e-01, -1.4677e-01, -2.2085e-01, 3.8048e-01,
-1.4932e-01, 9.7451e-02, -9.5515e-01, -4.0730e-02, 3.1863e-02,
3.4480e-01, -3.2218e-02, -1.8000e-01, -3.8381e-02, -1.3159e-01,
2.7847e-01, -9.5910e-02, 4.9402e-03, 1.0162e-01, -2.2593e-01,
-2.4486e-01, -1.1703e-01, 2.7868e-01, -9.5483e-02, -4.9591e-01],
[ 9.7559e-02, 9.9721e-02, 5.2562e-03, -3.3478e-02, -3.9585e-02,
-3.3170e-02, -4.6732e-02, -3.7832e-02, -8.7473e-02, -4.2170e-02,
-1.4273e-02, 7.4447e-02, -3.3671e-03, -1.0403e-02, 3.5913e-02,
-8.3806e-02, -1.1356e-02, 5.2117e-02, -6.0668e-02, 2.8677e-02,
6.5223e-02, -4.0261e-02, 9.8021e-02, -4.6573e-02, -3.5789e-02,
1.7325e-02, 9.3751e-02, -6.9879e-02, -1.8653e-02, -2.2977e-02,
-7.5078e-02, 2.3377e-02, -5.4457e-02, 7.0035e-02, -9.3567e-02,
-6.4344e-02, -8.1320e-02, 3.9031e-02, -2.8538e-02, 9.3248e-03,
7.6998e-02, 6.4107e-02, -4.4981e-02, -8.2442e-02, 7.3602e-02,
8.0456e-02, 3.3351e-02, -8.0334e-02, -5.1095e-02, -1.8214e-03,
-4.5014e-02, -8.9755e-02, 8.5262e-02, -1.8515e-02, 2.2350e-02,
3.9358e-02, -4.0820e-02, -5.2095e-02, 4.6096e-02, 1.8926e-03,
-5.4466e-02, 3.1584e-02, -1.6151e-02, -9.0292e-02, 8.3092e-02,
-9.3269e-02, 1.4231e-02, -8.2068e-02, -3.7458e-02, -3.5819e-02,
5.7825e-02, 7.0045e-02, 7.6307e-02, -9.2453e-02, 1.8708e-02,
6.1323e-02, -4.1380e-02, -3.1388e-02, 4.4657e-02, -2.4942e-02,
1.5333e-03, 4.4527e-02, -1.8141e-03, 6.9774e-03, 6.7264e-02,
-8.7481e-02, -9.8603e-03, -6.1178e-02, -2.7478e-02, -1.6766e-02,
-1.6482e-02, -4.3876e-02, 4.6738e-02, 6.9412e-02, -3.9773e-02,
8.4581e-02, -4.7432e-02, 1.6062e-02, -2.8793e-02, -9.7105e-02],
[ 3.4066e-01, -2.2669e-01, 2.0638e-01, 3.0436e-01, -1.4012e-01,
-2.1195e-01, 2.0294e-01, 3.3177e-01, -2.0788e-01, -1.6617e-01,
1.8914e-01, 3.5530e-01, 2.5570e-01, 1.6404e-01, 2.2100e-01,
1.3264e-01, 1.9906e-01, 2.3174e-01, 1.7208e-01, 2.1081e-01,
1.3878e-01, 1.0688e-01, -1.1241e-01, 2.8781e-01, 1.7327e-01,
1.8737e-01, -1.4073e-01, 2.9072e-01, 3.2045e-02, -2.0583e-01,
1.1135e-03, 2.4074e-01, 3.5334e-01, 3.2847e-01, 8.3936e-02,
2.1050e-01, 3.2807e-01, 3.0017e-01, 1.3213e-01, 1.6703e-01,
2.6032e-01, -1.1563e-01, 2.3754e-01, 6.4110e-03, -8.7744e-02,
2.5765e-01, 1.4098e-01, 2.8086e-01, -3.2349e-02, 4.8971e-02,
-2.6792e-01, 3.1783e-01, -1.2590e-01, -1.5280e-03, -1.8023e-01,
6.4615e-02, 2.4404e-01, 2.1271e-01, 1.5348e-01, -1.0202e-01,
-3.1126e-01, 3.0044e-01, -1.0591e-01, 1.9645e-02, 1.8606e-01,
-1.3696e-01, -5.1265e-03, 1.3365e-01, 3.2478e-01, -7.2510e-01,
-4.9263e-02, -1.3800e-01, 2.6668e-01, -4.0788e-02, 1.7252e-01,
-1.2523e-01, 2.2963e-01, 2.2844e-01, 2.7504e-01, 2.0454e-01,
-2.1326e-01, -3.8035e-02, -3.3890e-02, -1.1165e-01, 6.5699e-02,
3.0182e-01, -1.4481e-01, -1.6535e-01, -1.5763e-01, 2.1423e-01,
-1.6284e-02, -2.3269e-01, 1.1910e-02, 2.7782e-01, 3.1665e-01,
-5.9026e-02, 2.0754e-01, 4.0115e-02, 2.6167e-01, -5.7767e-02],
[-1.1849e-01, -7.7969e-03, -1.5661e-01, -6.1125e-02, 2.2625e-02,
-9.1237e-02, 4.0104e-02, -6.6335e-02, 2.6674e-02, -1.2272e-01,
-2.2597e-02, 4.3267e-02, -8.9480e-02, -6.1147e-02, -1.0059e-02,
-1.3936e-01, -8.6508e-02, -2.2515e-02, -9.3664e-02, -6.6825e-03,
-1.0201e-01, -1.0720e-01, -5.0068e-02, -1.1679e-01, 3.0857e-02,
-1.1464e-01, -5.6722e-02, -1.3030e-01, -1.5132e-01, 4.6598e-02,
-4.3135e-02, 2.5682e-03, 4.3127e-03, -8.0890e-02, -6.5783e-02,
-7.9201e-02, -9.9236e-02, -1.4267e-02, -9.8268e-04, -4.1820e-03,
-1.1445e-01, -1.5028e-01, -2.5004e-02, -2.6387e-03, -1.2209e-01,
-1.5330e-02, -1.5169e-01, -1.5652e-01, -1.0176e-01, -1.7024e-02,
3.9792e-02, -1.5171e-01, 3.4623e-02, -2.0488e-02, -4.3970e-03,
-1.4869e-01, -4.2210e-02, 3.5327e-02, 2.2972e-02, -7.8951e-02,
2.4533e-02, -9.5509e-02, -5.1651e-02, -8.3961e-02, -1.8472e-02,
4.2319e-02, -2.7645e-02, 9.0295e-03, 2.2153e-02, -6.5848e-02,
1.6244e-02, 9.4096e-03, -1.3045e-02, -3.8044e-02, -2.4185e-02,
-1.3695e-01, 1.5228e-02, -1.0525e-02, 4.6720e-03, -2.6101e-02,
-4.3436e-03, -4.6502e-02, -9.1774e-02, -1.1180e-01, 1.8824e-03,
6.0288e-04, -1.1517e-01, 1.9949e-02, -1.0356e-01, -1.3704e-01,
-8.3517e-02, -6.9738e-02, 5.1833e-02, -1.4469e-01, 3.5419e-02,
-1.2094e-01, 1.7723e-02, 6.9318e-03, -1.4521e-01, -1.0900e-01],
[ 3.3103e-02, -2.9652e-02, 1.9098e-02, -1.1797e-01, -8.8985e-03,
-9.5452e-02, 1.2919e-02, -1.7926e-02, -4.0095e-02, -2.6798e-02,
-8.7017e-02, -1.0745e-01, -3.8287e-02, -9.4597e-02, -2.5650e-02,
2.9995e-02, -7.4850e-02, -5.7272e-02, 2.8965e-02, -1.4841e-01,
-1.3993e-01, -8.7357e-02, -9.6490e-02, 3.6011e-02, -9.7039e-02,
3.8971e-02, 1.4832e-03, -1.3230e-01, -6.2370e-02, -4.9680e-02,
5.6635e-02, -6.6010e-02, -2.8427e-02, 1.0241e-02, -1.0028e-01,
-1.5885e-01, -1.2632e-01, -8.8555e-02, -1.2777e-01, -2.6624e-02,
-2.8859e-02, 8.7423e-02, 2.8555e-02, -1.5382e-01, -3.9508e-02,
-1.2824e-01, -1.0591e-01, 1.5083e-02, 4.9543e-02, 2.0838e-02,
9.7050e-02, -1.2707e-01, -1.0845e-01, -7.4149e-02, -9.0393e-02,
-1.5021e-01, -6.9915e-02, -7.3932e-02, -1.5561e-01, -2.6250e-02,
-4.8112e-02, 3.1877e-02, -8.8039e-02, 6.8075e-02, 7.8870e-03,
6.6852e-02, 2.2080e-02, -2.7804e-02, -1.2938e-01, -3.1030e-02,
-6.7066e-02, -1.7240e-03, -3.2691e-03, -4.2015e-02, -3.4847e-02,
-3.3286e-02, -4.4260e-02, -3.7307e-02, -1.4968e-01, -1.4191e-01,
7.5149e-02, -9.8460e-02, -4.1743e-02, -1.4487e-01, -4.9243e-02,
-1.5390e-01, 2.2282e-03, -5.1129e-02, -5.0759e-02, -1.5388e-01,
-4.8148e-02, 2.9664e-02, -3.4954e-02, -3.9795e-03, -2.4351e-02,
4.6323e-02, -6.1543e-02, 3.4615e-02, 3.9667e-02, 5.8227e-03],
[-4.6510e-02, -4.9032e-02, -7.1243e-02, -1.5057e-02, -3.9990e-02,
-8.6335e-02, -5.9020e-02, 3.9320e-02, -8.6666e-02, 7.7752e-02,
-3.2762e-02, -6.9965e-02, 4.2544e-02, 1.4781e-03, 2.6269e-02,
-7.7447e-02, 6.3324e-02, 3.9288e-02, -9.1361e-03, -5.1340e-02,
7.9445e-03, 2.2402e-02, -5.4517e-02, 4.5531e-02, -4.2306e-02,
-8.5663e-02, -4.0078e-02, -4.6542e-02, 2.7651e-02, 4.3961e-02,
9.8856e-03, -1.1676e-02, -2.2226e-04, -7.7728e-03, -1.2191e-02,
-9.8001e-02, -6.7645e-02, 5.0356e-02, 6.0369e-02, -8.4953e-02,
-1.0358e-01, 3.1401e-02, -8.3145e-02, -3.7875e-02, -7.1154e-02,
-1.2621e-01, -7.8417e-02, -6.1136e-02, 2.8972e-02, -4.8982e-02,
-9.5655e-02, -1.0762e-01, 1.7080e-02, 2.9777e-03, 9.8989e-04,
-6.3690e-02, -3.9715e-02, -2.5035e-02, 8.5934e-02, -2.7804e-02,
-6.8254e-02, -8.5664e-03, -2.6031e-02, -2.1849e-02, 6.0124e-02,
-6.5162e-03, 5.4263e-03, 2.6216e-02, 5.9158e-02, -9.5998e-02,
2.9186e-02, -3.9914e-02, -2.2287e-02, -8.0053e-02, -6.2409e-03,
-3.5394e-02, 6.0710e-02, -4.0315e-02, 2.9293e-02, -6.5881e-02,
-8.5151e-02, -8.0181e-02, 5.1541e-03, -7.7113e-04, 8.5335e-02,
-6.1182e-02, 1.4217e-04, -5.8660e-02, 4.9669e-02, 1.9340e-03,
-6.0617e-02, 7.2222e-02, 9.6410e-02, -6.6794e-02, 8.3251e-02,
-3.4567e-02, -5.1212e-02, 5.6342e-03, -7.7391e-02, -8.5255e-02],
[ 7.3992e-03, -9.2201e-02, -7.6840e-02, -5.7200e-04, -3.3702e-02,
1.6946e-02, -6.7936e-02, -1.5845e-01, -1.4161e-01, -9.3280e-02,
-9.1885e-02, 3.6732e-03, -6.5342e-03, 7.9814e-03, -9.3123e-02,
-1.4473e-01, -1.2320e-01, -6.3563e-02, -1.5855e-01, 2.1975e-03,
-1.3163e-01, -3.9574e-03, -1.5698e-01, 7.1899e-03, -1.4187e-01,
3.4146e-03, 3.8125e-02, 2.5843e-03, -1.3079e-02, -1.3204e-01,
8.3786e-02, -2.2120e-02, 3.9271e-02, -1.2596e-01, 1.2143e-02,
-8.5971e-02, -1.2114e-01, -1.4848e-01, -1.2931e-01, 4.6192e-03,
9.3149e-03, 4.7038e-02, -1.3509e-01, -1.3677e-01, 5.4563e-02,
-4.3882e-02, -9.9015e-02, -1.4939e-01, 2.6642e-02, -1.7332e-02,
4.0257e-03, 1.8777e-02, 9.0937e-04, -1.8833e-02, -9.0741e-02,
-1.0528e-01, -1.5416e-02, -6.9905e-03, -8.7717e-02, -4.0205e-02,
1.7305e-02, -9.7544e-02, 4.3502e-02, -1.0028e-01, -8.7938e-02,
-4.0312e-02, 1.5484e-02, 2.4350e-02, -6.2460e-02, -9.3478e-03,
-9.1472e-02, -1.2269e-01, -5.5012e-02, 8.3711e-02, -9.9640e-02,
2.4393e-03, -7.6099e-02, -4.0428e-02, -1.6589e-02, -1.3674e-01,
-7.4197e-02, 6.2661e-02, 3.1131e-02, -8.8770e-02, 8.4476e-02,
-1.4011e-01, -7.4726e-02, -6.8375e-03, -1.0614e-01, 8.3319e-02,
-1.1718e-01, 4.0877e-02, -3.3579e-02, 1.2562e-02, 2.6064e-02,
-2.0993e-02, -6.6200e-02, -4.2079e-02, -3.9781e-02, 3.0279e-02],
[-9.4463e-03, 2.5273e-01, -7.4421e-01, 2.3654e-01, 2.7598e-01,
3.1193e-01, 7.2331e-01, -2.5938e-01, 1.8057e-01, 2.2860e-02,
-4.2454e-02, 4.6235e-02, 5.6263e-01, -1.7237e+00, 7.5630e-02,
1.7184e-01, 1.2490e+00, 2.7224e-01, -1.0309e-01, -4.6537e-02,
1.3509e-01, 1.6600e-01, 2.8930e-01, 6.5499e-02, -1.6997e+00,
1.7876e-01, -1.2066e-01, -5.9579e-01, -4.4608e-02, 1.5029e-01,
2.5303e-01, -3.9610e-01, -3.0922e+00, -5.0539e-01, 4.2130e-01,
-5.0266e-02, -4.2651e-01, 2.8539e-01, -1.4207e-01, 1.1180e-01,
-1.1013e-03, 3.2779e-03, 7.6952e-02, 1.9291e-01, 1.6175e-01,
1.9929e-01, 1.4937e-01, 1.3756e-01, 1.7965e-01, 7.7044e-02,
1.9019e-01, -1.5527e-01, 2.6779e-01, 9.8191e-02, 1.2652e+00,
1.6396e-01, 1.3644e-01, -3.7035e-01, -5.6507e-02, 1.2423e-01,
1.2563e-01, -1.2620e-01, 2.7061e-01, 1.7234e-01, -5.9522e-03,
9.2409e-02, 1.3593e-01, 1.0498e-01, 3.6392e-02, 3.4221e-01,
1.9488e-01, 2.6107e-01, -3.8584e-01, 1.2479e-01, -4.2127e-03,
-3.8013e-01, -4.2373e-01, 2.4846e-01, 2.1683e-01, 6.1581e-01,
1.6604e-01, -1.7717e-01, 2.8340e-01, 2.7094e-01, -3.8997e-01,
5.1320e-02, 2.4170e-01, 1.7787e-01, 2.5139e-01, -4.2393e-01,
2.0944e-01, 2.7613e-01, -7.9983e-02, -2.9780e-01, -1.7064e-01,
7.5258e-02, -4.3994e+00, 1.9829e-01, -1.4221e+00, 5.4893e-03],
[-1.4490e-01, 3.0999e-01, -1.4755e-01, -2.7093e-02, 2.6288e-01,
2.5813e-01, -5.3249e-01, -6.2880e-03, 2.9064e-01, 1.6039e-01,
3.5534e-01, -1.2712e-02, -3.4056e-01, 1.4612e+00, -1.5646e-01,
-1.1143e-01, -4.6872e-03, -3.3256e-02, 2.0862e-01, -5.4929e-02,
-1.0332e-01, -1.9081e-01, 2.1010e-01, 4.3397e-02, 9.1897e-01,
-1.8204e-01, -2.8244e-01, -9.0974e-03, 2.0856e-01, 1.2192e-01,
2.2300e-01, -9.8558e-02, 3.4543e-03, -1.3415e-01, -4.9014e-01,
-1.3772e-01, -8.4227e-02, -8.6236e-02, -1.8728e-02, -3.6871e-01,
-5.5334e-01, 1.8979e-01, -2.3790e-01, 9.2382e-02, 2.7006e-01,
-2.4869e-01, 1.0566e-01, -1.6615e-01, -5.9913e-03, -1.0295e-01,
1.4341e-01, -2.9302e-02, 2.7770e-01, 1.1276e-01, -1.5408e+00,
5.1996e-02, -1.5886e-01, -1.3873e-01, 1.8660e-01, 2.2475e-01,
2.3754e-01, 1.5193e-03, 1.9058e-01, -1.3767e-01, -5.6253e-02,
1.5812e-01, -3.7213e-02, -1.1783e-01, -3.6599e-01, 2.8579e-01,
2.5834e-02, 1.7404e-01, -5.1877e-02, 3.0329e-01, -5.2061e-01,
3.4646e-01, 2.9267e+00, 2.3889e-01, -4.5155e-02, -5.4413e-01,
2.7813e-01, 2.6674e-01, 1.9599e-01, 2.0122e-01, 4.6119e-01,
-1.9064e-01, 2.5398e-01, 1.7401e-01, 1.9930e-01, 4.4993e-01,
9.3427e-02, 1.5471e-01, 4.2120e-02, 2.0102e+00, -1.2324e-01,
2.9004e-01, 2.0032e+00, -2.1398e-02, 7.0167e-01, 2.7802e-01],
[-9.8567e-02, -1.8608e-02, 9.8736e-03, -5.7179e-02, -8.6996e-02,
-8.7135e-02, 8.6413e-02, 2.4692e-02, 2.5092e-02, 3.0335e-02,
-4.2105e-02, -2.8809e-02, -4.4483e-02, 8.1848e-02, 2.9793e-02,
-8.0890e-02, 2.2661e-02, 9.8140e-02, 1.3173e-02, -1.3363e-02,
4.7422e-02, -5.3408e-02, -5.8562e-02, -4.1046e-02, -6.2692e-04,
1.1508e-02, -6.1444e-02, 2.8698e-02, 6.9115e-02, -2.2743e-02,
-3.5880e-02, -3.3627e-02, 2.8998e-02, -9.5922e-02, -9.2932e-02,
-9.6507e-02, -3.5611e-02, 5.4674e-02, -4.7794e-02, -8.3346e-02,
-6.8014e-02, 1.2078e-02, 1.6166e-02, -3.2733e-02, -9.1847e-02,
8.4442e-02, -9.3683e-02, 9.7822e-02, -5.6965e-02, -2.0141e-02,
-8.2946e-02, -1.8858e-02, 3.7601e-02, -6.8778e-02, 4.4359e-02,
5.0159e-02, -6.2003e-02, 7.2751e-02, -5.8908e-02, -5.0894e-02,
2.9734e-02, -2.5002e-02, -2.3154e-02, -3.2335e-02, -4.4888e-02,
2.6441e-02, -9.5169e-03, 8.6024e-03, -3.9696e-02, 6.8351e-02,
-6.1218e-03, 4.0225e-02, -4.4624e-02, 5.2940e-02, 4.8544e-02,
-6.7251e-02, -8.8407e-02, -1.8300e-03, 8.6462e-03, 1.4661e-02,
-9.3749e-02, -2.8587e-02, -9.6928e-02, 8.7356e-02, -9.0838e-02,
-7.9566e-02, -9.5027e-02, 1.8029e-02, 1.9667e-02, 9.0467e-02,
-8.4525e-03, -7.6852e-02, 7.9725e-03, 3.3333e-02, 8.2368e-02,
8.3399e-03, -3.9683e-03, -6.3617e-02, 2.6470e-02, -6.8539e-02]])
linear.2.bias tensor([-3.1027e-01, 4.5636e-02, 2.5239e-02, 2.4211e-02, -3.0502e-04,
2.1908e-02, -9.0223e-02, 4.4743e-01, 2.0019e-01, 6.7779e-02])
linear.4.weight tensor([[-0.3507, 0.2981, -0.1999, -0.1106, -0.1271, 0.1342, -0.2358, 0.1883,
0.5624, 0.1241],
[ 0.5603, 0.2650, 0.0084, -0.2541, 0.2511, 0.1093, 0.0400, 0.0160,
-0.3641, -0.0306],
[-0.0265, -0.2103, -0.2343, -0.2293, 0.0729, -0.0196, -0.0141, -0.0164,
0.3967, 0.0959],
[-0.2783, 0.2216, 0.2876, -0.0598, 0.3290, -0.2400, 0.3703, -0.3434,
0.2738, -0.2622],
[ 0.2651, 0.1938, -0.0362, -0.0228, -0.2329, -0.1324, 0.2049, 0.3382,
-0.1831, 0.1810],
[-0.0159, 0.3087, -0.7259, -0.1561, -0.0213, 0.2309, 0.2505, 0.5566,
-0.2723, -0.0016],
[-1.0350, 0.0426, -0.1839, 0.1457, 0.2306, 0.2729, 0.1060, -0.0189,
0.5539, 0.1271],
[-0.2176, -0.0341, -0.1587, 0.0922, 0.2077, 0.2225, 0.1486, -0.3146,
-0.0643, -0.0218],
[-0.2164, 0.2170, -0.4762, -0.0297, 0.0013, -0.1703, 0.1134, 0.5581,
-0.5296, -0.1345],
[ 0.0289, -0.2450, 0.4034, 0.2436, -0.2010, -0.2607, -0.0513, -0.1674,
0.1470, -0.0469]])
linear.4.bias tensor([ 0.4209, -0.4941, 0.3248, 0.2670, 0.5732, 0.1356, -0.1319, 0.1554,
0.3631, 0.1190])
linear.6.weight tensor([[-0.7229, 0.1407, -0.5544, -0.0383, 0.2408, 0.4584, -4.2137, 0.1853,
0.5194, -0.3519],
[ 0.4638, -0.1626, 0.2902, -0.3500, -0.1614, -0.5834, 0.3103, -0.1113,
-0.3584, -0.0282],
[-0.3584, 0.1217, -0.3695, 0.5293, -0.4095, 0.0567, -0.0190, -0.1553,
-0.3932, 0.4102],
[-0.4371, 0.4393, -0.1164, -0.0548, 0.2365, -0.5351, -0.0148, 0.2300,
-0.4858, 0.0137],
[ 0.4004, -0.6929, -0.0416, -0.6043, 0.3351, 0.1766, -0.0683, 0.0609,
0.3810, -0.0724]])
linear.6.bias tensor([-0.2300, -0.2272, 0.1777, 0.0780, 0.5216])
from sklearn.metrics import accuracy_score
# Convert the model predictions to numpy array
with torch.no_grad():
= model(X_tensor)
predictions = torch.argmax(predictions, dim=1).numpy()
predicted_labels
# Calculate the accuracy
= accuracy_score(y, predicted_labels)
accuracy print("Accuracy:", accuracy)
Accuracy: 0.983
# Generate a grid of points
= X[:, 0].min() - 1, X[:, 0].max() + 1
x_min, x_max = X[:, 1].min() - 1, X[:, 1].max() + 1
y_min, y_max = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
xx, yy = np.c_[xx.ravel(), yy.ravel()]
grid_points
# Convert the grid points to PyTorch tensor
= torch.tensor(grid_points, dtype=torch.float32)
grid_tensor
# Use the trained model to predict the class labels for the grid points
with torch.no_grad():
= model(grid_tensor)
predictions = torch.argmax(predictions, dim=1).numpy().reshape(xx.shape)
labels
# Create a new figure and an axes
= plt.subplots()
fig, ax
# Plot the decision boundary
=0.5, cmap=plt.cm.Set1)
ax.contourf(xx, yy, labels, alpha
# Plot the generated data
= ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1)
scatter
# Set labels and title
"Feature 1")
ax.set_xlabel("Feature 2")
ax.set_ylabel("Decision Boundary")
ax.set_title(
# Add a grid
True, color="lightgrey")
ax.grid(
# Display the plot
plt.show()
Image Classification with KMNIST
See Week 13!