import pandas as pd

SRC = "nsw_road_crash_data_2020-2024_crash.xlsx"
OUT = "cyclistsafe_cyclist_crashes.csv"

def find_col(df, keywords, default=None):
    cols = [str(c) for c in df.columns]
    for c in cols:
        name = c.lower().strip()
        if all(k.lower() in name for k in keywords):
            return c
    return default

def main():
    print("Loading dataset...")
    df = pd.read_excel(SRC, sheet_name=0, engine="openpyxl")
    print(f"Total rows loaded: {len(df)}")

    for col in ["Key TU type", "Other TU type"]:
        if col not in df.columns:
            raise KeyError(f"Expected column '{col}' not found in dataset.")
        df[col] = df[col].astype(str)

    cyclist_mask = (
        df["Key TU type"].str.contains("Pedal cycle", case=False, na=False) |
        df["Other TU type"].str.contains("Pedal cycle", case=False, na=False)
    )
    cdf = df[cyclist_mask].copy()
    print(f"Filtered cyclist crash records: {len(cdf)}")

    lga_src  = find_col(cdf, ["lga"])
    town_src = find_col(cdf, ["town"])

    rename_map = {}
    if "Crash ID" in cdf.columns:
        rename_map["Crash ID"] = "crash_id"
    if "Year of crash" in cdf.columns:
        rename_map["Year of crash"] = "crash_year"
    if "Month of crash" in cdf.columns:
        rename_map["Month of crash"] = "crash_month"
    if "Day of week of crash" in cdf.columns:
        rename_map["Day of week of crash"] = "day_of_week"
    if "Two-hour intervals" in cdf.columns:
        rename_map["Two-hour intervals"] = "time_interval"
    if lga_src:
        rename_map[lga_src] = "lga"
    if town_src:
        rename_map[town_src] = "town"
    if "Latitude" in cdf.columns:
        rename_map["Latitude"] = "latitude"
    if "Longitude" in cdf.columns:
        rename_map["Longitude"] = "longitude"
    if "Degree of crash" in cdf.columns:
        rename_map["Degree of crash"] = "degree"
    if "Degree of crash - detailed" in cdf.columns:
        rename_map["Degree of crash - detailed"] = "degree_detailed"
    if "Speed limit" in cdf.columns:
        rename_map["Speed limit"] = "speed_limit"
    if "Road surface" in cdf.columns:
        rename_map["Road surface"] = "road_surface"
    if "Surface condition" in cdf.columns:
        rename_map["Surface condition"] = "surface_condition"
    if "Weather" in cdf.columns:
        rename_map["Weather"] = "weather"
    if "Street lighting" in cdf.columns:
        rename_map["Street lighting"] = "street_lighting"
    if "Natural lighting" in cdf.columns:
        rename_map["Natural lighting"] = "natural_lighting"
    if "DCA - code" in cdf.columns:
        rename_map["DCA - code"] = "dca_code"
    if "DCA - description" in cdf.columns:
        rename_map["DCA - description"] = "dca_description"
    if "No. killed" in cdf.columns:
        rename_map["No. killed"] = "no_killed"
    if "No. seriously injured" in cdf.columns:
        rename_map["No. seriously injured"] = "no_serious_inj"
    if "No. moderately injured" in cdf.columns:
        rename_map["No. moderately injured"] = "no_mod_inj"
    if "No. minor-other injured" in cdf.columns:
        rename_map["No. minor-other injured"] = "no_minor_inj"

    cdf = cdf.rename(columns=rename_map)

    required = [
        "crash_id", "crash_year", "crash_month", "day_of_week",
        "time_interval", "lga", "town",
        "latitude", "longitude",
        "degree", "degree_detailed",
        "speed_limit", "road_surface", "surface_condition",
        "weather", "street_lighting", "natural_lighting",
        "dca_code", "dca_description",
        "no_killed", "no_serious_inj", "no_mod_inj", "no_minor_inj"
    ]

    for col in required:
        if col not in cdf.columns:
            cdf[col] = pd.NA

    cdf = cdf[required]
    cdf = cdf.dropna(subset=["latitude", "longitude"])
    print(f"Records with coordinates: {len(cdf)}")
    cdf.to_csv(OUT, index=False)
    print(f"Saved {len(cdf)} cyclist crash records to {OUT}")

if __name__ == "__main__":
    main()
