#!/usr/bin/env python3
"""
Ara NPU Basic Video Stream & Inference Hub
==========================================
"""

import argparse
import ctypes
import glob
import os
import sys
import threading
import time
import logging
import cv2
import numpy as np
from flask import Flask, Response, jsonify, request, render_template_string
import gi

gi.require_version('Gst', '1.0')
from gi.repository import Gst
Gst.init(None)

# Quiet down Werkzeug HTTP traffic logging to suppress 1Hz AJAX console pollution
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)

app = Flask(__name__)
lock = threading.Lock()

class AraDetection(ctypes.Structure):
    _pack_ = 1
    _fields_ = [
        ("xmin", ctypes.c_float), ("ymin", ctypes.c_float),
        ("xmax", ctypes.c_float), ("ymax", ctypes.c_float),
        ("confidence", ctypes.c_float), ("class_id", ctypes.c_int32),
        ("class_name_ptr", ctypes.c_void_p)
    ]

# --- STATE STORAGE ---
STATE_REPO = {
    "frame": None,
    "detections": [],
    "active_source": None,
    "active_model_name": "yolov8n",
    "active_model_path": "/usr/share/cnn/detection/yolov8n/model.dvm",
    "restart_flag": False,
    "source_registry": [],
    "model_registry": ["yolov8n"],
    
    # Target Pipeline Resolutions
    "CANVAS_W": 640,
    "CANVAS_H": 360,
    "MODEL_W": 640,
    "MODEL_H": 640,
    
    # Live Telemetry Metrics
    "native_w": 0,
    "native_h": 0,
    "stream_w": 0,
    "stream_h": 0,
    "inference_fps": 0.0
}

# FPS Calculation variables bound directly to the Inference thread
inference_timestamps = []

COCO_LABELS = {
    0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus',
    6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant',
    11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat',
    16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear',
    22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag',
    27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard',
    32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove',
    36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle',
    40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl',
    46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli',
    51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake',
    56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table',
    61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard',
    67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink',
    72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors',
    77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}

def build_source_injection_string(source_path):
    if source_path.endswith(".mp4"):
        return f"filesrc location={source_path} ! decodebin ! videoconvert ! tee name=t "
    else:
        return f"v4l2src device={source_path} ! videoconvert ! tee name=t "

