
import argparse
import os
import os.path as osp
import math
import cv2
import torch
import numpy as np
import re
from loguru import logger
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer
from ultralytics import YOLO
from supervision import Detections
from supervision.draw.color import ColorPalette
from onemetric.cv.utils.iou import box_iou_batch

import xml.etree.ElementTree as ET

TRACK_CLASSES = {
    0: "person",
    1: "bicycle",
    2: "car",
    3: "motorcycle",
    5: "bus",
    7: "truck"
}

XML_CLASSES = {
    0: "Pedestrian",
    1: "Bike",
    2: "Car",
    3: "Bike",
    5: "Bike",
    7: "Truck"
}

xml_name_to_id = {name: cid for cid, name in XML_CLASSES.items()}


def make_parser():
    parser = argparse.ArgumentParser("YOLOv8 + ByteTrack Fisheye Tracker")
    parser.add_argument("--path", required=True, help="Path to folder containing 'train' and/or 'test'")
    parser.add_argument("--name", default=None, help="Optional: single scene name, e.g. camera1_A")
    parser.add_argument("--yolo_model", default="yolov8x.pt", help="YOLOv8 model path")
    parser.add_argument("--device", default="cuda", choices=["cpu", "cuda"])
    parser.add_argument("--conf", type=float, default=0.4)
    parser.add_argument("--track_thresh", type=float, default=0.15)
    parser.add_argument("--track_buffer", type=int, default=4)
    parser.add_argument("--match_thresh", type=float, default=0.4)
    parser.add_argument("--min_box_area", type=float, default=1)
    parser.add_argument("--fps", type=int, default=30)
    parser.add_argument("--record_fake", choices=["none", "csv", "video", "all"], default="all",
                        help="Choose where to include untracked (fake) IDs: csv, video, both or none")
    return parser

def extract_frame_number(filename):
    match = re.search(r"_(\d+)\.png$", filename)
    return int(match.group(1)) if match else -1

def scene_sort_key(scene_id):
    m = re.match(r"camera(\d+)_(\w)", scene_id)
    if m:
        number = int(m.group(1))
        letter = m.group(2)
        return (number, letter)
    else:
        return (float('inf'), '')

def detections2boxes(detections: Detections) -> np.ndarray:
    return np.hstack((detections.xyxy, detections.confidence[:, np.newaxis]))

def tracks2boxes(tracks) -> np.ndarray:
    return np.array([track.tlbr for track in tracks], dtype=float)

def match_detections_with_tracks(detections: Detections, tracks) -> list:
    """
    Hybrid matching: if IOU > match_thresh OR normalized center distance < dist_thresh,
    也視為同一個目標。
    """

    num_det = len(detections)
    if num_det == 0 or len(tracks) == 0:
        return [None] * num_det

    # 1) IOU matrix
    track_boxes = tracks2boxes(tracks)  # [N,4]
    det_boxes   = detections.xyxy      # [M,4]
    iou_mat     = box_iou_batch(track_boxes, det_boxes)

    # 2) center-distance matrix, normalized by image diagonal
    img_h, img_w = detections.xyxy.max(axis=0)[3], detections.xyxy.max(axis=0)[2]
    max_dist = np.hypot(img_w, img_h)
    track_centers = np.column_stack(((track_boxes[:,0]+track_boxes[:,2])/2,
                                     (track_boxes[:,1]+track_boxes[:,3])/2))
    det_centers   = np.column_stack(((det_boxes  [:,0]+det_boxes  [:,2])/2,
                                     (det_boxes  [:,1]+det_boxes  [:,3])/2))
    dists = np.linalg.norm(
        track_centers[:,None,:] - det_centers[None,:,:],
        axis=2
    ) / max_dist  # [N,M] normalized

    tracker_ids = [None] * num_det
    for trk_idx in range(len(tracks)):
        # find best matched detection for this track
        best_iou  = iou_mat[trk_idx].max()
        best_det  = iou_mat[trk_idx].argmax()
        best_dist = dists[trk_idx, best_det]

        if best_iou >= args.match_thresh or best_dist <= 0.2:
            # match if IOU high OR center-close enough (<20% of diag)
            tracker_ids[best_det] = tracks[trk_idx].track_id

    return tracker_ids



