Files
2026-01-19 22:27:20 -05:00

137 lines
5.1 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# train_seq.py
import os, json, argparse
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
def get_device():
return torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
class SeqDataset(Dataset):
def __init__(self, X, y, augment=False):
self.X = X.astype(np.float32) # (Nclip, T, 63)
self.y = y.astype(np.int64)
self.augment = augment
def __len__(self): return len(self.y)
def _augment(self, seq): # seq: (T,63)
T = seq.shape[0]
pts = seq.reshape(T, 21, 3).copy()
# small 2D rotation (±7°) + scale (±10%) + Gaussian noise (σ=0.01)
ang = np.deg2rad(np.random.uniform(-7, 7))
c, s = np.cos(ang), np.sin(ang)
R = np.array([[c,-s],[s,c]], np.float32)
scale = np.random.uniform(0.9, 1.1)
pts[:, :, :2] = (pts[:, :, :2] @ R.T) * scale
pts += np.random.normal(0, 0.01, size=pts.shape).astype(np.float32)
return pts.reshape(T, 63)
def __getitem__(self, i):
xi = self.X[i]
if self.augment:
xi = self._augment(xi)
return torch.from_numpy(xi).float(), int(self.y[i])
class SeqGRU(nn.Module):
def __init__(self, input_dim=63, hidden=128, num_classes=26):
super().__init__()
self.gru = nn.GRU(input_dim, hidden, batch_first=True, bidirectional=True)
self.head = nn.Sequential(
nn.Linear(hidden*2, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes),
)
def forward(self, x): # x: (B,T,63)
h,_ = self.gru(x) # (B,T,2H)
h_last = h[:, -1, :] # or mean over time: h.mean(1)
return self.head(h_last)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--landmarks", default="landmarks_seq32", help="Folder from prep_sequence_resampled.py")
ap.add_argument("--epochs", type=int, default=40)
ap.add_argument("--batch", type=int, default=64)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--out", default="asl_seq32_gru.pt")
args = ap.parse_args()
# Load dataset
trX = np.load(os.path.join(args.landmarks,"train_X.npy")) # (N, T, 63)
trY = np.load(os.path.join(args.landmarks,"train_y.npy"))
vaX = np.load(os.path.join(args.landmarks,"val_X.npy"))
vaY = np.load(os.path.join(args.landmarks,"val_y.npy"))
classes = json.load(open(os.path.join(args.landmarks,"class_names.json")))
meta = json.load(open(os.path.join(args.landmarks,"meta.json")))
T = int(meta["frames"])
print(f"Loaded: train {trX.shape} val {vaX.shape} classes={classes}")
# Global mean/std over train (time+batch)
X_mean = trX.reshape(-1, trX.shape[-1]).mean(axis=0, keepdims=True).astype(np.float32) # (1,63)
X_std = trX.reshape(-1, trX.shape[-1]).std(axis=0, keepdims=True).astype(np.float32) + 1e-6
trXn = (trX - X_mean) / X_std
vaXn = (vaX - X_mean) / X_std
tr_ds = SeqDataset(trXn, trY, augment=True)
va_ds = SeqDataset(vaXn, vaY, augment=False)
tr_dl = DataLoader(tr_ds, batch_size=args.batch, shuffle=True)
va_dl = DataLoader(va_ds, batch_size=args.batch, shuffle=False)
device = get_device()
model = SeqGRU(input_dim=63, hidden=128, num_classes=len(classes)).to(device)
crit = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
best_acc, best_state = 0.0, None
for epoch in range(1, args.epochs+1):
# Train
model.train()
tot, correct, loss_sum = 0, 0, 0.0
for xb, yb in tr_dl:
xb, yb = xb.to(device), yb.to(device)
opt.zero_grad(set_to_none=True)
logits = model(xb)
loss = crit(logits, yb)
loss.backward()
opt.step()
loss_sum += loss.item() * yb.size(0)
correct += (logits.argmax(1)==yb).sum().item()
tot += yb.size(0)
tr_loss = loss_sum / max(1, tot)
tr_acc = correct / max(1, tot)
# Validate
model.eval()
vtot, vcorrect = 0, 0
with torch.no_grad():
for xb, yb in va_dl:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
vcorrect += (logits.argmax(1)==yb).sum().item()
vtot += yb.size(0)
va_acc = vcorrect / max(1, vtot)
sch.step()
print(f"Epoch {epoch:02d}: train_loss={tr_loss:.4f} train_acc={tr_acc:.3f} val_acc={va_acc:.3f}")
if va_acc > best_acc:
best_acc = va_acc
best_state = {
"model": model.state_dict(),
"classes": classes,
"frames": T,
"X_mean": torch.from_numpy(X_mean), # tensors → future-proof
"X_std": torch.from_numpy(X_std),
}
torch.save(best_state, args.out)
print(f" ✅ Saved best → {args.out} (val_acc={best_acc:.3f})")
print("Done. Best val_acc:", best_acc)
if __name__ == "__main__":
main()