import cv2
import tensorflow as tf
import numpy as np
import threading
import time
import copy

bbox_lock = threading.Lock()
AI_PROCESSING = threading.Event()
SHARED_FRAME = None
BOUNDING_BOX_INFO = None
RUNNING=True

yolo_label = {
    0 : "person", 1 : "bicycle", 2 : "car", 3 : "motorcycle", 4 : "airplane", 5 : "bus",
    6 : "train", 7 : "truck", 8 : "boat", 9 : "traffic", 10 : "fire", 11 : "stop", 12 : "parking",
    13 : "bench", 14 : "bird", 15 : "cat", 16 : "dog", 17 : "horse", 18 : "sheep", 19 : "cow",
    20 : "elephant", 21 : "bear", 22 : "zebra", 23 : "giraffe", 24 : "backpack", 25 : "umbrella",
    26 : "handbag", 27 : "tie", 28 : "suitcase", 29 : "frisbee", 30 : "skis", 31 : "snowboard",
    32 : "sports", 33 : "kite", 34 : "baseball", 35 : "baseball", 36 : "skateboard", 37 : "surfboard",
    38 : "tennis", 39 : "bottle", 40 : "wine", 41 : "cup", 42 : "fork", 43 : "knife", 44 : "spoon",
    45 : "bowl", 46 : "banana", 47 : "apple", 48 : "sandwich", 49 : "orange", 50 : "broccoli",
    51 : "carrot", 52 : "hot", 53 : "pizza", 54 : "donut", 55 : "cake", 56 : "chair", 57 : "couch",
    58 : "potted", 59 : "bed", 60 : "dining", 61 : "toilet", 62 : "tv", 63 : "laptop", 64 : "mouse",
    65 : "remote", 66 : "keyboard", 67 : "cell", 68 : "microwave", 69 : "oven", 70 : "toaster",
    71 : "sink", 72 : "refrigerator", 73 : "book", 74 : "clock", 75 : "vase", 76 : "scissors",
    77 : "teddy", 78 : "hair", 79 : "toothbrush"
}

def computer_iou(box1, box2):
    iou_ratio = 0
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    inter_w = max(0, x2 - x1)
    inter_h = max(0, y2 - y1)
    inter_area = inter_w * inter_h

    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

    union_area = area1 + area2 - inter_area
    if union_area != 0:
        iou_ratio = inter_area / union_area

    return iou_ratio

def image_preprocess(frame, input_details, model_input_size):
    input_type = input_details[0]['dtype']
    scale, zero_point = input_details[0]['quantization']
    frame_resized = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_resized = cv2.resize(frame_resized, model_input_size, interpolation=cv2.INTER_LINEAR)    
    scale = 1 if scale == 0 else scale
    frame_resized = frame_resized.astype(np.float32) / 255.0 / scale + zero_point
        
    if input_type == np.int8:
        np.clip(frame_resized, -128, 127, out=frame_resized)
        frame_resized = frame_resized.astype(np.int8, copy=False)
    elif input_type == np.uint8:        
        np.clip(frame_resized, 0, 255, out=frame_resized)
        frame_resized = frame_resized.astype(np.uint8, copy=False)
    elif input_type == np.float32:
        pass
    elif input_type == np.float16:
        None

    input_shape = input_details[0]['shape']
    if input_shape[-1] != 3:
        frame_resized = frame_resized.transpose(2,0,1)
    frame_resized = np.expand_dims(frame_resized, axis=0)

    return frame_resized

def yolonas_bbox(interpreter, confidence_threshold):
    # index0 (confidence): 1 x 8400 x 80
    # index1 (bounding box): 1 x 8400 x 4
    output_details = interpreter.get_output_details()
    
    # get data from index0 (confidence): 1 x 8400 x 80, and dequantized
    scale, zero_point = output_details[0]['quantization']   
    scale = 1 if scale == 0 else scale 
    class_confidence = interpreter.get_tensor(output_details[0]['index'])
    class_confidence = scale * (class_confidence.astype(np.float32) - zero_point)         
    
    # get data from index1 (bounding box): 1 x 8400 x 4, and dequantized
    scale, zero_point = output_details[1]['quantization']    
    scale = 1 if scale == 0 else scale
    bbox_coord = interpreter.get_tensor(output_details[1]['index'])
    bbox_coord = scale * (bbox_coord.astype(np.float32) - zero_point)    
    
    find_pass_class_idx = np.where(np.max(class_confidence, axis=2, keepdims=True)>confidence_threshold)[1]
    pass_class_labels_idx = np.argmax(class_confidence[:,find_pass_class_idx,:],axis=2).ravel()

    return [class_confidence, bbox_coord, find_pass_class_idx, pass_class_labels_idx]

