Files
slr_handshapes/prep_landmarks_binary.py
2026-01-19 22:19:15 -05:00

138 lines
4.4 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Prepare landmarks for a single binary task (Letter vs Not_Letter).
Data layout (per letter):
data/asl/
train/
A/
Not_A/
val/
A/
Not_A/
Usage (no outdir needed):
python prep_landmarks_binary.py --letter A
# -> saves into landmarks_A/
Optional:
python prep_landmarks_binary.py --letter B --data /path/to/dataset
"""
import os, argparse, json, math
from pathlib import Path
import numpy as np
import cv2
import mediapipe as mp
# ---------- geometry helpers ----------
def _angle(v): return math.atan2(v[1], v[0])
def _rot2d(t):
c, s = math.cos(t), math.sin(t)
return np.array([[c, -s], [s, c]], dtype=np.float32)
def normalize_landmarks(pts, handed=None):
"""
pts: (21,3) in MediaPipe normalized image coords.
Steps:
1) translate wrist to origin
2) mirror left->right (canonicalize)
3) rotate wrist->middle_mcp to +Y
4) scale by max pairwise XY distance
returns: (63,) float32
"""
pts = pts.astype(np.float32).copy()
# 1) translate
pts[:, :2] -= pts[0, :2]
# 2) canonicalize left/right
if handed and handed.lower().startswith("left"):
pts[:, 0] *= -1.0
# 3) rotate
v = pts[9, :2] # middle MCP
R = _rot2d(math.pi/2 - _angle(v)) # align to +Y
pts[:, :2] = pts[:, :2] @ R.T
# 4) scale
xy = pts[:, :2]
d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
d = 1.0 if d < 1e-6 else float(d)
pts[:, :2] /= d; pts[:, 2] /= d
return pts.reshape(-1)
# ---------- extraction ----------
def collect(split_dir: Path, pos_name: str, neg_name: str, min_det_conf: float):
X, y, paths = [], [], []
total, used = 0, 0
hands = mp.solutions.hands.Hands(
static_image_mode=True,
max_num_hands=1,
min_detection_confidence=min_det_conf
)
for label, cls in [(1, pos_name), (0, neg_name)]:
cls_dir = split_dir / cls
if not cls_dir.exists():
continue
for p in cls_dir.rglob("*"):
if not p.is_file() or p.suffix.lower() not in {".jpg",".jpeg",".png",".bmp",".webp"}:
continue
total += 1
bgr = cv2.imread(str(p))
if bgr is None:
continue
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
res = hands.process(rgb)
if not res.multi_hand_landmarks:
continue
ih = res.multi_hand_landmarks[0]
handed = None
if res.multi_handedness:
handed = res.multi_handedness[0].classification[0].label # "Left"/"Right"
pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32)
feat = normalize_landmarks(pts, handed)
X.append(feat); y.append(label); paths.append(str(p)); used += 1
X = np.stack(X) if X else np.zeros((0,63), np.float32)
y = np.array(y, dtype=np.int64)
print(f"Split '{split_dir.name}': found {total}, used {used} (hands detected).")
return X, y, paths
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--letter", required=True, help="Target letter (AZ)")
ap.add_argument("--data", default="data/asl", help="Root with train/ and val/ (default: data/asl)")
ap.add_argument("--outdir", default=None, help="Output dir (default: landmarks_<LETTER>)")
ap.add_argument("--min_det_conf", type=float, default=0.5, help="MediaPipe min detection confidence")
args = ap.parse_args()
L = args.letter.upper()
pos_name = L
neg_name = f"Not_{L}"
outdir = args.outdir or f"landmarks_{L}"
os.makedirs(outdir, exist_ok=True)
train_dir = Path(args.data) / "train"
val_dir = Path(args.data) / "val"
Xtr, ytr, ptr = collect(train_dir, pos_name, neg_name, args.min_det_conf)
Xva, yva, pva = collect(val_dir, pos_name, neg_name, args.min_det_conf)
# Save arrays + metadata
np.save(f"{outdir}/train_X.npy", Xtr)
np.save(f"{outdir}/train_y.npy", ytr)
np.save(f"{outdir}/val_X.npy", Xva)
np.save(f"{outdir}/val_y.npy", yva)
with open(f"{outdir}/class_names.json","w") as f:
json.dump([neg_name, pos_name], f) # index 0: Not_L, index 1: L
open(f"{outdir}/train_paths.txt","w").write("\n".join(ptr))
open(f"{outdir}/val_paths.txt","w").write("\n".join(pva))
print(f"✅ Saved {L}: train {Xtr.shape}, val {Xva.shape}, classes={[neg_name, pos_name]}{outdir}")
if __name__ == "__main__":
main()