#! /usr/bin/env python3

# Import the fastest jsons library available
try:
    import ujson as json
except ImportError:
    try:
        import simplejson as json
    except ImportError:
        import json

import argparse
import csv
import datetime
import functools
import os
import sys
from collections import namedtuple


def mergedDict(dict1, dict2):
    merged = dict1.copy()
    merged.update(dict2)
    return merged


def requireModules(modules):
    import importlib.util

    notFound = [m for m in modules if importlib.util.find_spec(m) is None]
    if notFound:

        def concat(l):
            if len(l) == 1:
                return l[0]
            else:
                return ", ".join(l[:-1]) + f", and {l[-1]}"

        print(
            f"This command requires the following additional dependencies: {concat(modules)}"
        )
        print(f"The following is/are missing: {concat(notFound)}")
        sys.exit(1)


def checkPolarsVersion():
    import polars

    major, minor, _ = tuple(map(int, polars.__version__.split(".")))
    if (major == 0) and (minor < 19):
        print(
            f"This command requires a polars installation of at least version 0.19.0 (found: {polars.__version__})"
        )
        sys.exit(1)


def loadRobust(content):
    try:
        return json.loads(content)
    except:
        print("Damaged input detected")
        content += "]}"
        return json.loads(content)


def readRobust(filename):
    with open(filename, "r") as openfile:
        return loadRobust(openfile.read())


def printWide(df):
    from itertools import repeat

    import polars as pl

    def makeLabel(eventName):
        if "_GLOBAL" in eventName:
            return "total"

        indentation = "  "

        if "/" not in eventName:
            return indentation + eventName
        parts = eventName.split("/")
        return indentation * len(parts) + parts[-1]

    df: pl.DataFrame = (
        df.sort("eid")
        .with_columns(pl.col("eid").map_elements(makeLabel, return_dtype=pl.String))
        .rename({"eid": "name"})
    )

    blocksize = 6
    assert len(df.columns) % blocksize == 1

    headerFmts = []
    bodyFmts = []
    colWidths = []
    for col in df.iter_columns():
        headerWidth = len(col.name.encode())

        def fmt(d):
            if col.dtype == pl.Float32 or col.dtype == pl.Float64:
                return len(f"{d:.2f}")
            else:
                return len(f"{d}")

        width = col.map_elements(fmt, return_dtype=pl.Int64).max()
        width = max(width, headerWidth)
        colWidths.append(width)
        sw = str(width)

        headerFmts.append("{:<" + sw + "}")

        if col.dtype == pl.Float32 or col.dtype == pl.Float64:
            bodyFmts.append("{:>" + sw + ".2f}")
        elif col.dtype == pl.String:
            bodyFmts.append("{:<" + sw + "}")
        else:
            bodyFmts.append("{:>" + sw + "}")

    headerFmt = ""
    hline = ""
    bodyFmt = ""

    for col, (h, b, w) in enumerate(zip(headerFmts, bodyFmts, colWidths)):
        if col % blocksize == 1:
            headerFmt += " │ "
        else:
            headerFmt += " "
        headerFmt += h

        if col % blocksize == 1:
            hline += "─┼─"
        else:
            hline += " "

        hline += "─" * w

        if col % blocksize == 1:
            bodyFmt += " │ "
        else:
            bodyFmt += " "
        bodyFmt += b

    print(headerFmt.format(*(df.columns)))
    print(hline)
    for row in df.iter_rows():
        print(
            bodyFmt.format(
                *map(lambda e: float("nan") if e is None else e, row)
            ).replace("nan", "   ")
        )


@functools.lru_cache
def ns_to_unit_factor(unit):
    return {
        "ns": 1,
        "us": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
        "m": 1e-9 / 60,
        "h": 1e-9 / 3600,
    }[unit]


