import mariadb
import numpy as np
import sys
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import pandas as pd
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from datetime import datetime


# Connect to the SQLite database with a password
# Replace 'your_database.db' with the actual name of your SQLite database
# Replace 'your_password' with your desired password
# mysql workbench download

try:
    conn = mariadb.connect(
        host="vote-ninja.crq8coopcoqx.us-west-1.rds.amazonaws.com",
        user="superuser",
        password="oVTWfikVzERre414eFGP",
        port=1433,
        database="voteninja"
        )
except mariadb.Error as e:
    print(e)
    sys.exit(1)

cursor = conn.cursor()

# Assuming you have a table named 'User' with features for clustering
query = "SELECT Sex FROM User"
cursor.execute(query)
data = cursor.fetchall()

# Close the database connection
conn.close()

# Convert the data to a Pandas DataFrame for easier manipulation
columns = ['Gender']
df = pd.DataFrame(data, columns=columns)

# One-hot encode the 'Gender' column
encoder = OneHotEncoder(sparse_output=False)  # Change sparse_output to sparse
gender_encoded = encoder.fit_transform(df[['Gender']])
df_encoded = pd.DataFrame(gender_encoded, columns=['Male', 'Female'])  # Assuming 'Male' and 'Female' are the gender categories

# Create KMeans instance with explicit n_init
kmeans = KMeans(n_clusters=2, random_state=42, n_init=40)

# Fit the data to the KMeans model
kmeans.fit(df_encoded)

# Get cluster labels
labels = kmeans.labels_

# Add cluster labels to the original DataFrame
df['Cluster'] = labels

# Display total counts of males and females for each cluster
cluster_counts = df.groupby('Cluster')['Gender'].value_counts().unstack().fillna(0)
cluster_counts.plot(kind='bar', color=['red', 'blue'], alpha=0.7)
plt.title('Total Count of Males and Females in Each Cluster')
plt.xlabel('Gender')
plt.ylabel('Gender')
plt.show()
