import os
from dotenv import load_dotenv
from neo4j import GraphDatabase

# Load environment variables from .env file
load_dotenv()

# Get Neo4j connection info from environment variables
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USER = os.getenv('NEO4J_USER')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')

# Function to connect to the Neo4j database and remove duplicate relationships while keeping one of each type
def remove_duplicate_relationships(uri, user, password):
    driver = GraphDatabase.driver(uri, auth=(user, password))

    with driver.session() as session:
        # Cypher query to find nodes with duplicate relationships
        identify_query = """
        MATCH (n)-[r]->(m)
        WITH type(r) AS relType, n AS startNode, m AS endNode, collect(id(r)) AS relIds
        WHERE size(relIds) > 1
        RETURN relType, startNode, endNode, relIds
        ORDER BY size(relIds) DESC
        """

        result = session.run(identify_query)

        for record in result:
            relationship_type = record['relType']
            start_node = record['startNode']
            end_node = record['endNode']
            rel_ids = record['relIds']

            # Keep only one relationship of each type and delete the rest
            if len(rel_ids) > 1:
                keep_rel_id = rel_ids[0]
                delete_rel_ids = [rid for rid in rel_ids if rid != keep_rel_id]

                # Delete extra relationships
                for del_rel_id in delete_rel_ids:
                    session.run(
                        "MATCH ()-[r]->() WHERE id(r) = $relId DELETE r",
                        relId=del_rel_id
                    )

                    print(f"Deleted relationship with ID: {del_rel_id}")

            print("Remaining Relationship ID:", keep_rel_id)
            print("\n")

    driver.close()

# Call the function with your connection info
remove_duplicate_relationships(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)