#!/usr/bin/env python3 """ infer_seq_webcam.py Live webcam demo: detect a hand with MediaPipe, normalize landmarks, classify with a trained sequence GRU model (multiclass). Examples: python infer_seq_webcam.py --model asl_seq32_gru_ABJZ.pt --threshold 0.8 --smooth 0.7 python infer_seq_webcam.py --model asl_seq32_gru_ABJZ.pt --threshold 0.85 --smooth 1.0 --url https://www.google.com """ import os, math, argparse, time, webbrowser import numpy as np import cv2 import torch import mediapipe as mp # --- Quiet logs --- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" os.environ["GLOG_minloglevel"] = "2" import absl.logging absl.logging.set_verbosity(absl.logging.ERROR) cv2.setLogLevel(0) # ---------- geometry helpers ---------- def _angle(v): return math.atan2(v[1], v[0]) def _rot2d(t): c, s = math.cos(t), math.sin(t) return np.array([[c, -s], [s, c]], dtype=np.float32) def normalize_landmarks(pts, handedness_label=None): """ pts: (21,3) MediaPipe normalized coords in [0..1] Steps: translate wrist->origin, mirror left to right, rotate to +Y, scale by max pairwise distance. Returns: (63,) float32 """ pts = pts.astype(np.float32).copy() pts[:, :2] -= pts[0, :2] if handedness_label and handedness_label.lower().startswith("left"): pts[:, 0] *= -1.0 v = pts[9, :2] # middle MCP R = _rot2d(math.pi/2 - _angle(v)) pts[:, :2] = pts[:, :2] @ R.T xy = pts[:, :2] d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max() d = 1.0 if d < 1e-6 else float(d) pts[:, :2] /= d; pts[:, 2] /= d return pts.reshape(-1) # ---------- sequence model ---------- class SeqGRU(torch.nn.Module): def __init__(self, input_dim=63, hidden=128, num_classes=26): super().__init__() self.gru = torch.nn.GRU(input_dim, hidden, batch_first=True, bidirectional=True) self.head = torch.nn.Sequential( torch.nn.Linear(hidden*2, 128), torch.nn.ReLU(), torch.nn.Dropout(0.2), torch.nn.Linear(128, num_classes), ) def forward(self, x): h, _ = self.gru(x) # (B,T,2H) h_last = h[:, -1, :] # or h.mean(1) return self.head(h_last) # ---------- main ---------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True, help="Path to trained .pt model") ap.add_argument("--camera", type=int, default=0) ap.add_argument("--threshold", type=float, default=0.8) ap.add_argument("--smooth", type=float, default=0.7, help="EMA smoothing window in seconds (0 disables smoothing)") ap.add_argument("--width", type=int, default=640) ap.add_argument("--height", type=int, default=480) ap.add_argument("--url", type=str, default="https://www.google.com", help="URL to open when the sequence W→E→B is detected") args = ap.parse_args() if not os.path.exists(args.model): raise SystemExit(f"❌ Model file not found: {args.model}") # Load checkpoint (support numpy or tensor stats; support 'frames' if present) state = torch.load(args.model, map_location="cpu", weights_only=False) classes = state["classes"] T = int(state.get("frames", 32)) X_mean, X_std = state["X_mean"], state["X_std"] if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy() if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy() X_mean = X_mean.astype(np.float32) X_std = (X_std.astype(np.float32) + 1e-6) device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") model = SeqGRU(63, 128, num_classes=len(classes)).to(device) model.load_state_dict(state["model"]) model.eval() hands = mp.solutions.hands.Hands( static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5 ) cap = cv2.VideoCapture(args.camera) if not cap.isOpened(): raise SystemExit(f"❌ Could not open camera index {args.camera}") cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.width) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height) print(f"✅ Loaded {args.model} frames={T} classes={classes}") print("Press 'q' to quit.") seq_buffer, ema_probs = [], None last_ts = time.time() last_emitted_letter = None # Rolling history of emitted letters to detect the sequence "WEB" detected_history = [] # only stores emitted letters (deduped by change) while True: ok, frame = cap.read() if not ok: break now = time.time() dt = max(1e-6, now - last_ts) last_ts = now rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) res = hands.process(rgb) overlay_text = "No hand" current_letter = None if res.multi_hand_landmarks: ih = res.multi_hand_landmarks[0] handed = None if res.multi_handedness: handed = res.multi_handedness[0].classification[0].label pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32) feat = normalize_landmarks(pts, handedness_label=handed) seq_buffer.append(feat) if len(seq_buffer) > T: seq_buffer.pop(0) if len(seq_buffer) == T: X = np.stack(seq_buffer, 0) Xn = (X - X_mean) / X_std xt = torch.from_numpy(Xn).float().unsqueeze(0).to(device) with torch.no_grad(): logits = model(xt) probs = torch.softmax(logits, dim=1)[0].cpu().numpy() if args.smooth > 0: alpha = 1.0 - math.exp(-dt / args.smooth) if ema_probs is None: ema_probs = probs else: ema_probs = (1.0 - alpha) * ema_probs + alpha * probs use_probs = ema_probs else: use_probs = probs top_idx = int(np.argmax(use_probs)) top_p = float(use_probs[top_idx]) top_cls = classes[top_idx] if top_p >= args.threshold: overlay_text = f"{top_cls} {top_p*100:.1f}%" current_letter = top_cls else: seq_buffer, ema_probs = [], None # Only emit when a *letter* changes (ignore no-hand and repeats) if current_letter is not None and current_letter != last_emitted_letter: print(f"Detected: {current_letter}") last_emitted_letter = current_letter # Update rolling history detected_history.append(current_letter) if len(detected_history) > 3: detected_history.pop(0) # Check for special sequence "WEB" if detected_history == ["W", "E", "B"]: print("🚀 Detected WEB! Time to open the web browser app.") try: webbrowser.open(args.url) except Exception as e: print(f"⚠️ Failed to open browser: {e}") detected_history.clear() # fire once per occurrence # On-screen overlay (still shows "No hand" when nothing is detected) cv2.putText(frame, overlay_text, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,255,0), 2) cv2.imshow("ASL sequence demo", frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() if __name__ == "__main__": main()