def gstreamer_orchestration_loop():
    global inference_timestamps
    CANVAS_W = STATE_REPO["CANVAS_W"]
    CANVAS_H = STATE_REPO["CANVAS_H"]
    MODEL_W = STATE_REPO["MODEL_W"]
    MODEL_H = STATE_REPO["MODEL_H"]

    while True:
        while STATE_REPO["active_source"] is None:
            time.sleep(0.2)
            if STATE_REPO["restart_flag"]:
                break

        current_target_source = STATE_REPO["active_source"]
        current_target_model = STATE_REPO["active_model_path"]
        STATE_REPO["restart_flag"] = False
        
        if current_target_source is None:
            continue
            
        source_segment = build_source_injection_string(current_target_source)
        
        pipe_str = (
            f"{source_segment} "
            f"t. ! queue max-size-buffers=2 leaky=downstream ! appsink name=nativesink sync=false async=false emit-signals=true "
            f"t. ! queue max-size-buffers=2 leaky=downstream ! videoscale ! video/x-raw,width={CANVAS_W},height={CANVAS_H} ! videoconvert ! video/x-raw,format=BGR ! appsink name=framesink sync=false async=false emit-signals=true "
            f"t. ! queue max-size-buffers=2 leaky=downstream ! "
            f"videoscale ! video/x-raw,width={MODEL_W},height={MODEL_H} ! videoconvert ! video/x-raw,format=BGRA ! "
            f"dvPre model={current_target_model} ! "
            f"dvInf model={current_target_model} sock=/var/run/proxy.sock use-shm=true shm-path=/dev/shm/ara_inf_ ! "
            f"dvPost model={current_target_model} orig-width={MODEL_W} orig-height={MODEL_H} ! "
            f"appsink name=postsink sync=false async=false emit-signals=true"
        )

        print(f"[LAUNCH PIPELINE]\n   {pipe_str}\n")
        pipeline = Gst.parse_launch(pipe_str)
        
        native_sink = pipeline.get_by_name("nativesink")
        frame_sink = pipeline.get_by_name("framesink")
        post_sink = pipeline.get_by_name("postsink")

        def on_native_caps(sink):
            sample = sink.emit("pull-sample")
            if sample:
                caps = sample.get_caps()
                struct = caps.get_structure(0)
                STATE_REPO["native_w"] = struct.get_value("width")
                STATE_REPO["native_h"] = struct.get_value("height")
            return Gst.FlowReturn.OK

        def on_new_detection(sink):
            global inference_timestamps
            sample = sink.emit("pull-sample")
            if sample:
                # Calculate FPS derived purely from the inference hardware return loop
                now = time.time()
                inference_timestamps.append(now)
                if len(inference_timestamps) > 30:
                    inference_timestamps.pop(0)
                if len(inference_timestamps) > 1:
                    STATE_REPO["inference_fps"] = len(inference_timestamps) / (inference_timestamps[-1] - inference_timestamps[0])

                buffer = sample.get_buffer()
                raw_bytes = buffer.extract_dup(0, buffer.get_size())
                if raw_bytes and len(raw_bytes) >= 4:
                    num_detections = np.frombuffer(raw_bytes[:4], dtype=np.uint32)[0]
                    local_dets = []
                    offset = 4
                    ds = ctypes.sizeof(AraDetection)
                    for _ in range(num_detections):
                        if offset + ds > len(raw_bytes): break
                        det = AraDetection.from_buffer_copy(raw_bytes[offset:offset+ds])
                        offset += ds
                        local_dets.append((det.class_id, det.confidence, det.xmin, det.ymin, det.xmax, det.ymax))
                    STATE_REPO["detections"] = local_dets
            return Gst.FlowReturn.OK

        def on_new_frame(sink):
            sample = sink.emit("pull-sample")
            if sample:
                buffer = sample.get_buffer()
                caps = sample.get_caps()
                struct = caps.get_structure(0)
                w = struct.get_value("width")
                h = struct.get_value("height")
                
                STATE_REPO["stream_w"] = w
                STATE_REPO["stream_h"] = h
                
                raw_bytes = buffer.extract_dup(0, buffer.get_size())
                if raw_bytes:
                    try:
                        frame_flat = np.frombuffer(raw_bytes, dtype=np.uint8)
                        frame_arr = frame_flat.reshape((h, w, 3))
                        STATE_REPO["frame"] = frame_arr.copy()
                    except ValueError:
                        pass
            return Gst.FlowReturn.OK

        native_sink.connect("new-sample", on_native_caps)
        post_sink.connect("new-sample", on_new_detection)
        frame_sink.connect("new-sample", on_new_frame)
        pipeline.set_state(Gst.State.PLAYING)

        bus = pipeline.get_bus()
        while True:
            msg = bus.timed_pop_filtered(Gst.SECOND * 0.05, Gst.MessageType.ERROR | Gst.MessageType.EOS)
            if msg:
                if msg.type == Gst.MessageType.EOS and current_target_source.endswith(".mp4"):
                    pipeline.seek_simple(Gst.Format.TIME, Gst.SeekFlags.FLUSH | Gst.SeekFlags.KEY_UNIT, 0)
                    continue
                break
            
            if STATE_REPO["restart_flag"]:
                break
        
        pipeline.set_state(Gst.State.NULL)
        STATE_REPO["frame"] = None
        STATE_REPO["detections"] = []
        STATE_REPO["native_w"] = 0
        STATE_REPO["native_h"] = 0
        STATE_REPO["stream_w"] = 0
        STATE_REPO["stream_h"] = 0
        STATE_REPO["inference_fps"] = 0.0
        inference_timestamps = []
        time.sleep(1.0)

