Files
slr_handshapes_locations/doc/infer_seq_webcam.py
2026-01-19 22:27:20 -05:00

250 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Live webcam inference for two hands + full face + pose + face-relative hand extras (1670 dims/frame).
Works for letters (A..Z) or word classes (e.g., Mother, Father).
Optionally detects the sequence W → E → B to open a URL.
"""
import os, math, argparse, time, webbrowser # stdlib
import numpy as np # arrays
import cv2 # webcam UI
import torch # inference
import mediapipe as mp # Holistic landmarks
# Quiet logs: reduce console noise from TF/absl/OpenCV
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"; os.environ["GLOG_minloglevel"] = "2"
import absl.logging; absl.logging.set_verbosity(absl.logging.ERROR)
cv2.setLogLevel(0)
mp_holistic = mp.solutions.holistic
# ---------- normalization ----------
def _angle(v):
"""atan2 for 2D vector."""
return math.atan2(v[1], v[0])
def _rot2d(t):
"""2×2 rotation matrix for angle t."""
c, s = math.cos(t), math.sin(t)
return np.array([[c, -s], [s, c]], dtype=np.float32)
def normalize_hand(pts, handed=None):
"""
Wrist-translate, mirror left→right, rotate so middle MCP is +Y, scale by max XY spread.
Returns (21,3).
"""
pts = pts.astype(np.float32).copy()
pts[:, :2] -= pts[0, :2]
if handed and str(handed).lower().startswith("left"): pts[:, 0] *= -1.0
v = pts[9, :2]; R = _rot2d(math.pi/2 - _angle(v))
pts[:, :2] = pts[:, :2] @ R.T
xy = pts[:, :2]; d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
d = 1.0 if d < 1e-6 else float(d)
pts[:, :2] /= d; pts[:, 2] /= d
return pts
def normalize_face(face):
"""Center at eye midpoint, scale by inter-ocular, rotate eye-line horizontal; returns (468,3)."""
f = face.astype(np.float32).copy()
left, right = f[33, :2], f[263, :2]
center = 0.5*(left+right)
f[:, :2] -= center[None, :]
eye_vec = right - left; eye_dist = float(np.linalg.norm(eye_vec)) or 1.0
f[:, :2] /= eye_dist; f[:, 2] /= eye_dist
R = _rot2d(-_angle(eye_vec)); f[:, :2] = f[:, :2] @ R.T
return f
def normalize_pose(pose):
"""Center at shoulder midpoint, scale by shoulder width, rotate shoulders horizontal; returns (33,4)."""
p = pose.astype(np.float32).copy()
ls, rs = p[11, :2], p[12, :2]
center = 0.5*(ls+rs); p[:, :2] -= center[None, :]
sw_vec = rs - ls; sw = float(np.linalg.norm(sw_vec)) or 1.0
p[:, :2] /= sw; p[:, 2] /= sw
R = _rot2d(-_angle(sw_vec)); p[:, :2] = p[:, :2] @ R.T
return p
def face_frame_transform(face_pts):
"""Return (center, eye_dist, R) to project points into the face-normalized frame."""
left = face_pts[33, :2]; right = face_pts[263, :2]
center = 0.5*(left + right)
eye_vec = right - left
eye_dist = float(np.linalg.norm(eye_vec)) or 1.0
R = _rot2d(-_angle(eye_vec))
return center, eye_dist, R
def to_face_frame(pt_xy, center, eye_dist, R):
"""Project a 2D point into the face frame."""
v = (pt_xy - center) / eye_dist
return (v @ R.T).astype(np.float32)
# ---------- model ----------
class SeqGRU(torch.nn.Module):
"""
BiGRU classifier used at training time; same shape and head for inference.
"""
def __init__(self, input_dim, hidden=128, num_classes=26):
super().__init__()
self.gru = torch.nn.GRU(input_dim, hidden, batch_first=True, bidirectional=True)
self.head = torch.nn.Sequential(
torch.nn.Linear(hidden*2, 128), torch.nn.ReLU(), torch.nn.Dropout(0.2),
torch.nn.Linear(128, num_classes),
)
def forward(self, x):
h,_ = self.gru(x) # (B,T,2H)
return self.head(h[:, -1, :]) # last-time-step logits
# ---------- main ----------
def main():
"""
Stream webcam, build rolling window of T frames, normalize with training stats,
classify with BiGRU, overlay current top prediction, and optionally trigger
an action when the sequence 'W', 'E', 'B' is observed.
"""
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True) # path to .pt checkpoint
ap.add_argument("--camera", type=int, default=0) # webcam device index
ap.add_argument("--threshold", type=float, default=0.35) # emit threshold for top prob
ap.add_argument("--smooth", type=float, default=0.1, help="EMA window (seconds); 0 disables")
ap.add_argument("--width", type=int, default=640) # capture resolution
ap.add_argument("--height", type=int, default=480)
ap.add_argument("--holistic-complexity", type=int, default=1, choices=[0,1,2]) # accuracy/speed
ap.add_argument("--det-thresh", type=float, default=0.5) # detector confidence thresholds
ap.add_argument("--url", type=str, default="https://www.google.com") # used on WEB
args = ap.parse_args()
state = torch.load(args.model, map_location="cpu", weights_only=False) # load checkpoint dict
classes = state["classes"] # label names
T = int(state.get("frames", 32)) # window length
X_mean = state["X_mean"].cpu().numpy().astype(np.float32) # normalization stats
X_std = (state["X_std"].cpu().numpy().astype(np.float32) + 1e-6)
input_dim = X_mean.shape[-1] # expected F (1670)
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") # Apple MPS if avail
model = SeqGRU(input_dim=input_dim, hidden=128, num_classes=len(classes)).to(device) # same arch
model.load_state_dict(state["model"]); model.eval() # load weights
hol = mp_holistic.Holistic( # configure detector
static_image_mode=False,
model_complexity=args.holistic_complexity,
smooth_landmarks=True,
enable_segmentation=False,
refine_face_landmarks=False,
min_detection_confidence=args.det_thresh,
min_tracking_confidence=args.det_thresh,
)
cap = cv2.VideoCapture(args.camera) # open camera
if not cap.isOpened(): raise SystemExit(f"❌ Could not open camera {args.camera}")
cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.width); cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height)
print(f"✅ Loaded {args.model} frames={T} classes={classes} input_dim={input_dim}")
print("Press 'q' to quit.")
seq_buffer, ema_probs = [], None # rolling window + smoother
last_ts = time.time() # for EMA time constant
last_emitted = None # de-bounce repeated prints
history = [] # recent emitted labels
while True:
ok, frame = cap.read() # grab a frame
if not ok: break
now = time.time(); dt = max(1e-6, now - last_ts); last_ts = now # frame delta seconds
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # BGR→RGB
res = hol.process(rgb) # run detection
overlay = "No face/hand" # default HUD text
current = None # currently confident label
# hands
right_pts = left_pts = None
if res.right_hand_landmarks is not None:
right_pts = np.array([[lm.x, lm.y, lm.z] for lm in res.right_hand_landmarks.landmark], np.float32)
if res.left_hand_landmarks is not None:
left_pts = np.array([[lm.x, lm.y, lm.z] for lm in res.left_hand_landmarks.landmark], np.float32)
# face
face_pts = None
if res.face_landmarks is not None:
face_pts = np.array([[lm.x, lm.y, lm.z] for lm in res.face_landmarks.landmark], np.float32)
# pose
pose_arr = None
if res.pose_landmarks is not None:
pose_arr = np.array([[lm.x, lm.y, lm.z, lm.visibility] for lm in res.pose_landmarks.landmark], np.float32)
if face_pts is not None and (right_pts is not None or left_pts is not None):
f_norm = normalize_face(face_pts) # normalized face (anchor)
# build extras in face frame (preserve where hands are relative to face)
f_center, f_scale, f_R = face_frame_transform(face_pts)
def hand_face_extras(hand_pts):
"""Return [wrist.x, wrist.y, tip.x, tip.y] projected into the face frame, or zeros."""
if hand_pts is None:
return np.zeros(4, np.float32)
wrist_xy = hand_pts[0, :2]
tip_xy = hand_pts[8, :2]
w = to_face_frame(wrist_xy, f_center, f_scale, f_R)
t = to_face_frame(tip_xy, f_center, f_scale, f_R)
return np.array([w[0], w[1], t[0], t[1]], np.float32)
rh_ex = hand_face_extras(right_pts)
lh_ex = hand_face_extras(left_pts)
rh = normalize_hand(right_pts, "Right").reshape(-1) if right_pts is not None else np.zeros(63, np.float32)
lh = normalize_hand(left_pts, "Left" ).reshape(-1) if left_pts is not None else np.zeros(63, np.float32)
p_norm = normalize_pose(pose_arr).reshape(-1) if pose_arr is not None else np.zeros(33*4, np.float32)
feat = np.concatenate([rh, lh, f_norm.reshape(-1), p_norm, rh_ex, lh_ex], axis=0) # (1670,)
seq_buffer.append(feat) # push newest feature frame
if len(seq_buffer) > T: seq_buffer.pop(0) # keep last T frames only
if len(seq_buffer) == T: # only infer when buffer full
X = np.stack(seq_buffer, 0) # (T, F)
Xn = (X - X_mean) / X_std # normalize with training stats
xt = torch.from_numpy(Xn).float().unsqueeze(0).to(device) # (1, T, F)
with torch.no_grad(): # inference (no grads)
probs = torch.softmax(model(xt), dim=1)[0].cpu().numpy() # class probabilities
if args.smooth > 0:
alpha = 1.0 - math.exp(-dt / args.smooth) # EMA with time-based alpha
ema_probs = probs if ema_probs is None else (1.0 - alpha) * ema_probs + alpha * probs
use = ema_probs
else:
use = probs
top_idx = int(np.argmax(use)); top_p = float(use[top_idx]); top_cls = classes[top_idx] # best class
overlay = f"{top_cls} {top_p*100:.1f}%" # HUD text
if top_p >= args.threshold: current = top_cls # only emit when confident
else:
seq_buffer, ema_probs = [], None # reset if face+hand not available
# Emit on change & optional "WEB" sequence trigger
if current is not None and current != last_emitted:
print(f"Detected: {current}") # console feedback
last_emitted = current
history.append(current) # remember last few
if len(history) > 3: history.pop(0)
if history == ["W","E","B"]: # simple finite-seq detector
print("🚀 Detected WEB! Opening browser…")
try: webbrowser.open(args.url) # launch default browser
except Exception as e: print(f"⚠️ Browser open failed: {e}")
history.clear() # reset after triggering
# Overlay HUD
buf = f"buf={len(seq_buffer)}/{T}" # show buffer fill
if ema_probs is not None:
ti = int(np.argmax(ema_probs)); tp = float(ema_probs[ti]); tc = classes[ti]
buf += f" top={tc} {tp:.2f}" # show smoothed top prob
cv2.putText(frame, overlay, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,255,0), 2)
cv2.putText(frame, buf, (20, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,0), 2)
cv2.imshow("ASL demo (R+L hands + face + pose + extras)", frame) # preview window
if cv2.waitKey(1) & 0xFF == ord('q'): break # quit key
cap.release(); cv2.destroyAllWindows() # cleanup
if __name__ == "__main__":
main()