Files
slr_handshapes/train_mlp.py
2026-01-19 22:19:15 -05:00

128 lines
4.6 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
train_mlp.py
Train a small MLP on landmarks for a single letter (binary: Letter vs Not_Letter).
Expected workflow:
python prep_landmarks_binary.py --letter A # saves landmarks_A/
python train_mlp.py --letter A --epochs 40 --batch 64
python infer_webcam.py --letter A
"""
import os, json, argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
def get_device():
return torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
class MLP(nn.Module):
def __init__(self, in_dim, num_classes):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, num_classes),
)
def forward(self, x): return self.net(x)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--letter", required=True, help="Target letter (AZ)")
ap.add_argument("--epochs", type=int, default=40)
ap.add_argument("--batch", type=int, default=64)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--landmarks", default=None,
help="Landmarks folder (default: landmarks_<LETTER>)")
ap.add_argument("--out", default=None,
help="Output filename (default: asl_<LETTER>_mlp.pt)")
args = ap.parse_args()
letter = args.letter.upper()
landmarks_dir = args.landmarks or f"landmarks_{letter}"
out_file = args.out or f"asl_{letter}_mlp.pt"
# Load data
trX = np.load(os.path.join(landmarks_dir, "train_X.npy"))
trY = np.load(os.path.join(landmarks_dir, "train_y.npy"))
vaX = np.load(os.path.join(landmarks_dir, "val_X.npy"))
vaY = np.load(os.path.join(landmarks_dir, "val_y.npy"))
with open(os.path.join(landmarks_dir, "class_names.json")) as f:
classes = json.load(f)
print(f"Letter: {letter}")
print(f"Loaded: train {trX.shape} val {vaX.shape} classes={classes}")
# Standardize using train mean/std
X_mean_np = trX.mean(axis=0, keepdims=True).astype(np.float32)
X_std_np = (trX.std(axis=0, keepdims=True) + 1e-6).astype(np.float32)
trXn = (trX - X_mean_np) / X_std_np
vaXn = (vaX - X_mean_np) / X_std_np
# Torch datasets
tr_ds = TensorDataset(torch.from_numpy(trXn).float(), torch.from_numpy(trY).long())
va_ds = TensorDataset(torch.from_numpy(vaXn).float(), torch.from_numpy(vaY).long())
tr_dl = DataLoader(tr_ds, batch_size=args.batch, shuffle=True)
va_dl = DataLoader(va_ds, batch_size=args.batch, shuffle=False)
device = get_device()
model = MLP(in_dim=trX.shape[1], num_classes=len(classes)).to(device)
criterion = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
best_acc, best_state = 0.0, None
for epoch in range(1, args.epochs + 1):
# Train
model.train()
tot, correct, loss_sum = 0, 0, 0.0
for xb, yb in tr_dl:
xb, yb = xb.to(device), yb.to(device)
opt.zero_grad(set_to_none=True)
logits = model(xb)
loss = criterion(logits, yb)
loss.backward()
opt.step()
loss_sum += loss.item() * yb.size(0)
correct += (logits.argmax(1) == yb).sum().item()
tot += yb.size(0)
tr_loss = loss_sum / max(1, tot)
tr_acc = correct / max(1, tot)
# Validate
model.eval()
vtot, vcorrect = 0, 0
with torch.no_grad():
for xb, yb in va_dl:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
vcorrect += (logits.argmax(1) == yb).sum().item()
vtot += yb.size(0)
va_acc = vcorrect / max(1, vtot)
sched.step()
print(f"Epoch {epoch:02d}: train_loss={tr_loss:.4f} train_acc={tr_acc:.3f} val_acc={va_acc:.3f}")
if va_acc > best_acc:
best_acc = va_acc
# Save stats as **tensors** (future-proof for torch.load safety)
best_state = {
"model": model.state_dict(),
"classes": classes,
"X_mean": torch.from_numpy(X_mean_np), # tensor
"X_std": torch.from_numpy(X_std_np), # tensor
}
torch.save(best_state, out_file)
print(f" ✅ Saved best → {out_file} (val_acc={best_acc:.3f})")
print("Done. Best val_acc:", best_acc)
if __name__ == "__main__":
main()