def alignEvents(events):
    """Aligns passed events of multiple ranks and or participants.
    All ranks of a participant align at initialization, ensured by a barrier in preCICE.
    Primary ranks of all participants align after successfully establishing primary connections.
    """
    assert "events" in events and "eventDict" in events, "Not a loaded event file"
    assert len(events.get("events")) > 0, "No participants in the file"
    participants = events["events"].keys()
    grouped = events["events"]

    # Align ranks of each participant
    for participant, ranks in grouped.items():
        if len(ranks) == 1:
            continue
        print(f"Aligning {len(ranks)} ranks of {participant}")
        intraSyncID = [
            id
            for id, name in events["eventDict"].items()
            if "com.initializeIntraCom" in name
        ][0]
        syncs = {
            rank: e["ts"] + e["dur"]
            for rank, data in ranks.items()
            for e in data["events"]
            if e["eid"] == intraSyncID
        }

        firstSync = min(syncs.values())
        shifts = {rank: firstSync - tp for rank, tp in syncs.items()}

        for rank, data in ranks.items():
            for e in data["events"]:
                e["ts"] += shifts[rank]

    if len(grouped) == 1:
        return events

    # Align participants
    primaries = [name for name, ranks in events["events"].items() if 0 in ranks]

    for lonely in set(events["events"].keys()).difference(primaries):
        print(f"Cannot align {lonely} as event file of rank 0 is missing.")

    # Cannot align anything
    if len(primaries) == 1:
        return events

    # Find synchronization points
    syncEvents = [
        "m2n.acceptPrimaryRankConnection.",
        "m2n.requestPrimaryRankConnection.",
    ]
    # Event ID -> Remote name
    syncIDs = {
        id: name.rsplit(".", 1)[1]
        for id, name in events["eventDict"].items()
        if any([check in name for check in syncEvents])
    }
    syncs = {
        local: {
            remote: e["ts"] + e["dur"]
            for e in events["events"][local][0]["events"]
            for eid, remote in syncIDs.items()
            if eid == e["eid"]
        }
        for local in primaries
    }

    def hasSync(l, r):
        return (
            syncs.get(l)
            and syncs.get(l).get(r)
            and syncs.get(r)
            and syncs.get(r).get(l)
        )

    shifts = {
        (local, remote): syncs[local][remote] - syncs[remote][local]
        # all unique participant combinations
        for local in primaries
        for remote in primaries
        if local < remote
        if hasSync(local, remote)
    }
    for (local, remote), shift in shifts.items():
        print(f"Aligning {remote} ({shift}us) with {local}")
        for rank, data in events["events"][remote].items():
            data["meta"]["unix_us"] += shift
            for e in data["events"]:
                e["ts"] += shift

    return events


def groupEvents(events: [dict], initTime: int):

    # Expands event names
    def namedEvents():
        nameMap = {int(e["eid"]): e["en"] for e in events if e["et"] == "n"}
        for e in events:
            type = e["et"]
            if type != "n":
                e["eid"] = nameMap[e["eid"]]
                if type == "d":
                    e["dn"] = nameMap[e["dn"]]
                yield e

    completed = []
    active = {}  # name to event data
    stack = []

    for event in namedEvents():
        type = event["et"]

        name: str = event["eid"]
        assert isinstance(name, str)

        # Handle event starts
        if type == "b":
            # assert(name not in active.keys())
            if name in active.keys():
                print(f"Ignoring start of active event {name}")
            else:
                event["ts"] = int(event["ts"])
                fullName = "/".join(stack + [name])
                event["eid"] = fullName
                active[name] = event
                if name != "_GLOBAL":
                    stack.append(name)
        # Handle event stops
        elif type == "e":
            # assert(name in active.keys())
            if name not in active.keys():
                print(f"Ignoring end of inactive event {name}")
            else:
                begin = active[name]
                active.pop(name)
                begin["dur"] = int(event["ts"]) - begin["ts"]
                begin["ts"] = int(begin["ts"]) + initTime
                begin.pop("et")
                completed.append(begin)
                if name != "_GLOBAL":
                    assert stack[-1] == name
                    stack.pop()
        # Handle event data
        elif type == "d":
            if name not in active.keys():
                print(f"Dropping data event {name} as event isn't yet known.")
            else:
                d = active[name].get("data", {})
                dname = event["dn"]
                assert isinstance(dname, str)
                d[dname] = int(event["dv"])
                active[name]["data"] = d

    # Handle leftover events in case of a truncated input file
    if active:
        lastTS = min(map(lambda e: e["ts"] + e["dur"], completed))
        for event in active.values():
            name = event["eid"]  # This is a global id
            print(f"Truncating event without end {name}")
            begin = active[name]
            begin["ts"] = int(begin["ts"]) + initTime
            begin["dur"] = lastTS - begin["ts"]
            begin.pop("et")
            completed.append(begin)

    assert all((isinstance(e["eid"], str) for e in completed))

    return sorted(completed, key=lambda e: e["ts"])


