71 lines
3.7 KiB
Python
71 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
# Evaluate a trained SeqGRU on the validation set; reads input_dim from meta.json
|
|
|
|
import os, json, argparse # stdlib
|
|
import numpy as np # arrays
|
|
import torch, torch.nn as nn # model
|
|
from sklearn.metrics import classification_report, confusion_matrix # metrics
|
|
|
|
class SeqGRU(nn.Module):
|
|
"""
|
|
BiGRU classifier head:
|
|
GRU(input_dim → hidden, bidirectional) → Linear/ReLU/Dropout → Linear(num_classes)
|
|
Uses the last time step's hidden state for classification.
|
|
"""
|
|
def __init__(self, input_dim, hidden=128, num_classes=26):
|
|
super().__init__()
|
|
self.gru = nn.GRU(input_dim, hidden, batch_first=True, bidirectional=True) # temporal encoder
|
|
self.head = nn.Sequential( # MLP head
|
|
nn.Linear(hidden*2, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(128, num_classes),
|
|
)
|
|
def forward(self, x):
|
|
h, _ = self.gru(x) # h: (B, T, 2*hidden)
|
|
return self.head(h[:, -1, :]) # take last time step → logits
|
|
|
|
def main():
|
|
"""
|
|
Load val split + model checkpoint, normalize using stored mean/std, run inference,
|
|
then print confusion matrix and classification report.
|
|
"""
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--landmarks", default="landmarks_seq32") # dataset folder
|
|
ap.add_argument("--model", required=True) # .pt checkpoint path
|
|
args = ap.parse_args()
|
|
|
|
vaX = np.load(os.path.join(args.landmarks,"val_X.npy")) # (N, T, F)
|
|
vaY = np.load(os.path.join(args.landmarks,"val_y.npy")) # (N,)
|
|
classes = json.load(open(os.path.join(args.landmarks,"class_names.json"))) # label names
|
|
meta = json.load(open(os.path.join(args.landmarks,"meta.json"))) # frames, input_dim
|
|
T = int(meta.get("frames", vaX.shape[1])) # clip length
|
|
input_dim = int(meta.get("input_dim", vaX.shape[-1])) # feature dimension
|
|
|
|
state = torch.load(args.model, map_location="cpu", weights_only=False) # load checkpoint dict
|
|
X_mean, X_std = state["X_mean"], state["X_std"] # stored normalization stats
|
|
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.numpy() # ensure numpy arrays
|
|
if isinstance(X_std, torch.Tensor): X_std = X_std.numpy()
|
|
X_mean = X_mean.astype(np.float32) # float32 for compute
|
|
X_std = (X_std.astype(np.float32) + 1e-6) # add epsilon for safety
|
|
|
|
vaXn = (vaX - X_mean) / X_std # normalize val features
|
|
|
|
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") # accel if on Mac
|
|
model = SeqGRU(input_dim=input_dim, hidden=128, num_classes=len(classes)) # build model
|
|
model.load_state_dict(state["model"]) # load trained weights
|
|
model.eval().to(device) # eval mode
|
|
|
|
with torch.no_grad(): # no grad for eval
|
|
xb = torch.from_numpy(vaXn).float().to(device) # tensorize val set
|
|
logits = model(xb) # forward pass
|
|
pred = logits.argmax(1).cpu().numpy() # top-1 class indices
|
|
|
|
cm = confusion_matrix(vaY, pred) # confusion matrix
|
|
print("Classes:", classes)
|
|
print("\nConfusion matrix (rows=true, cols=pred):\n", cm)
|
|
print("\nReport:\n", classification_report(vaY, pred, target_names=classes)) # precision/recall/F1
|
|
|
|
if __name__ == "__main__":
|
|
main()
|