@app.route('/')
def index():
    src_active = STATE_REPO["active_source"]
    
    if not STATE_REPO["source_registry"]:
        src_html = '<option value="" disabled selected>-- NO VALID INPUT SOURCES AVAILABLE --</option>'
    else:
        src_html = '<option value="" disabled selected>-- SELECT TARGET SOURCE CHANNEL --</option>' if src_active is None else ""
        src_html += "".join(f'<option value="{s}" {"selected" if s == src_active else ""}>{s}</option>' for s in STATE_REPO["source_registry"])
    
    mdl_active = STATE_REPO["active_model_name"]
    mdl_html = "".join(f'<option value="{m}" {"selected" if m == mdl_active else ""}>{m}</option>' for m in STATE_REPO["model_registry"])

    html_template = """<!DOCTYPE html>
    <html>
    <head>
        <title>Ara Stream Client</title>
        <style>
            body { font-family: sans-serif; background: #0c0c0e; color: #e1e1e6; margin: 0; padding: 20px; display: flex; flex-direction: column; align-items: center; }
            .dashboard-layout { display: flex; flex-direction: column; gap: 15px; width: 660px; }
            .panel { background: #121216; padding: 12px 15px; border-radius: 6px; border: 1px solid #1f1f24; display: flex; flex-direction: column; gap: 10px; }
            .control-row { display: flex; align-items: center; justify-content: space-between; }
            label { font-size: 12px; font-weight: bold; color: #8f8f9d; text-transform: uppercase; }
            select { background: #0c0c0e; color: #fff; border: 1px solid #04d361; padding: 6px 10px; border-radius: 4px; width: 420px; outline: none; }
            .stats-banner { display: flex; justify-content: space-between; background: #17171f; padding: 10px 15px; border: 1px solid #1f1f24; border-radius: 4px; font-family: monospace; font-size: 13px; color: #8f8f9d; }
            .stats-banner span strong { color: #04d361; }
            .media-container { background: #121216; padding: 8px; border-radius: 6px; border: 1px solid #1f1f24; position: relative; min-height: 480px; display: flex; align-items: center; justify-content: center; }
            img { display: block; border-radius: 4px; width: 100%; height: auto; }
            .overlay { position: absolute; top: 0; left: 0; width: 100%; height: 100%; background: rgba(12,12,14,0.9); display: flex; flex-direction: column; align-items: center; justify-content: center; border-radius: 6px; text-align: center; }
            .prompt-text { color: #04d361; font-weight: bold; font-size: 16px; margin-bottom: 10px; }
        </style>
        <script>
            let streamStarted = {% if active_src %}true{% else %}false{% endif %};
            
            async function switchConfig() {
                const src = document.getElementById('source-picker').value;
                const mdl = document.getElementById('model-picker').value;
                if(!src) return;
                
                await fetch('/api/swap_config', {
                    method: 'POST',
                    headers: { 'Content-Type': 'application/json' },
                    body: JSON.stringify({ "source": src, "model": mdl })
                });
                
                streamStarted = true;
                document.getElementById('gatekeeper-overlay').style.display = 'none';
                setTimeout(() => {
                    document.getElementById('stream-player').src = '/stream.mjpg';
                }, 1000);
            }

            async function updateStreamMetrics() {
                if (!streamStarted) return;
                try {
                    const response = await fetch('/api/stream_info');
                    const data = await response.json();
                    
                    document.getElementById('metric-res').innerText = 'Source:' + data.native_w + 'x' + data.native_h + ' Canvas:' + data.width + 'x' + data.height;
                    document.getElementById('metric-fps').innerText = data.fps.toFixed(1);
                    document.getElementById('metric-dets').innerText = data.detections;
                } catch (err) {}
            }
            setInterval(updateStreamMetrics, 1000);
        </script>
    </head>
    <body>
        <h2>Ara Vision Engine</h2>
        <div class="dashboard-layout">
            <div class="panel">
                <div class="control-row">
                    <label for="source-picker">Media Stream Target:</label>
                    <select id="source-picker" onchange="switchConfig()">""" + src_html + """</select>
                </div>
                <div class="control-row">
                    <label for="model-picker">NPU Pipeline Model:</label>
                    <select id="model-picker" onchange="switchConfig()">""" + mdl_html + """</select>
                </div>
            </div>

            <div class="stats-banner">
                <span id="metric-res">Source:0x0 Canvas:0x0</span>
                <span>NPU Inference: <span id="metric-fps">0.0</span> FPS</span>
                <span>Active Detections: <span id="metric-dets">0</span></span>
            </div>

            <div class="media-container">
                {% if not active_src %}
                <div class="overlay" id="gatekeeper-overlay">
                    <div class="prompt-text">Awaiting Source Context</div>
                    <div style="color: #8f8f9d; font-size: 13px; max-width: 400px;">Please select a media path and model from the drop-downs above to mount your pipeline.</div>
                </div>
                {% endif %}
                <img id="stream-player" {% if active_src %}src="/stream.mjpg"{% endif %} style="max-width: """ + str(STATE_REPO["CANVAS_W"]) + """px;" />
            </div>
        </div>
    </body>
    </html>"""
    return render_template_string(html_template, active_src=src_active)

@app.route('/api/stream_info')
def stream_info():
    with lock:
        return jsonify({
            "native_w": STATE_REPO["native_w"],
            "native_h": STATE_REPO["native_h"],
            "width": STATE_REPO["stream_w"],
            "height": STATE_REPO["stream_h"],
            "fps": STATE_REPO["inference_fps"],
            "detections": len(STATE_REPO["detections"])
        })