def load_model(model_name, delegate_type):
    ai_sdk_version="2.40.0.251030"
    hexagon_version="hexagon-v68"
    ai_sdk_dir=f"/opt/marco/qairt/{ai_sdk_version}"
    machine_arch="aarch64-ubuntu-gcc9.4"
    try:
        if delegate_type == "htp":
            delegate = [
                tf.lite.experimental.load_delegate(
                    library = "/opt/marco/qairt/2.40.0.251030/lib/aarch64-ubuntu-gcc9.4/libQnnTFLiteDelegate.so",
                    options ={
                        "backend_type" : "htp",
                        # "log_level": 5,
                        # "profiling" : 2,
                        "library_path" : f"{ai_sdk_dir}/lib/{machine_arch}/libQnnHtp.so",
                        "skel_library_dir" : f"{ai_sdk_dir}/lib/{hexagon_version}/unsigned/",
                        "htp_performance_mode" : 2,
                        "htp_use_fold_relu": 1,
                    }
                )
            ]
        elif delegate_type == "gpu":
            delegate = [
                tf.lite.experimental.load_delegate(
                    library = "/opt/marco/qairt/2.40.0.251030/lib/aarch64-ubuntu-gcc9.4/libQnnTFLiteDelegate.so",
                    options ={
                        "backend_type" : "gpu",
                        "library_path" : f"{ai_sdk_dir}/lib/{machine_arch}/libQnnGpu.so",
                        "skel_library_dir" : f"{ai_sdk_dir}/lib/{hexagon_version}/unsigned/",
                    }
                )
            ]
        elif delegate_type == "cpu":
            delegate = None
        else:
            raise RuntimeError("Unsupported delegate type")

        interpreter = tf.lite.Interpreter(
            model_path=model_name,
            experimental_delegates=delegate,
        )

    except Exception as e:
        raise RuntimeError("Model init failed, error code: {e}!")

    return interpreter

def LITERT_PROCESS(model_name, delegate_type,
                   camera_width, camera_height,
                   confidence_threshold=0.5, iouthreshold=0.5,
                   ):
    global RUNNING
    global AI_PROCESSING
    global SHARED_FRAME
    global BOUNDING_BOX_INFO

    total_process_time = 0
    total_process_count = 0

    interpreter = load_model(model_name, delegate_type)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    
    # Better way is to get (640, 640) from the input_details
    model_input_size = (640, 640)
    scale_factors = np.array([
        camera_width / model_input_size[0],   # x1 scale
        camera_height / model_input_size[1],  # y1 scale
        camera_width / model_input_size[0],   # x2 scale
        camera_height / model_input_size[1]   # y2 scale
    ], dtype=np.float32)

    while(RUNNING):
        AI_PROCESSING.wait()
        if not RUNNING:
            break
        
        t1 = time.perf_counter()
        
        print(f"AI Processing Start")

        input_data = image_preprocess(SHARED_FRAME, input_details, model_input_size)
        if input_data is None:
            print("Unsupport type !")
            RUNNING = False
            break

        interpreter.set_tensor(input_details[0]['index'], input_data)

        t2 = time.perf_counter()
        interpreter.invoke()
        t3 = time.perf_counter()
        delta_time = (t3 - t2) * 1000
        print(f"Inference time = {(delta_time):.2f} ms")
        total_process_time +=(t2-t1)
        total_process_count += 1

        [class_confidence,
         bbox_coord,
         find_pass_class_idx,
         pass_class_labels_idx] = yolonas_bbox(interpreter, confidence_threshold)

        if(find_pass_class_idx.size):
            iou_filter_pass_class_index = list()
            for i, box1_idx in enumerate(find_pass_class_idx):
                pass_flag = True
                for j,box2_idx in enumerate(find_pass_class_idx[i+1:]):
                    if (pass_class_labels_idx[i] == pass_class_labels_idx[i+j+1]):
                        iou = computer_iou(bbox_coord[0,box1_idx,:],bbox_coord[0,box2_idx,:])
                        if iou > iouthreshold:
                            pass_flag = False
                            break
                    else:
                        continue

                if pass_flag:
                    iou_filter_pass_class_index.append(int(box1_idx))

            filter_class_labels_idx = np.argmax(class_confidence[0:,iou_filter_pass_class_index,:],axis=2).ravel()
            scale_bounding_box_list = np.array(bbox_coord[0:,iou_filter_pass_class_index,:] * scale_factors, dtype=int)
            print(f"Detect {len(iou_filter_pass_class_index)} Object")
            t4 = time.perf_counter()
            try:
                bbox_lock.acquire()
                BOUNDING_BOX_INFO = {
                    "bboxes_label_idx": filter_class_labels_idx,
                    "bboxes_coord": scale_bounding_box_list[0],
                    "infer_time": delta_time,
                    "bbox_fps": (1/(t4-t1)),
                }
            finally:
                bbox_lock.release()

        print("AI_PROCESSING END")
        SHARED_FRAME = None

        AI_PROCESSING.clear()

    print(f"Average FPS: {total_process_count/total_process_time}")

