#!/usr/bin/env python3
import argparse
import os
import os.path as osp
import re
import cv2
import torch
import numpy as np
from loguru import logger
from ultralytics import YOLO
from yolox.tracker.byte_tracker import BYTETracker
from supervision.tools.detections import Detections

def make_parser():
    parser = argparse.ArgumentParser("LOAF Person Tracker")
    parser.add_argument("--path", required=True)
    parser.add_argument("--name", default=None)
    parser.add_argument("--yolo_model", default="yolov8x.pt")
    parser.add_argument("--device", default="cuda", choices=["cpu","cuda"])
    parser.add_argument("--conf", type=float, default=0.3,)
    parser.add_argument("--track_thresh", type=float, default=0.1)
    parser.add_argument("--track_buffer", type=int, default=30)
    parser.add_argument("--match_thresh", type=float, default=0.4)
    parser.add_argument("--min_box_area", type=float, default=1)
    parser.add_argument("--fps", type=int, default=30)
    parser.add_argument("--record_fake", choices=["none","csv","video","all"], default="all")
    return parser

def extract_frame_number(fname):
    # LOAF 格式： {scene}_{frame}.jpg
    m = re.search(r"_(\d+)\.jpg$", fname)
    return int(m.group(1)) if m else -1

def scene_sort_key(scene_id):
    # LOAF 场景 ID 通常是纯数字字符串
    return int(scene_id)

def detections2boxes(det: Detections):
    return np.hstack((det.xyxy, det.confidence[:,None]))

def match_detections_with_tracks(detections, tracks, match_thresh):
    # 简单 IOU 匹配
    M = len(detections)
    if M==0 or len(tracks)==0:
        return [None]*M
    from onemetric.cv.utils.iou import box_iou_batch
    track_boxes = np.array([t.tlbr for t in tracks],dtype=float)
    det_boxes   = detections.xyxy
    iou_mat     = box_iou_batch(track_boxes, det_boxes)
    ids = [None]*M
    for ti, row in enumerate(iou_mat):
        j = row.argmax()
        if row[j]>= match_thresh:
            ids[j] = tracks[ti].track_id
    return ids

