250 lines
13 KiB
Python
250 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Live webcam inference for two hands + full face + pose + face-relative hand extras (1670 dims/frame).
|
||
Works for letters (A..Z) or word classes (e.g., Mother, Father).
|
||
Optionally detects the sequence W → E → B to open a URL.
|
||
"""
|
||
|
||
import os, math, argparse, time, webbrowser # stdlib
|
||
import numpy as np # arrays
|
||
import cv2 # webcam UI
|
||
import torch # inference
|
||
import mediapipe as mp # Holistic landmarks
|
||
# Quiet logs: reduce console noise from TF/absl/OpenCV
|
||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"; os.environ["GLOG_minloglevel"] = "2"
|
||
import absl.logging; absl.logging.set_verbosity(absl.logging.ERROR)
|
||
cv2.setLogLevel(0)
|
||
|
||
mp_holistic = mp.solutions.holistic
|
||
|
||
# ---------- normalization ----------
|
||
def _angle(v):
|
||
"""atan2 for 2D vector."""
|
||
return math.atan2(v[1], v[0])
|
||
|
||
def _rot2d(t):
|
||
"""2×2 rotation matrix for angle t."""
|
||
c, s = math.cos(t), math.sin(t)
|
||
return np.array([[c, -s], [s, c]], dtype=np.float32)
|
||
|
||
def normalize_hand(pts, handed=None):
|
||
"""
|
||
Wrist-translate, mirror left→right, rotate so middle MCP is +Y, scale by max XY spread.
|
||
Returns (21,3).
|
||
"""
|
||
pts = pts.astype(np.float32).copy()
|
||
pts[:, :2] -= pts[0, :2]
|
||
if handed and str(handed).lower().startswith("left"): pts[:, 0] *= -1.0
|
||
v = pts[9, :2]; R = _rot2d(math.pi/2 - _angle(v))
|
||
pts[:, :2] = pts[:, :2] @ R.T
|
||
xy = pts[:, :2]; d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
|
||
d = 1.0 if d < 1e-6 else float(d)
|
||
pts[:, :2] /= d; pts[:, 2] /= d
|
||
return pts
|
||
|
||
def normalize_face(face):
|
||
"""Center at eye midpoint, scale by inter-ocular, rotate eye-line horizontal; returns (468,3)."""
|
||
f = face.astype(np.float32).copy()
|
||
left, right = f[33, :2], f[263, :2]
|
||
center = 0.5*(left+right)
|
||
f[:, :2] -= center[None, :]
|
||
eye_vec = right - left; eye_dist = float(np.linalg.norm(eye_vec)) or 1.0
|
||
f[:, :2] /= eye_dist; f[:, 2] /= eye_dist
|
||
R = _rot2d(-_angle(eye_vec)); f[:, :2] = f[:, :2] @ R.T
|
||
return f
|
||
|
||
def normalize_pose(pose):
|
||
"""Center at shoulder midpoint, scale by shoulder width, rotate shoulders horizontal; returns (33,4)."""
|
||
p = pose.astype(np.float32).copy()
|
||
ls, rs = p[11, :2], p[12, :2]
|
||
center = 0.5*(ls+rs); p[:, :2] -= center[None, :]
|
||
sw_vec = rs - ls; sw = float(np.linalg.norm(sw_vec)) or 1.0
|
||
p[:, :2] /= sw; p[:, 2] /= sw
|
||
R = _rot2d(-_angle(sw_vec)); p[:, :2] = p[:, :2] @ R.T
|
||
return p
|
||
|
||
def face_frame_transform(face_pts):
|
||
"""Return (center, eye_dist, R) to project points into the face-normalized frame."""
|
||
left = face_pts[33, :2]; right = face_pts[263, :2]
|
||
center = 0.5*(left + right)
|
||
eye_vec = right - left
|
||
eye_dist = float(np.linalg.norm(eye_vec)) or 1.0
|
||
R = _rot2d(-_angle(eye_vec))
|
||
return center, eye_dist, R
|
||
|
||
def to_face_frame(pt_xy, center, eye_dist, R):
|
||
"""Project a 2D point into the face frame."""
|
||
v = (pt_xy - center) / eye_dist
|
||
return (v @ R.T).astype(np.float32)
|
||
|
||
# ---------- model ----------
|
||
class SeqGRU(torch.nn.Module):
|
||
"""
|
||
BiGRU classifier used at training time; same shape and head for inference.
|
||
"""
|
||
def __init__(self, input_dim, hidden=128, num_classes=26):
|
||
super().__init__()
|
||
self.gru = torch.nn.GRU(input_dim, hidden, batch_first=True, bidirectional=True)
|
||
self.head = torch.nn.Sequential(
|
||
torch.nn.Linear(hidden*2, 128), torch.nn.ReLU(), torch.nn.Dropout(0.2),
|
||
torch.nn.Linear(128, num_classes),
|
||
)
|
||
def forward(self, x):
|
||
h,_ = self.gru(x) # (B,T,2H)
|
||
return self.head(h[:, -1, :]) # last-time-step logits
|
||
|
||
# ---------- main ----------
|
||
def main():
|
||
"""
|
||
Stream webcam, build rolling window of T frames, normalize with training stats,
|
||
classify with BiGRU, overlay current top prediction, and optionally trigger
|
||
an action when the sequence 'W', 'E', 'B' is observed.
|
||
"""
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--model", required=True) # path to .pt checkpoint
|
||
ap.add_argument("--camera", type=int, default=0) # webcam device index
|
||
ap.add_argument("--threshold", type=float, default=0.35) # emit threshold for top prob
|
||
ap.add_argument("--smooth", type=float, default=0.1, help="EMA window (seconds); 0 disables")
|
||
ap.add_argument("--width", type=int, default=640) # capture resolution
|
||
ap.add_argument("--height", type=int, default=480)
|
||
ap.add_argument("--holistic-complexity", type=int, default=1, choices=[0,1,2]) # accuracy/speed
|
||
ap.add_argument("--det-thresh", type=float, default=0.5) # detector confidence thresholds
|
||
ap.add_argument("--url", type=str, default="https://www.google.com") # used on WEB
|
||
args = ap.parse_args()
|
||
|
||
state = torch.load(args.model, map_location="cpu", weights_only=False) # load checkpoint dict
|
||
classes = state["classes"] # label names
|
||
T = int(state.get("frames", 32)) # window length
|
||
X_mean = state["X_mean"].cpu().numpy().astype(np.float32) # normalization stats
|
||
X_std = (state["X_std"].cpu().numpy().astype(np.float32) + 1e-6)
|
||
input_dim = X_mean.shape[-1] # expected F (1670)
|
||
|
||
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") # Apple MPS if avail
|
||
model = SeqGRU(input_dim=input_dim, hidden=128, num_classes=len(classes)).to(device) # same arch
|
||
model.load_state_dict(state["model"]); model.eval() # load weights
|
||
|
||
hol = mp_holistic.Holistic( # configure detector
|
||
static_image_mode=False,
|
||
model_complexity=args.holistic_complexity,
|
||
smooth_landmarks=True,
|
||
enable_segmentation=False,
|
||
refine_face_landmarks=False,
|
||
min_detection_confidence=args.det_thresh,
|
||
min_tracking_confidence=args.det_thresh,
|
||
)
|
||
|
||
cap = cv2.VideoCapture(args.camera) # open camera
|
||
if not cap.isOpened(): raise SystemExit(f"❌ Could not open camera {args.camera}")
|
||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.width); cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height)
|
||
|
||
print(f"✅ Loaded {args.model} frames={T} classes={classes} input_dim={input_dim}")
|
||
print("Press 'q' to quit.")
|
||
|
||
seq_buffer, ema_probs = [], None # rolling window + smoother
|
||
last_ts = time.time() # for EMA time constant
|
||
last_emitted = None # de-bounce repeated prints
|
||
history = [] # recent emitted labels
|
||
|
||
while True:
|
||
ok, frame = cap.read() # grab a frame
|
||
if not ok: break
|
||
now = time.time(); dt = max(1e-6, now - last_ts); last_ts = now # frame delta seconds
|
||
|
||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # BGR→RGB
|
||
res = hol.process(rgb) # run detection
|
||
|
||
overlay = "No face/hand" # default HUD text
|
||
current = None # currently confident label
|
||
|
||
# hands
|
||
right_pts = left_pts = None
|
||
if res.right_hand_landmarks is not None:
|
||
right_pts = np.array([[lm.x, lm.y, lm.z] for lm in res.right_hand_landmarks.landmark], np.float32)
|
||
if res.left_hand_landmarks is not None:
|
||
left_pts = np.array([[lm.x, lm.y, lm.z] for lm in res.left_hand_landmarks.landmark], np.float32)
|
||
|
||
# face
|
||
face_pts = None
|
||
if res.face_landmarks is not None:
|
||
face_pts = np.array([[lm.x, lm.y, lm.z] for lm in res.face_landmarks.landmark], np.float32)
|
||
|
||
# pose
|
||
pose_arr = None
|
||
if res.pose_landmarks is not None:
|
||
pose_arr = np.array([[lm.x, lm.y, lm.z, lm.visibility] for lm in res.pose_landmarks.landmark], np.float32)
|
||
|
||
if face_pts is not None and (right_pts is not None or left_pts is not None):
|
||
f_norm = normalize_face(face_pts) # normalized face (anchor)
|
||
|
||
# build extras in face frame (preserve where hands are relative to face)
|
||
f_center, f_scale, f_R = face_frame_transform(face_pts)
|
||
|
||
def hand_face_extras(hand_pts):
|
||
"""Return [wrist.x, wrist.y, tip.x, tip.y] projected into the face frame, or zeros."""
|
||
if hand_pts is None:
|
||
return np.zeros(4, np.float32)
|
||
wrist_xy = hand_pts[0, :2]
|
||
tip_xy = hand_pts[8, :2]
|
||
w = to_face_frame(wrist_xy, f_center, f_scale, f_R)
|
||
t = to_face_frame(tip_xy, f_center, f_scale, f_R)
|
||
return np.array([w[0], w[1], t[0], t[1]], np.float32)
|
||
|
||
rh_ex = hand_face_extras(right_pts)
|
||
lh_ex = hand_face_extras(left_pts)
|
||
|
||
rh = normalize_hand(right_pts, "Right").reshape(-1) if right_pts is not None else np.zeros(63, np.float32)
|
||
lh = normalize_hand(left_pts, "Left" ).reshape(-1) if left_pts is not None else np.zeros(63, np.float32)
|
||
p_norm = normalize_pose(pose_arr).reshape(-1) if pose_arr is not None else np.zeros(33*4, np.float32)
|
||
|
||
feat = np.concatenate([rh, lh, f_norm.reshape(-1), p_norm, rh_ex, lh_ex], axis=0) # (1670,)
|
||
seq_buffer.append(feat) # push newest feature frame
|
||
if len(seq_buffer) > T: seq_buffer.pop(0) # keep last T frames only
|
||
|
||
if len(seq_buffer) == T: # only infer when buffer full
|
||
X = np.stack(seq_buffer, 0) # (T, F)
|
||
Xn = (X - X_mean) / X_std # normalize with training stats
|
||
xt = torch.from_numpy(Xn).float().unsqueeze(0).to(device) # (1, T, F)
|
||
|
||
with torch.no_grad(): # inference (no grads)
|
||
probs = torch.softmax(model(xt), dim=1)[0].cpu().numpy() # class probabilities
|
||
|
||
if args.smooth > 0:
|
||
alpha = 1.0 - math.exp(-dt / args.smooth) # EMA with time-based alpha
|
||
ema_probs = probs if ema_probs is None else (1.0 - alpha) * ema_probs + alpha * probs
|
||
use = ema_probs
|
||
else:
|
||
use = probs
|
||
|
||
top_idx = int(np.argmax(use)); top_p = float(use[top_idx]); top_cls = classes[top_idx] # best class
|
||
overlay = f"{top_cls} {top_p*100:.1f}%" # HUD text
|
||
if top_p >= args.threshold: current = top_cls # only emit when confident
|
||
else:
|
||
seq_buffer, ema_probs = [], None # reset if face+hand not available
|
||
|
||
# Emit on change & optional "WEB" sequence trigger
|
||
if current is not None and current != last_emitted:
|
||
print(f"Detected: {current}") # console feedback
|
||
last_emitted = current
|
||
history.append(current) # remember last few
|
||
if len(history) > 3: history.pop(0)
|
||
if history == ["W","E","B"]: # simple finite-seq detector
|
||
print("🚀 Detected WEB! Opening browser…")
|
||
try: webbrowser.open(args.url) # launch default browser
|
||
except Exception as e: print(f"⚠️ Browser open failed: {e}")
|
||
history.clear() # reset after triggering
|
||
|
||
# Overlay HUD
|
||
buf = f"buf={len(seq_buffer)}/{T}" # show buffer fill
|
||
if ema_probs is not None:
|
||
ti = int(np.argmax(ema_probs)); tp = float(ema_probs[ti]); tc = classes[ti]
|
||
buf += f" top={tc} {tp:.2f}" # show smoothed top prob
|
||
cv2.putText(frame, overlay, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,255,0), 2)
|
||
cv2.putText(frame, buf, (20, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,0), 2)
|
||
cv2.imshow("ASL demo (R+L hands + face + pose + extras)", frame) # preview window
|
||
if cv2.waitKey(1) & 0xFF == ord('q'): break # quit key
|
||
|
||
cap.release(); cv2.destroyAllWindows() # cleanup
|
||
|
||
if __name__ == "__main__":
|
||
main()
|