from flask import Flask, render_template_string
from neo4j import GraphDatabase
from pyvis.network import Network
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

app = Flask(__name__)

# Configure your Neo4j connection using environment variables
uri = os.getenv("NEO4J_URI")
user = os.getenv("NEO4J_USER")
password = os.getenv("NEO4J_PASSWORD")

driver = GraphDatabase.driver(uri, auth=(user, password))

def fetch_data():
    with driver.session() as session:
        result = session.run("MATCH (n:Person)-[r]->(m:Legislation) RETURN n.name AS node1, type(r) AS relationship, m.number AS node2")
        nodes = []
        edges = []

        for record in result:
            node1 = record["node1"]
            node2 = record["node2"]

            # Debugging: Check the types of node IDs
            print(f"Node 1 Type: {type(node1)}, Value: {node1}")
            print(f"Node 2 Type: {type(node2)}, Value: {node2}")

            assert isinstance(node1, (str, int)), f"Unexpected type for node ID: {type(node1)}"
            assert isinstance(node2, (str, int)), f"Unexpected type for node ID: {type(node2)}"

            nodes.append(node1)
            nodes.append(node2)
            edges.append((node1, node2, {"title": record["relationship"]}))

        return list(set(nodes)), edges

@app.route('/')
def index():
    try:
        nodes, edges = fetch_data()

        # Create a Network object
        net = Network(notebook=False, height='600px', width='100%', directed=True)

        # Add nodes and edges to the graph
        for node in nodes:
            net.add_node(node)

        for edge in edges:
            net.add_edge(edge[0], edge[1], title=edge[2]['title'])

        # Generate HTML for embedding in Flask app
        html = net.generate_html("3d-force-directed-graph.html")
        return render_template_string(html)
    except AssertionError as e:
        return f"Assertion Error: {e}", 500

if __name__ == '__main__':
    port = os.getenv("FLASK_PORT", 5000)  # Default to 5000 if not specified in .env
    app.run(debug=True, port=port)