def process_scene_group(scene_group, args, model, group_name):
    # 1. 準備輸出目錄與檔名
    os.makedirs("track_result", exist_ok=True)
    os.makedirs("videos", exist_ok=True)
    csv_path = osp.join("track_result", f"LOAF_{group_name}_result.csv")
    mp4_path = osp.join("videos",      f"LOAF_{group_name}_result.mp4")

    # 2. 掃描每個場景的第一張影像，計算整體最大尺寸
    first_images = []
    for scene_path, scene_name in scene_group:
        files = [f for f in os.listdir(scene_path)
                 if f.endswith(".jpg") and f.startswith(scene_name + "_")]
        files = sorted(files, key=lambda x: extract_frame_number(x))
        if files:
            first_images.append(osp.join(scene_path, files[0]))
    max_w = max_h = 0
    for p in first_images:
        img0 = cv2.imread(p)
        h0, w0 = img0.shape[:2]
        max_w = max(max_w, w0)
        max_h = max(max_h, h0)

    # 3. 初始化 VideoWriter
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(mp4_path, fourcc, args.fps, (max_w, max_h))

    # 4. 結果緩衝
    results = []
    fake_id_ctr = 10000
    trajectories = {}

    # 5. 逐場景處理
    for scene_path, scene_name in scene_group:
        # 5.1 每場景一個新的 ByteTrack 追蹤器
        tracker = BYTETracker(
            argparse.Namespace(
                track_thresh=args.track_thresh,
                track_buffer=args.track_buffer,
                match_thresh=args.match_thresh,
                min_box_area=args.min_box_area,
                mot20=False
            ),
            frame_rate=1
        )

        # 5.2 列舉並排序本場景所有影像
        img_files = [f for f in os.listdir(scene_path)
                     if f.endswith(".jpg") and f.startswith(scene_name + "_")]
        img_files = sorted(img_files, key=lambda x: extract_frame_number(x))

        # 5.3 處理每張影像
        for fname in img_files:
            frame_id = extract_frame_number(fname)
            img = cv2.imread(osp.join(scene_path, fname))
            if img is None:
                continue

            # 5.3.1 若影像較小，貼到黑色畫布上，統一至 (max_w, max_h)
            h, w = img.shape[:2]
            if (w, h) != (max_w, max_h):
                canvas = np.zeros((max_h, max_w, 3), dtype=img.dtype)
                canvas[0:h, 0:w] = img
                orig = canvas
            else:
                orig = img.copy()

            # 5.3.2 標注場景與幀編號
            cv2.putText(orig, f"{scene_name} frame {frame_id}", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)

            # 5.3.3 YOLOv8 只偵測行人 (class 0)
            res = model.predict(source=orig, conf=args.conf,
                                device=args.device, verbose=False)[0]
            cls_ids = res.boxes.cls.cpu().numpy().astype(int)
            mask = (cls_ids == 0)
            dets = Detections(
                xyxy=res.boxes.xyxy[mask].cpu().numpy(),
                confidence=res.boxes.conf[mask].cpu().numpy(),
                class_id=cls_ids[mask]
            )

            # 5.3.4 準備偵測輸入並更新追蹤器
            det_array = np.hstack((dets.xyxy, dets.confidence[:, None]))
            online_tracks = tracker.update(
                torch.from_numpy(det_array).float(),
                [max_h, max_w], (max_w, max_h))
            tracker_ids = match_detections_with_tracks(dets, online_tracks, args.match_thresh)

            # 5.3.5 繪製框、ID、軌跡，並收集結果
            for xyxy, tid, conf in zip(dets.xyxy, tracker_ids, dets.confidence):
                x1, y1, x2, y2 = map(int, xyxy)
                is_fake = (tid is None)
                if is_fake:
                    tid = fake_id_ctr
                    fake_id_ctr += 1

                show_box = args.record_fake in ("video", "all") or not is_fake
                save_box = args.record_fake in ("csv",   "all") or not is_fake

                if show_box:
                    cv2.rectangle(orig, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    cv2.putText(orig, f"ID{tid}", (x1, y1 - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
                    # 軌跡
                    cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
                    trajectories.setdefault(tid, []).append((cx, cy))
                    trace = trajectories[tid][-30:]
                    for p, q in zip(trace, trace[1:]):
                        cv2.line(orig, p, q, (0, 255, 0), 2)

                if save_box:
                    w_box, h_box = x2 - x1, y2 - y1
                    results.append([frame_id, tid, x1, y1, w_box, h_box, scene_name])

            # 5.3.6 寫入影片
            writer.write(orig)
            print(f"[INFO] {scene_name} - Processed frame {frame_id}/{len(img_files)}")

    # 6. 完成後釋放影片
    writer.release()

    # 7. 輸出 CSV
    with open(csv_path, "w") as f:
        f.write("frame_index,track_ID,x,y,w,h,scene_id\n")
        for row in results:
            f.write(",".join(map(str, row)) + "\n")

    logger.info(f"[GROUP] {group_name} done → {csv_path}, {mp4_path}")

def main():
    args = make_parser().parse_args()
    model = YOLO(args.yolo_model)

    if args.name:
        # 单场景
        proc = [(args.path, args.name)]
        process_scene_group(proc, args, model, "single")
    else:
        for group in ("train","test"):
            base = osp.join(args.path, group)
            imgdir = base # osp.join(base,"images")
            if not osp.isdir(imgdir): continue
            # 收集所有 scene_id
            scenes = set(f.split("_",1)[0]
                         for f in os.listdir(imgdir) if f.endswith(".jpg"))
            lst = sorted(scenes, key=scene_sort_key)
            proc = [(imgdir,sc) for sc in lst]
            print(f"[GROUP] {group} → {len(proc)} scenes")
            process_scene_group(proc, args, model, group)

if __name__=="__main__":
    main()
