#!/usr/bin/env python

# This file is part of DarkFi (https://dark.fi)
#
# Copyright (C) 2020-2026 Dyne.org foundation
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import asyncio, re, sys, base58
import urwid as u
import networkx as nx
from datetime import datetime
# import matplotlib.pyplot as plt
import src.rpc

from os.path import join

# this is counter-intuitive because the dag is reversed
resolved = [True, True]

def graph(event, longest_path):
    global resolved
    merge = len(event['parents']) > 1
    fork = len(event['children']) > 1

    if merge and not fork:
        if not resolved[0]:
            resolved[1] = False
        if resolved[0]:
            resolved[0] = False
    
        return "M━┑"

    if fork and not merge:
        if resolved[1]:
            resolved[0] = True
        if not resolved[1]:
            resolved[1] = True

        return "o─┘"

    if merge and fork:
        return "M━┪"
    if not merge and not fork:
        if not resolved[0]:
            if event['hash'] in longest_path:
                return "o │"
            else:
                return "│ o"
        else:
            return "o  "

# because tab character is broken in urwid texts
def indent(num):
    return " " * (8 - len(str(num)))

class ListItem(u.WidgetWrap):
    
    def __init__ (self, event, longest_path):
        g = graph(event, longest_path)
        self.content = event
        layer_num = int(event["layer"])
        layer = "layer " + str(layer_num) + indent(layer_num) if layer_num != 0 else "genesis       "
        dt = datetime.fromtimestamp(int(event['timestamp'])) if int(event['timestamp']) < 1e10 else datetime.fromtimestamp(int(int(event['timestamp'])/1000))
        dt = event['hash'][:10] + " │ " + str(dt)
        t = u.AttrMap(u.Text([('word', dt),
                              ('layer-num', " " + layer),
                              ('word', g),
                              ('cont', event['content'])], wrap="ellipsis"),
                              {'word':'datetime', 'layer-num': 'reporter', 'cont': 'content'},
                              {'word':'event_selected', 'layer-num': 'event_selected', 'cont': 'event_selected'})
        u.WidgetWrap.__init__(self, t)

    def selectable (self):
        return True
    
    def keypress(self, size, key):
        return key

class ListView(u.WidgetWrap):

    def __init__(self):
        u.register_signal(self.__class__, ['show_details'])
        self.walker = u.SimpleFocusListWalker([])
        lb = u.ListBox(self.walker)
        u.WidgetWrap.__init__(self, lb)

    def modified(self):
        focus_w, _ = self.walker.get_focus()
        if focus_w is not None:
            u.emit_signal(self, 'show_details', focus_w.content)
        else:
            print("DAG is empty!")
            exit(-1)

    def set_data(self, events, longest_path):
        events_widgets = [ListItem(e, longest_path) for e in events]
        u.disconnect_signal(self.walker, 'modified', self.modified)

        while len(self.walker) > 0:
            self.walker.pop()
        
        self.walker.extend(events_widgets)
        u.connect_signal(self.walker, "modified", self.modified)
        self.walker.set_focus(0)

class DetailView(u.WidgetWrap):
    
    def __init__ (self):
        t = u.Text("")
        u.WidgetWrap.__init__(self, t)
        
    def set_event(self, c):
        s = f'Hash: {c["hash"]}\nChildren: {c["children"]}\nParents: {c["parents"]}\nContent: {c["content"]}\nLayer: {c["layer"]}\nTimestamp: {c["timestamp"]}'
        self._w.set_text(s)

