63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Evaluate the trained per-letter model on the saved val split.
|
||
Prints confusion matrix and a classification report.
|
||
|
||
Usage:
|
||
python eval_val.py --letter A
|
||
"""
|
||
import argparse, json
|
||
import numpy as np
|
||
from sklearn.metrics import confusion_matrix, classification_report
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
class MLP(nn.Module):
|
||
def __init__(self, in_dim, num_classes):
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Linear(in_dim,128), nn.ReLU(), nn.Dropout(0.2),
|
||
nn.Linear(128,64), nn.ReLU(), nn.Dropout(0.1),
|
||
nn.Linear(64,num_classes),
|
||
)
|
||
def forward(self, x): return self.net(x)
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--letter", required=True, help="Target letter (A–Z)")
|
||
args = ap.parse_args()
|
||
L = args.letter.upper()
|
||
|
||
# Load val split and classes
|
||
X = np.load(f"landmarks_{L}/val_X.npy")
|
||
y = np.load(f"landmarks_{L}/val_y.npy")
|
||
classes = json.load(open(f"landmarks_{L}/class_names.json"))
|
||
|
||
# Load checkpoint (disable weights-only safety; handle tensor/ndarray)
|
||
state = torch.load(f"asl_{L}_mlp.pt", map_location="cpu", weights_only=False)
|
||
X_mean = state["X_mean"]
|
||
X_std = state["X_std"]
|
||
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy()
|
||
if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy()
|
||
X_mean = np.asarray(X_mean, dtype=np.float32)
|
||
X_std = np.asarray(X_std, dtype=np.float32) + 1e-6
|
||
|
||
model = MLP(X.shape[1], len(classes))
|
||
model.load_state_dict(state["model"])
|
||
model.eval()
|
||
|
||
# Normalize and predict
|
||
Xn = (X - X_mean) / X_std
|
||
with torch.no_grad():
|
||
probs = torch.softmax(model(torch.from_numpy(Xn).float()), dim=1).numpy()
|
||
pred = probs.argmax(axis=1)
|
||
|
||
print("Classes:", classes) # e.g., ['Not_A','A']
|
||
print("\nConfusion matrix (rows=true, cols=pred):")
|
||
print(confusion_matrix(y, pred))
|
||
print("\nReport:")
|
||
print(classification_report(y, pred, target_names=classes, digits=3))
|
||
|
||
if __name__ == "__main__":
|
||
main()
|