#!/usr/bin/env python3

# This is a Python rewrite this of Ruby script:
# https://gist.github.com/JaciBrunning/6be34dfd11b7b8cfa0aab57ad260c518

import re
import subprocess
from pathlib import Path


def flatten(xss):
    return [x for xs in xss for x in xs]


lsusb_re = re.compile(r"Bus (\d+) Device \d+: ID [0-9a-f:]+ (.*)")
lsusb = []
for usb in subprocess.run("lsusb", capture_output=True, encoding="utf-8").stdout.split(
    "\n"
):
    if "root" not in usb and lsusb_re.search(usb):
        match = flatten(lsusb_re.findall(usb))
        match[0] = f"usb{int(match[0])}"
        lsusb.append(match)


sysbus_re = re.compile(r"(usb\d+) -> .+\/([0-9a-f:\.]+)\/usb\d+")
sysbus_list = flatten(
    [
        sysbus_re.findall(sysbus)
        for sysbus in subprocess.run(
            ["ls", "-l", "/sys/bus/usb/devices"], capture_output=True, encoding="utf-8"
        ).stdout.split("\n")
        if sysbus_re.search(sysbus)
    ]
)

iommu_re = re.compile(r"iommu_groups\/(\d+)\/devices\/([0-9a-f:\.]+)")
pci_re = re.compile(r"\[([0-9a-f:]+)\]")
iommu_list = []
for path in Path("/sys/kernel/iommu_groups").glob("*/devices/*"):
    if iommu_re.search(str(path)):
        match = flatten(iommu_re.findall(str(path)))
        output = subprocess.run(
            ["lspci", "-nns", match[1]], capture_output=True, encoding="utf-8"
        ).stdout
        if pci_re.search(output):
            match.append(pci_re.findall(output)[-1])
            iommu_list.append(match)

mapped = {}
for usb in lsusb:
    if usb[0] not in mapped:
        mapped[usb[0]] = {"devices": []}
        mapped[usb[0]]["iommu_count"] = 0

        iommu_group = -1

        for sysbus in sysbus_list:
            if sysbus[0] == usb[0]:
                mapped[usb[0]]["pci"] = sysbus[1]
                break

        for iommu in iommu_list:
            if iommu[1] == mapped[usb[0]]["pci"]:
                mapped[usb[0]]["iommu"] = iommu[0]
                mapped[usb[0]]["pcidetail"] = iommu[-1]
                iommu_group = iommu[0]
                break

        for iommu in iommu_list:
            if iommu[0] == iommu_group:
                mapped[usb[0]]["iommu_count"] += 1

    mapped[usb[0]]["devices"].append(usb[1])


for usb_id, usb_dict in mapped.items():
    devices = ""
    for device in usb_dict["devices"]:
        devices += f"\n  - {device}"

    print(f"""{usb_id}
 - PCI: {usb_dict["pci"]} [{usb_dict["pcidetail"]}]
 - IOMMU: Group {usb_dict["iommu"]} ({usb_dict["iommu_count"]} device(s) in the same IOMMU group)
 - USB devices: {devices}""")