class App(object):
    
    def unhandled_input(self, key):
        if key in ('q', 'Q'):
            raise u.ExitMainLoop()
        # if key == 'r':
        #     await self.update_data(self.config)
        if key == 'enter':
            self.current_view = self.frame2
            self.loop.widget = self.frame2
        if key in ('b', 'B'):
            self.current_view = self.frame1
            self.loop.widget = self.frame1

    def show_details(self, event):
        self.view_two.set_event(event)
        
    def __init__(self):
        self.view_one = ListView()
        u.connect_signal(self.view_one, 'show_details', self.show_details)
        footer = u.AttrWrap(u.Text(" Q to exit"), "footer")
        col_rows = u.raw_display.Screen().get_cols_rows()
        h = col_rows[0] - 2
        f1 = u.Filler(self.view_one, valign='top', height=h)
        c_list = u.LineBox(f1, title="Events")
        columns = u.Columns([('weight', 100, c_list)])
        frame1 = u.AttrMap(u.Frame(body=columns, footer=footer), 'bg')
        self.frame1 = frame1
        ############
        self.view_two = DetailView()
        f2 = u.Filler(self.view_two, valign='top')
        c_details = u.LineBox(f2, title="Details")
        footer = u.AttrWrap(u.Text(" Q to exit, B to main view"), "footer")
        columns = u.Columns([('weight', 100, c_details)])
        frame2 = u.AttrMap(u.Frame(body=columns, footer=footer), 'bg')
        self.frame2 = frame2
        ##########
        self.current_view = self.frame1  # Start with View One
        self.palette = {
            ("bg",               "white",           "black"),
            ("event",            "white",           "black"),
            ("event_selected",   "white",           "light green"),
            ('datetime',         "light blue",      "black"),
            ('reporter',         "dark green",      "black"),
            ('content',          "",                "black"),
            ("footer",           "white, bold",     "dark red")
        }
    
    async def update_data(self, config, replay_mode):
            self.config = config
            dag_dict = await recreate_dag(config, replay_mode)
            dag_list = list(dag_dict.items())
            parent_child_pairs = []
            for item in dag_list:
                parents = item[1]['parents']
                child = item[0]
                for parent in parents:
                    if parent == '0' * 64:
                        continue
                    parent_child_pairs.append((parent, child))
                    
            # Create a directed graph
            dag = nx.DiGraph()

            # Add edges to the graph (this also adds nodes)
            dag.add_edges_from(parent_child_pairs)
            l = []
            topological_order = list(nx.topological_sort(dag))
            for node in reversed(topological_order):
                event_details = dag_dict.get(node) # details
                if event_details is None:
                    continue
                layer = int(event_details['layer'])
                content = event_details['content'] # event content
                timestamp = event_details['timestamp']
                children = list(dag.successors(node))
                parents = list(dag.predecessors(node))
                pattern = r'\\x[0-9A-Fa-f]{2}'
                decoded_str = str(base58.b58decode(content))
                y = re.sub(pattern, ' ', decoded_str)
                if y.startswith("b'"):
                    matches = y.replace("b'", "")[:-1]
                elif y.startswith('b"'):
                    matches = y.replace('b"', "")[:-1]
                
                l.append({"layer":f"{layer}", "hash":f"{node}", "children":children, "parents":parents, "content":f"{matches}", "timestamp": f"{timestamp}"})

            longest_path = nx.dag_longest_path(dag)
            self.view_one.set_data(l, longest_path)

    async def start(self, config, replay_mode):
        await self.update_data(config, replay_mode)
        self.loop = u.MainLoop(self.current_view, self.palette, unhandled_input=self.unhandled_input)
        self.loop.run()

async def recreate_dag(config, replay_mode):
    host = config['host']
    port = config['port']
    rpc = src.rpc.JsonRpc()
    while True:
        try:
            await rpc.start(host, port)
            break
        except OSError:
            print(f"Error: Connection Refused to '{host}:{port}', Either because the daemon is down, is currently syncing or wrong url.")
            sys.exit(-1)

    await rpc.deg_switch(True)
    await rpc.deg_switch(False)

    if replay_mode:
        json_result = await rpc._make_request("eventgraph.replay", [])
    else:
        json_result = await rpc._make_request("eventgraph.get_info", [])

    if json_result['result']['eventgraph_info']:
        return json_result['result']['eventgraph_info']['dag']
        

async def main(argv):
    val = str('127.0.0.1:26660')
    for i in range(1, len(sys.argv)):
        if sys.argv[i] == "-e":
            try:
                val = sys.argv[i+1]
            except IndexError:
                print("Please provide a value for \'-e\'")
                exit(-1)
            break
    config = {}
    try:
        config['host'], config['port'] = val.split(':')
    except ValueError:
        print("Please provide a port as in: 127.0.0.1:26660")
        exit(-1)

    replay_mode = False
    if len(argv) > 1:
        if '-r' in argv:
            replay_mode = True
        if set(argv) & set(['darkirc', 'irc']):
            config['port'] = 26660
        elif set(argv) & set(['taud', 'tau']):
            config['port'] = 23330

    app = App()
    await app.start(config, replay_mode)
    

asyncio.run(main(sys.argv))