def compressNames(events):
    allNames = set(
        (
            e["eid"]
            for participants in events.values()
            for ranks in participants.values()
            for rank in ranks.values()
            for e in ranks["events"]
        )
    )
    nameToId = {name: id for id, name in enumerate(allNames)}

    for p, ranks in events.items():
        for r, data in ranks.items():
            for e in data["events"]:
                name = e["eid"]
                assert isinstance(name, str)
                e["eid"] = nameToId[name]

    idToName = dict(zip(nameToId.values(), nameToId.keys()))
    return idToName


def loadProfilingOutputs(filenames):
    # Load all jsons
    print("Loading event files")
    jsons = []
    for i, fn in enumerate(filenames):
        jsons.append(readRobust(fn))

    # Grouping events
    print("Grouping events")
    events = {}
    for i, json in enumerate(jsons):
        name = json["meta"]["name"]
        rank = int(json["meta"]["rank"])
        unix_us = int(json["meta"]["unix_us"])
        events.setdefault(name, {})[rank] = {
            "meta": {
                "name": name,
                "rank": rank,
                "size": int(json["meta"]["size"]),
                "unix_us": unix_us,
                "tinit": json["meta"]["tinit"],
            },
            "events": groupEvents(json["events"], unix_us),
        }

    print("Compressing names")
    globalNameMap = compressNames(events)

    return {"eventDict": globalNameMap, "events": events}


class RankData:
    def __init__(self, data):
        meta = data["meta"]
        self.name = meta["name"]
        self.rank = meta["rank"]
        self.size = meta["size"]
        self.unix_us = meta["unix_us"]
        self.tinit = meta["tinit"]

        self.events = data["events"]

    @property
    def type(self):
        return "Primary (0)" if self.rank == 0 else f"Secondary ({self.rank})"

    def toListOfTuples(self, eventLookup):
        for e in self.events:
            yield (
                self.name,
                self.rank,
                eventLookup[e["eid"]],
                int(e["ts"]),
                int(e["dur"]),
            )