def main(demo_info):
    global RUNNING
    global SHARED_FRAME
    global AI_PROCESSING
    global BOUNDING_BOX_INFO

    if demo_info["input_type"] == "camera":
        cap = cv2.VideoCapture(2)
        if not cap.isOpened():
            raise RuntimeError("Cannot open camera")
        width = demo_info["camera_width"]
        height = demo_info["camera_height"]
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)

    elif demo_info["input_type"] == "video":
        cap = cv2.VideoCapture(demo_info["filename"])  # 改成你的影片檔名或路徑
        if not cap.isOpened():
            raise RuntimeError("Cannot open video file")
        width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    else:
        print("unsupport format")
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(demo_info['output_name'], fourcc, 30.0, (width, height))

    ai_thread = threading.Thread(
        target=LITERT_PROCESS,
        args=(
              demo_info["model_name"], demo_info["delegate_type"],
              width, height,
              demo_info["confidence_threshold"] ,demo_info["iou_threshold"]
        )
    )

    ai_thread.start()

    try:
        frame_count = 0
        bboxes_coord = None
        bboxes_label_idx = None
        while RUNNING:
            ret, frame = cap.read()
            if not ret:
                break

            if not AI_PROCESSING.isSet():
                SHARED_FRAME = copy.copy(frame)
                AI_PROCESSING.set()

            if (BOUNDING_BOX_INFO):
                bbox_lock.acquire()
                bboxes_coord = BOUNDING_BOX_INFO["bboxes_coord"]
                bboxes_label_idx = BOUNDING_BOX_INFO["bboxes_label_idx"]
                infer_time = BOUNDING_BOX_INFO["infer_time"]
                bbox_fps = BOUNDING_BOX_INFO["bbox_fps"]
                BOUNDING_BOX_INFO = None
                bbox_lock.release()

            if (bboxes_coord is not None):
                for bbox_idx, bbox_coord in enumerate(bboxes_coord):
                    bbox_coord[[0, 2]] = np.clip(bbox_coord[[0, 2]], 0, width)
                    bbox_coord[[1, 3]] = np.clip(bbox_coord[[1, 3]], 0, height)
                    if bbox_coord[0] == bbox_coord[2] or bbox_coord[1] == bbox_coord[3]:
                        continue
                    
                    cv2.rectangle(frame, (bbox_coord[0], bbox_coord[1]),(bbox_coord[2], bbox_coord[3]), (255,0,0), 4)
                    # 在矩形上方寫文字
                    label = yolo_label[bboxes_label_idx[bbox_idx]]
                    cv2.putText(
                        frame,                                  # 圖片
                        f"{label}",                             # 文字
                        (bbox_coord[0]+30, bbox_coord[1]+30),   # 左下角座標 (x, y)
                        cv2.FONT_HERSHEY_SIMPLEX,               # 字體
                        0.8,                                    # 字型大小
                        (0,0,255),                              # 顏色 BGR
                        2                                       # 粗細
                    )
                    cv2.putText(
                        frame,                                  # 圖片
                        f"{(infer_time):.2f} ms",               # 文字
                        (bbox_coord[0]+30, bbox_coord[1]+60),   # 左下角座標 (x, y)
                        cv2.FONT_HERSHEY_SIMPLEX,               # 字體
                        0.8,                                    # 字型大小
                        (0,0,255),                              # 顏色 BGR
                        2                                       # 粗細
                    )
                    cv2.putText(
                        frame,                                  # 圖片
                        f"BBox {(bbox_fps):.2f} fps",           # 文字
                        (0 + 30, 0 + 100),                     # 左下角座標 (x, y)
                        cv2.FONT_HERSHEY_SIMPLEX,               # 字體
                        1.5,                                      # 字型大小
                        (0,0,255),                              # 顏色 BGR
                        4                                       # 粗細
                    )

            out.write(frame)
            frame_count += 1
            if (frame_count%10 == 0):
                print(f"frame count: {frame_count}")

    except KeyboardInterrupt:
        print("Recording stopped by user")

    finally:
        RUNNING=False
        cap.release()
        out.release()        
        AI_PROCESSING.set()
        ai_thread.join()
        print(f"Recording saved as {demo_info['output_name']}")

if __name__ == "__main__":

    demo_info = {
        "model_name" : "/opt/marco/yolo_nas_s/quantized_yolo_nas_s_int8.tflite",
        "delegate_type": "htp",
        "input_type" : "video",
        "filename" : "litert_demo.mp4",
        "confidence_threshold" : 0.5,
        "iou_threshold" : 0.5,
        "camera_id" : 2,
        "camera_width" : 1920,
        "camera_height" : 1080,
        "output_name" : "output.mp4"
    }

    main(demo_info)