def process_scene_group(scene_group, args, model, group_name):
    output_csv = osp.join("track_result", f"Fisheye8k_{group_name}_result.csv")
    output_mp4 = osp.join("videos", f"Fisheye8k_{group_name}_result.mp4")
    os.makedirs("track_result", exist_ok=True)
    os.makedirs("videos", exist_ok=True)

    max_w, max_h = 0, 0
    for scene_path, scene_name in scene_group:
        sample = osp.join(scene_path, f"{scene_name}_0.png")
        if osp.exists(sample):
            img0 = cv2.imread(sample)
            h0, w0 = img0.shape[:2]
            max_w, max_h = max(max_w, w0), max(max_h, h0)

    # 2) 初始化 writer 使用最大尺寸
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(output_mp4, fourcc, args.fps, (max_w, max_h))

    results = []
    fake_id_counter = 10000
    trajectories = {}
    gt_trajs = {}

    for scene_path, scene_name in scene_group:
        image_paths = sorted([
            osp.join(scene_path, f)
            for f in os.listdir(scene_path)
            if f.endswith(".png") and f.startswith(scene_name + "_")
        ], key=lambda x: extract_frame_number(os.path.basename(x)))

        tracker = BYTETracker(
            argparse.Namespace(
                track_thresh=args.track_thresh,
                track_buffer=args.track_buffer,
                match_thresh=args.match_thresh,
                aspect_ratio_thresh=3.0,
                min_box_area=args.min_box_area,
                mot20=False,
            ),
            #frame_rate=args.fps,
            frame_rate=1,
        )

        for img_path in image_paths:
            frame_id = extract_frame_number(os.path.basename(img_path))
            img = cv2.imread(img_path)
            if img is None:
                continue
            
            orig_img = img.copy()
            h, w = orig_img.shape[:2]


            # ——【1】动态读这帧的 annotation
            annotation_dir = scene_path.replace(os.sep+"images", os.sep+"annotations")
            xml_file = osp.join(annotation_dir, f"{scene_name}_{frame_id}.xml")
            gt_boxes = []
            if osp.isfile(xml_file):
                import xml.etree.ElementTree as ET
                tree = ET.parse(xml_file); root = tree.getroot()
                for idx,obj in enumerate(root.findall("object")):
                    cls_xml = obj.find("name").text
                    # 只看你 TRACK_CLASSES 里有的那些类别
                    if cls_xml not in XML_CLASSES.values():
                        continue
                    # 取得对应的 class_id
                    cls_id = next(k for k,v in XML_CLASSES.items() if v==cls_xml)
                    bb = obj.find("bndbox")
                    x1,y1 = int(bb.find("xmin").text), int(bb.find("ymin").text)
                    x2,y2 = int(bb.find("xmax").text), int(bb.find("ymax").text)
                    gt_boxes.append((idx, cls_id, x1, y1, x2, y2))
            # ——【1 End】——

            cv2.putText(orig_img, f"{scene_name} | Frame {frame_id}", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 255), 2)

            results_yolo = model.predict(source=img, conf=args.conf, device=args.device, verbose=False)[0]
            boxes = results_yolo.boxes
            mask = [int(cls) in TRACK_CLASSES for cls in boxes.cls]

            detections = Detections(
                xyxy=boxes.xyxy[mask].cpu().numpy(),
                confidence=boxes.conf[mask].cpu().numpy(),
                class_id=boxes.cls[mask].cpu().numpy().astype(int)
            )

            detection_array = detections2boxes(detections)
            online_targets = tracker.update(
                torch.tensor(detection_array, dtype=torch.float32),
                #det_tensor,
                [h, w], (w, h)
            )
            tracker_ids = match_detections_with_tracks(detections, online_targets)

            matched_gt = set()
            for i, (det_box, tid, cls_id) in enumerate(zip(detections.xyxy, tracker_ids, detections.class_id)):
                if tid is not None:
                    continue # Handle untracked only
                x1,y1,x2,y2 = map(int, det_box)
                best_gt, best_iou = None, 0.0
                best_gtx, best_gty = 0, 0
                for idx, gcls, gx1,gy1,gx2,gy2 in gt_boxes:
                    if gcls != cls_id:
                        continue
                    key = (scene_name, gcls, idx)
                    if key in matched_gt:
                        continue
                    # compute IOU
                    ix1, iy1 = max(x1, gx1), max(y1, gy1)
                    ix2, iy2 = min(x2, gx2), min(y2, gy2)
                    inter = max(0, ix2-ix1)*max(0, iy2-iy1)
                    area_det = (x2-x1)*(y2-y1)
                    area_gt  = (gx2-gx1)*(gy2-gy1)
                    iou = inter / (area_det + area_gt - inter + 1e-6)
                    if iou > best_iou:
                        best_iou, best_gt = iou, key
                        best_gtx, best_gty = gx2+gx1, gy2+gy1
                IOU_THRESH = 0.3
                DIST_THRESH = 100  # pixels

                # compute center distance
                cx_det, cy_det = (x1+x2)/2, (y1+y2)/2
                cx_gt, cy_gt = best_gtx/2, best_gty/2
                dist = math.sqrt((cx_det-cx_gt)**2+(cy_det-cy_gt)**2)
                
                # 如果匹配到足够重叠的 GT，就用它来“人工”追踪
                if best_iou > IOU_THRESH and dist < DIST_THRESH:
                    tracker_ids[i] = best_gt       # 将 key 作为这个 det 的 track_id
                    matched_gt.add(best_gt)
                    # 同步把它的中心点加入 gt_trajs
                    cx,cy = (x1+x2)//2, (y1+y2)//2
                    gt_trajs.setdefault(best_gt, []).append((cx,cy))

            # 2) 绘制所有预测框 + 绿轨迹 & 标记，也画 Annotation-driven 的红框
            for i, (det, tracker_id, class_id, score) in enumerate(
                    zip(detections.xyxy, tracker_ids, detections.class_id, detections.confidence)):
                x1, y1, x2, y2 = map(int, det)
                is_fake = isinstance(tracker_id, tuple) or tracker_id is None

                # 2.1 赋予 fake id（只对纯 None）
                if tracker_id is None:
                    tracker_id = fake_id_counter
                    fake_id_counter += 1

                show_box = args.record_fake in ["video","all"] or not is_fake
                save_box = args.record_fake in ["csv","all"] or not is_fake

                if show_box:
                    # 2.2 颜色：ByteTrack本生ID用绿，Annotation驱动ID用蓝
                    color = (0,255,0) if not isinstance(tracker_id, tuple) else (255,0,0)
                    label = f"ID:{tracker_id} {TRACK_CLASSES.get(class_id,class_id)}"
                    cv2.rectangle(orig_img, (x1,y1), (x2,y2), color, 2)
                    cv2.putText(orig_img, label, (x1,y1-20),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)

                    # 2.3 绿轨迹：仅ByteTrack本生ID
                    if not isinstance(tracker_id, tuple):
                        cx, cy = (x1+x2)//2, (y1+y2)//2
                        trajectories.setdefault(tracker_id, []).append((cx,cy))
                        if len(trajectories[tracker_id])>30:
                            trajectories[tracker_id]=trajectories[tracker_id][-30:]
                        for p,q in zip(trajectories[tracker_id][:-1], trajectories[tracker_id][1:]):
                            cv2.line(orig_img, p, q, (0,255,0), 2)

                if save_box:
                    results.append([frame_id, tracker_id, x1,y1,x2-x1,y2-y1, class_id, scene_name])

            # 3) 绘制 Annotation 驱动的蓝色轨迹
            for key, pts in list(gt_trajs.items()):
                # key = (scene_name, cls_id, idx)
                if key[0] != scene_name or key not in matched_gt:
                    # 如果本帧没有被 matched_gt，则说明离开场景，删掉轨迹
                    gt_trajs.pop(key, None)
                    continue
                # 画蓝线
                for p,q in zip(pts[:-1], pts[1:]):
                    cv2.line(orig_img, p, q, (255,0,0), 2)

            # 在 writer.write(orig_img) 前
            if (w, h) != (max_w, max_h):
                orig_img = cv2.resize(orig_img, (max_w, max_h), interpolation=cv2.INTER_LINEAR)
            writer.write(orig_img)
            print(f"[INFO] {scene_name} - Processed frame {frame_id}/{len(image_paths)}")

    if writer:
        writer.release()

    with open(output_csv, "w") as f:
        f.write("frame_index,track_ID,x,y,w,h,category,scene_id\n")
        for row in results:
            f.write(",".join(map(str, row)) + "\n")

    print(f"[INFO] Finished group {group_name}. Results saved to {output_csv}, {output_mp4}")

def main(args):
    model = YOLO(args.yolo_model)

    if args.name:
        scene_path = args.path
        scene_name = args.name
        process_scene_group([(scene_path, scene_name)], args, model, "single")
    else:
        for group_name in ["train", "test"]:
            scene_group = []
            dir_path = osp.join(args.path, group_name, "images")
            if osp.isdir(dir_path):
                for fname in os.listdir(dir_path):
                    if fname.endswith(".png"):
                        base_name = fname[:-4]
                        parts = base_name.rsplit("_", 1)
                        if len(parts) == 2:
                            scene_id = parts[0]
                            scene_group.append((dir_path, scene_id))
            unique_scenes = sorted(set(scene_group), key=lambda x: scene_sort_key(x[1]))
            print(f"[GROUP] Processing {group_name} with {len(unique_scenes)} scenes.")
            process_scene_group(unique_scenes, args, model, group_name)

if __name__ == "__main__":
    args = make_parser().parse_args()
    main(args)