@app.route('/api/swap_config', methods=['POST'])
def swap_config():
    payload = request.get_json()
    src_selected = payload.get("source")
    mdl_selected = payload.get("model")
    
    with lock:
        trigger_restart = False
        if src_selected in STATE_REPO["source_registry"] and STATE_REPO["active_source"] != src_selected:
            STATE_REPO["active_source"] = src_selected
            trigger_restart = True
        if mdl_selected in STATE_REPO["model_registry"] and STATE_REPO["active_model_name"] != mdl_selected:
            base_dir = app.config["MODEL_DIR"]
            STATE_REPO["active_model_name"] = mdl_selected
            STATE_REPO["active_model_path"] = os.path.join(base_dir, mdl_selected, "model.dvm")
            trigger_restart = True
        if trigger_restart:
            STATE_REPO["restart_flag"] = True
    return jsonify({"status": "success"})

def generate_mjpeg_stream_generator():
    MODEL_W = float(STATE_REPO["MODEL_W"])
    MODEL_H = float(STATE_REPO["MODEL_H"])
    
    while True:
        time.sleep(0.04)
        frame_copy = STATE_REPO["frame"]
        local_dets = list(STATE_REPO["detections"])
        if frame_copy is not None:
            frame = frame_copy.copy()
            h_native, w_native, _ = frame_copy.shape
            for class_id, confidence, rx1, ry1, rx2, ry2 in local_dets:
                cx1 = int(rx1 * (float(w_native) / MODEL_W))
                cx2 = int(rx2 * (float(w_native) / MODEL_W))
                cy1 = int(ry1 * (float(h_native) / MODEL_H))
                cy2 = int(ry2 * (float(h_native) / MODEL_H))
                label = f"{COCO_LABELS.get(class_id, f'Class {class_id}')} ({confidence*100:.1f}%)"
                cv2.rectangle(frame, (cx1, cy1), (cx2, cy2), (0, 255, 97), 2)
                cv2.putText(frame, label, (cx1, max(15, cy1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 97), 2)
            _, encoded_img = cv2.imencode(".jpg", frame)
            yield (b'--frame\r\n'
                   b'Content-Type: image/jpeg\r\n\r\n' + encoded_img.tobytes() + b'\r\n')
        else:
            waiting_canvas = np.zeros((480, 640, 3), dtype=np.uint8)
            cv2.putText(waiting_canvas, "AWAITING MEDIA INPUT SELECTION...", (140, 240), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 97), 1)
            _, encoded_img = cv2.imencode(".jpg", waiting_canvas)
            yield (b'--frame\r\n'
                   b'Content-Type: image/jpeg\r\n\r\n' + encoded_img.tobytes() + b'\r\n')

@app.route('/stream.mjpg')
def video_feed_stream_route():
    return Response(generate_mjpeg_stream_generator(), mimetype='multipart/x-mixed-replace; boundary=frame')

def main():
    parser = argparse.ArgumentParser(description="Wiki Template: Ara Flask Video Engine")
    parser.add_argument("--camera", default=None, help="Camera context device node path")
    parser.add_argument("--mp4", default=None, help="Directory containing target mp4 sample videos")
    parser.add_argument("--port", type=int, default=8080, help="Target port mapping")
    parser.add_argument("--model-dir", default="/usr/share/cnn/detection", help="Directory containing target models")
    parser.add_argument("--model", default="yolov8n", help="Initial model selection")
    args = parser.parse_args()

    app.config["MODEL_DIR"] = args.model_dir
    STATE_REPO["source_registry"] = []

    if args.camera and os.path.exists(args.camera):
        STATE_REPO["source_registry"].append(args.camera)

    if args.mp4 and os.path.exists(args.mp4):
        local_videos = glob.glob(os.path.join(args.mp4, "*.mp4"))
        for vid in sorted(local_videos):
            STATE_REPO["source_registry"].append(vid)

    if os.path.exists(args.model_dir):
        discovered_models = []
        for entry in sorted(os.listdir(args.model_dir)):
            full_subdir = os.path.join(args.model_dir, entry)
            if os.path.isdir(full_subdir) and os.path.exists(os.path.join(full_subdir, "model.dvm")):
                discovered_models.append(entry)
        if discovered_models:
            STATE_REPO["model_registry"] = discovered_models
            STATE_REPO["active_model_name"] = args.model if args.model in discovered_models else discovered_models[0]
            STATE_REPO["active_model_path"] = os.path.join(args.model_dir, STATE_REPO["active_model_name"], "model.dvm")

    threading.Thread(target=gstreamer_orchestration_loop, daemon=True).start()

    print(f"Server serving on: http://localhost:{args.port}/")
    app.run(host='0.0.0.0', port=args.port, threaded=True, use_reloader=False, debug=False)

if __name__ == '__main__':
    main()