class Run:
    def __init__(self, filename):
        print(f"Reading events file {filename}")
        import json

        with open(filename, "r") as f:
            content = json.load(f)

        self.eventLookup = {int(id): name for id, name in content["eventDict"].items()}
        self.data = content["events"]

    def iterRanks(self):
        for pranks in self.data.values():
            for d in sorted(
                pranks.values(), key=lambda data: int(data["meta"]["rank"])
            ):
                yield RankData(d)

    def iterParticipant(self, name):
        for d in self.data[name].values():
            yield RankData(d)

    def participants(self):
        return self.data.keys()

    def lookupEvent(self, id):
        return self.eventLookup[int(id)]

    def toTrace(self, selectRanks):
        if selectRanks:
            print(f'Selected ranks: {",".join(map(str,sorted(selectRanks)))}')

        def filterop(rank):
            return True if not selectRanks else rank.rank in selectRanks

        pids = {name: id for id, name in enumerate(self.participants())}
        tids = {
            (rank.name, rank.rank): id
            for id, rank in enumerate(filter(filterop, self.iterRanks()))
        }
        metaEvents = [
            {"name": "process_name", "ph": "M", "pid": pid, "args": {"name": name}}
            for name, pid in pids.items()
        ] + [
            {
                "name": "thread_name",
                "ph": "M",
                "pid": pid,
                "tid": tids[(rank.name, rank.rank)],
                "args": {"name": rank.type},
            }
            for name, pid in pids.items()
            for rank in filter(filterop, self.iterParticipant(name))
        ]

        mainEvents = []
        for rank in filter(filterop, self.iterRanks()):
            pid, tid = pids[rank.name], tids[(rank.name, rank.rank)]
            for e in rank.events:
                en = self.lookupEvent(e["eid"])
                mainEvents.append(
                    mergedDict(
                        {
                            "name": en,
                            "cat": "Solver" if en.startswith("solver") else "preCICE",
                            "ph": "X",  # complete event
                            "pid": pid,
                            "tid": tid,
                            "ts": e["ts"],
                            "dur": e["dur"],
                        },
                        {} if "data" not in e else {"args": e["data"]},
                    )
                )

        return {"traceEvents": metaEvents + mainEvents}

    def allDataFields(self):
        return list(
            {
                dname
                for rank in self.iterRanks()
                for e in rank.events
                if "data" in e
                for dname in e["data"].keys()
            }
        )

    def toExportList(self, unit, dataNames):
        factor = ns_to_unit_factor(unit) * 1e3 if unit else 1

        def makeData(e):
            if "data" not in e:
                return tuple("" for dname in dataNames)

            return tuple(e["data"].get(dname, "") for dname in dataNames)

        for rank in self.iterRanks():
            for e in rank.events:
                yield (
                    rank.name,
                    rank.rank,
                    rank.size,
                    self.lookupEvent(e["eid"]),
                    e["ts"],
                    e["dur"] * factor,
                ) + makeData(e)

    def toDataFrame(self):
        import itertools

        import polars as pl

        columns = ["participant", "rank", "eid", "ts", "dur"]
        df = pl.DataFrame(
            data=itertools.chain.from_iterable(
                map(lambda r: r.toListOfTuples(self.eventLookup), self.iterRanks())
            ),
            schema=[
                ("participant", pl.Utf8),
                ("rank", pl.Int32),
                ("eid", pl.Utf8),
                ("ts", pl.Int64),
                ("dur", pl.Int64),
            ],
        ).with_columns([pl.col("ts").cast(pl.Datetime("us"))])
        return df


def traceCommand(profilingfile, outfile, rankfilter, limit):
    run = Run(profilingfile)
    selection = (
        set()
        .union(rankfilter if rankfilter else [])
        .union(range(limit) if limit else [])
    )
    traces = run.toTrace(selection)
    print(f"Writing to {outfile}")
    with open(outfile, "w") as outfile:
        json.dump(traces, outfile)
    return 0


