Files
slr_handshapes/infer_webcam.py
2026-01-19 22:19:15 -05:00

138 lines
4.7 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
infer_webcam.py
Live webcam demo: detect a hand with MediaPipe, normalize landmarks,
classify with a trained MLP model.
Examples:
python infer_webcam.py --letter A # loads asl_A_mlp.pt
python infer_webcam.py --letter B # loads asl_B_mlp.pt
python infer_webcam.py --model /path/to/asl_A_mlp.pt
Press 'q' to quit.
"""
import os, math, argparse
import numpy as np
import cv2
import torch
import mediapipe as mp
# ---------- geometry helpers ----------
def _angle(v): return math.atan2(v[1], v[0])
def _rot2d(t):
c, s = math.cos(t), math.sin(t)
return np.array([[c, -s], [s, c]], dtype=np.float32)
def normalize_landmarks(pts, handedness_label=None):
pts = pts.astype(np.float32).copy()
# translate wrist to origin
pts[:, :2] -= pts[0, :2]
# mirror left→right
if handedness_label and handedness_label.lower().startswith("left"):
pts[:, 0] *= -1.0
# rotate wrist→middle_mcp to +Y
v = pts[9, :2]
R = _rot2d(math.pi/2 - _angle(v))
pts[:, :2] = pts[:, :2] @ R.T
# scale by max pairwise distance
xy = pts[:, :2]
d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
d = 1.0 if d < 1e-6 else float(d)
pts[:, :2] /= d; pts[:, 2] /= d
return pts.reshape(-1)
# ---------- model ----------
class MLP(torch.nn.Module):
def __init__(self, in_dim, num_classes):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(in_dim, 128),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(128, 64),
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Linear(64, num_classes),
)
def forward(self, x): return self.net(x)
# ---------- main ----------
def main():
ap = argparse.ArgumentParser()
grp = ap.add_mutually_exclusive_group(required=True)
grp.add_argument("--letter", help="Target letter (AZ). Loads asl_<LETTER>_mlp.pt")
grp.add_argument("--model", help="Path to trained .pt model (overrides --letter)")
ap.add_argument("--camera", type=int, default=0, help="OpenCV camera index (default: 0)")
args = ap.parse_args()
# Resolve model path
model_path = args.model
if model_path is None:
letter = args.letter.upper()
model_path = f"asl_{letter}_mlp.pt"
if not os.path.exists(model_path):
raise SystemExit(f"❌ Model file not found: {model_path}")
# Load state (allowing tensors or numpy inside; disable weights-only safety)
state = torch.load(model_path, map_location="cpu", weights_only=False)
classes = state["classes"]
X_mean = state["X_mean"]
X_std = state["X_std"]
# Convert X_mean/X_std to numpy no matter how they were saved
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy()
if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy()
X_mean = np.asarray(X_mean, dtype=np.float32)
X_std = np.asarray(X_std, dtype=np.float32) + 1e-6
model = MLP(63, len(classes))
model.load_state_dict(state["model"])
model.eval()
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model.to(device)
hands = mp.solutions.hands.Hands(
static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5
)
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
raise SystemExit(f"❌ Could not open camera index {args.camera}")
print(f"✅ Loaded {model_path} with classes {classes}")
print("Press 'q' to quit.")
while True:
ok, frame = cap.read()
if not ok: break
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
res = hands.process(rgb)
label_text = "No hand"
if res.multi_hand_landmarks:
ih = res.multi_hand_landmarks[0]
handed = None
if res.multi_handedness:
handed = res.multi_handedness[0].classification[0].label
pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32)
feat = normalize_landmarks(pts, handedness_label=handed)
# standardize
xn = (feat - X_mean.flatten()) / X_std.flatten()
xt = torch.from_numpy(xn).float().unsqueeze(0).to(device)
with torch.no_grad():
probs = torch.softmax(model(xt), dim=1)[0].cpu().numpy()
idx = int(probs.argmax())
label_text = f"{classes[idx]} {probs[idx]*100:.1f}%"
cv2.putText(frame, label_text, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,255,0), 2)
cv2.imshow("ASL handshape demo", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()