#!/usr/bin/env python3
from bcc import BPF
import time
import socket
import struct

# --- 1. eBPF C 代码 ---
bpf_text = """
#include <uapi/linux/bpf.h>
#include <linux/in.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <linux/udp.h>

// 定义统计结构体
struct stats {
    u64 packets;
    u64 bytes;
};

// 定义哈希表：Key 为源 IP (IPv4)，Value 为统计数据
BPF_HASH(client_stats, u32, struct stats);

int xdp_hysteria2_monitor(struct xdp_md *ctx) {
    void *data_end = (void *)(long)ctx->data_end;
    void *data = (void *)(long)ctx->data;

    // 1. 解析以太网头
    struct ethhdr *eth = data;
    if ((void *)(eth + 1) > data_end) return XDP_PASS;

    // 只处理 IPv4 流量
    if (eth->h_proto != bpf_htons(ETH_P_IP)) return XDP_PASS;

    // 2. 解析 IP 头
    struct iphdr *iph = (struct iphdr *)(eth + 1);
    if ((void *)(iph + 1) > data_end) return XDP_PASS;

    // Hysteria2 基于 QUIC/UDP，只处理 UDP 协议
    if (iph->protocol != IPPROTO_UDP) return XDP_PASS;

    // 3. 解析 UDP 头
    struct udphdr *udph = (struct udphdr *)(iph + 1);
    if ((void *)(udph + 1) > data_end) return XDP_PASS;

    // 4. 过滤 Hysteria2 监听端口 (假设为 443，请按需修改)
    u16 target_port = 443; 
    if (udph->dest != bpf_htons(target_port)) return XDP_PASS;

    // 5. 提取源 IP 并更新统计信息
    u32 src_ip = iph->saddr;
    struct stats *st, zero = {0, 0};

    st = client_stats.lookup_or_try_init(&src_ip, &zero);
    if (st) {
        st->packets += 1;
        st->bytes += (data_end - data); // 累加数据包大小
    }

    // 纯观测程序，最后将数据包放行给内核网络栈
    return XDP_PASS;
}
"""

# --- 2. Python 控制平面 ---
# 替换为你的实际网卡名称，例如 eth0, ens3
INTERFACE = "eth0" 

print(f"[*] 正在编译并加载 XDP 程序到 {INTERFACE}...")
b = BPF(text=bpf_text)
in_fn = b.load_func("xdp_hysteria2_monitor", BPF.XDP)

# 将 XDP 程序挂载到网卡
b.attach_xdp(INTERFACE, in_fn, 0)
print("[*] 成功挂载 XDP 程序！按 Ctrl+C 停止观测...\n")

def inet_ntoa(addr):
    return socket.inet_ntoa(struct.pack("<I", addr))

try:
    while True:
        time.sleep(2)
        print(f"--- Hysteria2 流量实时统计 (过去2秒累积) ---")
        client_stats = b.get_table("client_stats")
        
        # 遍历 BPF Map 打印数据
        for k, v in client_stats.items():
            ip_str = inet_ntoa(k.value)
            print(f"来自客户端 IP: {ip_str:<15} | 数据包数: {v.packets:<8} | 流量: {v.bytes / 1024:.2f} KB")
        print("")
        
        # 可选：每次打印后清空统计，实现计算每秒速率
        # client_stats.clear()

except KeyboardInterrupt:
    print("\n[*] 正在卸载 XDP 程序...")
finally:
    b.remove_xdp(INTERFACE, 0)
    print("[*] 已退出。")