def exportCommand(profilingfile, outfile, unit):
    run = Run(profilingfile)
    dataFields = run.allDataFields()
    fieldnames = [
        "participant",
        "rank",
        "size",
        "event",
        "timestamp",
        "duration",
    ] + dataFields
    print(f"Writing to {outfile}")
    with open(outfile, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(fieldnames)
        writer.writerows(run.toExportList(unit, dataFields))
    return 0


def analyzeCommand(profilingfile, participant, event, outfile=None, unit="us"):
    requireModules(["polars"])
    import polars as pl

    checkPolarsVersion()

    run = Run(profilingfile)
    df = run.toDataFrame()

    participants = set(df.get_column("participant").unique())
    assert (
        participant in participants
    ), f"Given participant {participant} doesn't exist. Known: " + ", ".join(
        participants
    )

    print(f"Output timing are in {unit}.")

    # Filter by participant
    # Convert duration to requested unit
    dur_factor = 1000 * ns_to_unit_factor(unit)
    df = (
        df.filter(pl.col("participant") == participant)
        .drop("participant")
        .with_columns(
            (pl.col("dur") * dur_factor),
            (pl.col("dur") / pl.col("dur").max().over(["rank"]) * 100).alias("rel"),
        )
    )

    ranks = df.select("rank").unique()

    if len(ranks) == 1:
        joined = (
            df.group_by("eid")
            .agg(
                pl.sum("dur").alias("sum"),
                pl.sum("rel").alias("%"),
                pl.count("dur").alias("count"),
                pl.mean("dur").alias("mean"),
                pl.min("dur").alias("min"),
                pl.max("dur").alias("max"),
            )
            .sort("eid")
        )
    else:
        ldf = df.lazy()
        rankAdvance = (
            ldf.filter((pl.col("eid") == event) & (pl.col("rank") > 0))
            .group_by("rank")
            .agg(pl.col("dur").sum())
            .sort("dur")
            .collect()
        )
        minSecRank = rankAdvance.select(pl.first("rank")).item()
        maxSecRank = rankAdvance.select(pl.last("rank")).item()

        if minSecRank is None or maxSecRank is None:
            ranksToPrint = (0,)
            print(
                "Selection only contains the primary rank 0 as event isn't available on secondary ranks."
            )
        elif minSecRank == maxSecRank:
            ranksToPrint = (0, minSecRank)
            print(
                f"Selection contains the primary rank 0 and secondary rank {minSecRank}."
            )
        else:
            ranksToPrint = (0, minSecRank, maxSecRank)
            print(
                f"Selection contains the primary rank 0, the cheapest secondary rank {minSecRank}, and the most expensive secondary rank {maxSecRank}."
            )

        ldf = df.lazy()
        joined = (
            pl.concat(
                [
                    (
                        ldf.filter(pl.col("rank") == rank)
                        .group_by("eid")
                        .agg(
                            pl.sum("dur").alias(f"R{rank}:sum"),
                            pl.sum("rel").alias(f"R{rank}:%"),
                            pl.count("dur").alias(f"R{rank}:count"),
                            pl.mean("dur").alias(f"R{rank}:mean"),
                            pl.min("dur").alias(f"R{rank}:min"),
                            pl.max("dur").alias(f"R{rank}:max"),
                        )
                    )
                    for rank in ranksToPrint
                ],
                how="align",
            )
            .sort("eid")
            .collect()
        )

    printWide(joined)

    if outfile:
        print(f"Writing to {outfile}")
        joined.write_csv(outfile)

    return 0


def histogramCommand(profilingfile, outfile, participant, event, rank, bins, unit="us"):
    requireModules(["polars", "matplotlib"])
    import matplotlib.pyplot as plt
    import polars as pl

    checkPolarsVersion()

    run = Run(profilingfile)
    df = run.toDataFrame()

    # Check user input
    assert df.select(
        pl.col("participant").is_in([participant]).any()
    ).item(), f"Given participant {participant} doesn't exist."
    assert df.select(
        pl.col("eid").is_in([event]).any()
    ).item(), f"Given event {event} doesn't exist."
    if not rank is None:
        assert df.select(
            pl.col("rank").is_in([rank]).any()
        ).item(), f"Given rank {rank} doesn't exist."

    # Filter by participant and event
    filter = (pl.col("participant") == participant) & (pl.col("eid") == event)
    # Optionally filter by rank
    if not rank is None:
        filter = filter & (pl.col("rank") == rank)

    # duration scaling factor
    dur_factor = 1000 * ns_to_unit_factor(unit)

    # Query durations
    durations = df.filter(filter).select(pl.col("dur") * dur_factor)

    ranks = df["rank"].unique()
    rankDesc = (
        "ranks: " + ",".join(map(str, ranks))
        if len(ranks) < 5
        else f"{len(ranks)} ranks"
    )

    fig, ax = plt.subplots(figsize=(14, 7), tight_layout=True)
    ax.set_title(f'Histogram of event "{event}" on {participant} ({rankDesc})')
    ax.set_xlabel(f"Duration [{unit}]")
    ax.set_ylabel("Occurrence")
    hist_data = ax.hist(durations, bins=bins, histtype="barstacked", align="mid")
    ax.bar_label(hist_data[2])
    binborders = hist_data[1]
    ax.set_xticks(binborders, labels=[f"{d:,.2f}" for d in binborders], rotation=90)
    ax.set_xlim(left=min(binborders), right=max(binborders))
    ax.grid(axis="x")

    if outfile:
        print(f"Writing to {outfile}")
        plt.savefig(outfile)
    else:
        plt.show()

    return 0


def detectFiles(files):
    # Default to loading from the local directory
    if not files:
        files = ["."]

    def searchDir(directory):
        assert os.path.isdir(directory)
        import glob

        return glob.glob(os.path.join(directory, "*-*-*.json"))

    resolved = []
    for path in files:
        if os.path.isfile(path):
            resolved.append(path)
            continue
        if os.path.isdir(path):
            candidates = [
                ".",
                "precice-profiling",
                "../precice-profiling",
                "../../precice-profiling",
            ]
            found = False
            for candidate in candidates:
                dir = os.path.join(path, candidate)
                if not os.path.isdir(dir):
                    continue
                detected = searchDir(dir)
                if len(detected) == 0:
                    continue
                print(f"Found {len(detected)} files in {path}")
                resolved += detected
                found = True
                break

            if not found:
                print(f"Nothing found in {path}")

        else:
            print(f'Cannot interpret "{path}"')

    unique = list(set(resolved))
    if len(files) > 1:
        print(f"Found {len(unique)} profiling files in total")
    return unique


def findFilesOfLatestRun(name, sizes):
    assert len(sizes) > 1
    print(f"Found multiple runs for participant {name}")
    timestamps = []
    for size, ranks in sizes.items():
        assert len(ranks) > 0
        example = next(iter(ranks.values()))  # Get some file of this run
        timestamp = int(readRobust(example)["meta"]["unix_us"])
        timestamps.append((size, timestamp))

    # Find oldest size of newest timestamps
    size, _ = max(timestamps, key=lambda p: p[1])
    print(f"└ selected latest run of size {size}")

    return list(sizes[size].values())


def groupRuns(files):
    PieceFile = namedtuple("PieceFile", ["name", "rank", "size", "filename"])

    def info(filename):
        base = os.path.splitext(os.path.basename(filename))[0]
        parts = base.split("-")
        name = "-".join(parts[:-2])
        return PieceFile(name, int(parts[-2]), int(parts[-1]), filename)

    pieces = [info(filename) for filename in files]

    map = {}
    for n, r, s, fn in pieces:
        rankMap = map.setdefault(n, {}).setdefault(s, {})
        if r in rankMap:
            print(
                f"{fn}: \033[33mwarning:\033[0m Ignored due to conflict with '{rankMap[r]}'"
            )
        else:
            rankMap.update({r: fn})
    return map


def sanitizeFiles(files):
    map = groupRuns(files)
    filesToLoad = []
    for name, sizes in map.items():
        if len(sizes) == 1:
            print(f"Found a single run for participant {name}")
            filesToLoad += [
                filename for _, ranks in sizes.items() for filename in ranks.values()
            ]
            continue

        filesToLoad = findFilesOfLatestRun(name, sizes)
    return filesToLoad


def mergeCommand(files, outfile, align):
    resolved = detectFiles(files)
    sanitized = sanitizeFiles(resolved)
    merged = loadProfilingOutputs(sanitized)

    if align:
        merged = alignEvents(merged)

    print(f"Writing to {outfile}")
    with open(outfile, "w", newline="") as file:
        json.dump(merged, file)

    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="", epilog="")
    subparsers = parser.add_subparsers(
        title="commands",
        description="Note that some commands may require polars.",
        help="additional help",
        dest="cmd",
    )
    # parser.add_argument("-v", "--verbose", help="Print verbose output")
    parser.add_argument(
        "-u",
        "--unit",
        choices=["h", "m", "s", "ms", "us"],
        default="us",
        help="The duration unit to use",
    )

    analyze_help = """Analyze profiling data of a given solver.
    Event durations are displayed in the unit of choice.
    Parallel solvers show events of the primary rank next to the secondary ranks spending the least and most time in advance of preCICE.
    """
    analyze = subparsers.add_parser(
        "analyze", help=analyze_help.splitlines()[0], description=analyze_help
    )
    analyze.add_argument("participant", type=str, help="The participant to analyze")
    analyze.add_argument(
        "profilingfile",
        nargs="?",
        type=str,
        default="profiling.json",
        help="The profiling file to process",
    )
    analyze.add_argument(
        "-e",
        "--event",
        nargs="?",
        type=str,
        default="advance",
        help="The event used to determine the most expensive and cheapest rank.",
    )
    analyze.add_argument("-o", "--output", help="Write the result to CSV file")

    def try_int(s):
        try:
            return int(s)
        except:
            return s

    histogram_help = """Plots the duration distribution of a single event of a given solver.
    Event durations are displayed in the unit of choice.
    """
    histogram = subparsers.add_parser(
        "histogram", help=histogram_help.splitlines()[0], description=histogram_help
    )
    histogram.add_argument(
        "-o",
        "--output",
        default=None,
        help="Write to file instead of displaying the plot",
    )
    histogram.add_argument(
        "-r", "--rank", type=int, default=None, help="Display only the given rank"
    )
    histogram.add_argument(
        "-b",
        "--bins",
        type=try_int,
        default="fd",
        help="Number of bins or strategy. Must be a valid argument to numpy.histogram_bin_edges",
    )
    histogram.add_argument("participant", type=str, help="The participant to analyze")
    histogram.add_argument("event", type=str, help="The event to analyze")
    histogram.add_argument(
        "profilingfile",
        nargs="?",
        type=str,
        default="profiling.json",
        help="The profiling file to process",
    )

    trace_help = "Transform profiling to the Trace Event Format."
    trace = subparsers.add_parser(
        "trace", help=trace_help.splitlines()[0], description=trace_help
    )
    trace.add_argument(
        "profilingfile",
        type=str,
        nargs="?",
        default="profiling.json",
        help="The profiling file to process",
    )
    trace.add_argument(
        "-o", "--output", default="trace.json", help="The resulting trace file"
    )
    trace.add_argument(
        "-l", "--limit", type=int, metavar="n", help="Select the first n ranks"
    )
    trace.add_argument(
        "-r", "--rank", type=int, nargs="*", help="Select individual ranks"
    )

    export_help = "Export the profiling data as a CSV file."
    export = subparsers.add_parser(
        "export", help=export_help.splitlines()[0], description=export_help
    )
    export.add_argument(
        "profilingfile",
        nargs="?",
        type=str,
        default="profiling.json",
        help="The profiling files to process",
    )
    export.add_argument(
        "-o",
        "--output",
        type=str,
        default="profiling.csv",
        help="The CSV file to export to.",
    )

    merge_help = "Merges preCICE profiling output files to a single file used by the other commands."
    merge = subparsers.add_parser(
        "merge", help=merge_help.splitlines()[0], description=merge_help
    )
    merge.add_argument(
        "files",
        nargs="*",
        help="The profiling files to process, directories to search, or nothing to autodetect",
    )
    merge.add_argument(
        "-o",
        "--output",
        type=str,
        default="profiling.json",
        help="The resulting profiling file.",
    )
    merge.add_argument(
        "-n", "--no-align", action="store_true", help="Don't align participants?"
    )
    args = parser.parse_args()

    dispatcher = {
        "trace": lambda ns: traceCommand(
            ns.profilingfile, ns.output, ns.rank, ns.limit
        ),
        "export": lambda ns: exportCommand(ns.profilingfile, ns.output, ns.unit),
        "analyze": lambda ns: analyzeCommand(
            ns.profilingfile, ns.participant, ns.event, ns.output, ns.unit
        ),
        "histogram": lambda ns: histogramCommand(
            ns.profilingfile,
            ns.output,
            ns.participant,
            ns.event,
            ns.rank,
            ns.bins,
            ns.unit,
        ),
        "merge": lambda ns: mergeCommand(ns.files, ns.output, not ns.no_align),
    }

    def showHelp(ns):
        parser.print_help()
        return 1

    sys.exit(dispatcher.get(args.cmd, showHelp)(args))
