import os
import glob
import argparse
from pathlib import Path

import pandas as pd
import numpy as np
import h3  # top-level API


def h3_index(lat: float, lon: float, res: int):
    """
    Wrapper to handle both old and new h3-py APIs.
    """
    if hasattr(h3, "geo_to_h3"):  # older API
        return h3.geo_to_h3(lat, lon, res)
    elif hasattr(h3, "latlng_to_cell"):  # newer API (v4+)
        return h3.latlng_to_cell(lat, lon, res)
    else:
        raise AttributeError(
            "Installed h3 package has neither geo_to_h3 nor latlng_to_cell. "
            "Please check your h3 version."
        )


def add_h3_columns(
    df: pd.DataFrame,
    lat_col: str,
    lon_col: str,
    res_start: int,
    res_end: int
) -> pd.DataFrame:
    """
    Add H3 index columns for resolutions [res_start, res_end] (inclusive)
    based on latitude/longitude columns in the DataFrame.

    Invalid or out-of-range lat/lon rows will get None in the H3 columns.
    """
    if lat_col not in df.columns or lon_col not in df.columns:
        raise ValueError(f"Latitude/longitude columns '{lat_col}'/'{lon_col}' not found in dataframe.")

    # Coerce to numeric; invalid strings become NaN
    lats = pd.to_numeric(df[lat_col], errors="coerce")
    lons = pd.to_numeric(df[lon_col], errors="coerce")

    # Valid range mask
    valid_mask = (
        lats.notna()
        & lons.notna()
        & (lats >= -90) & (lats <= 90)
        & (lons >= -180) & (lons <= 180)
    )

    invalid_count = (~valid_mask).sum()
    if invalid_count > 0:
        print(f"  Warning: {invalid_count} rows have invalid lat/lon and will get empty H3 cells.")

    lat_arr = lats.to_numpy()
    lon_arr = lons.to_numpy()

    for res in range(res_start, res_end + 1):
        col_name = f"h3_r{res}"
        result = []

        for lat, lon, is_valid in zip(lat_arr, lon_arr, valid_mask.to_numpy()):
            if not is_valid:
                result.append(None)
                continue

            try:
                result.append(h3_index(float(lat), float(lon), res))
            except Exception:
                # In case H3 still complains for some row, fall back to None
                result.append(None)

        df[col_name] = result

    return df


def process_folder(
    input_dir: str,
    output_dir: str,
    lat_col: str,
    lon_col: str,
    res_start: int,
    res_end: int,
    combine_output: bool = False,
    combined_filename: str = "combined_with_h3.csv",
    pattern: str = "*.csv"
):
    """
    Process all CSV files in input_dir, add H3 columns, and write outputs.

    - If combine_output = False:
        writes one output CSV per input CSV into output_dir
    - If combine_output = True:
        writes a single combined CSV (all files appended) into output_dir/combined_filename
    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    files = sorted(glob.glob(str(input_path / pattern)))
    if not files:
        raise FileNotFoundError(f"No files matching pattern '{pattern}' found in {input_dir}")

    print(f"Found {len(files)} files to process.")

    combined_path = output_path / combined_filename
    header_written = False  # for combined mode

    for fpath in files:
        fname = os.path.basename(fpath)
        print(f"Processing {fname}...")

        df = pd.read_csv(fpath)

        df = add_h3_columns(df, lat_col=lat_col, lon_col=lon_col,
                            res_start=res_start, res_end=res_end)

        if combine_output:
            df.to_csv(
                combined_path,
                mode="a",
                index=False,
                header=not header_written
            )
            header_written = True
        else:
            out_file = output_path / fname
            df.to_csv(out_file, index=False)

    if combine_output:
        print(f"Combined output written to: {combined_path}")
    else:
        print(f"Per-file outputs written to: {output_path}")


def main():
    parser = argparse.ArgumentParser(
        description="Add H3 index columns to all CSVs in a folder."
    )
    parser.add_argument(
        "--input-dir",
        required=True,
        help="Folder containing input CSV files (e.g. d:\\sourcedata)."
    )
    parser.add_argument(
        "--output-dir",
        required=True,
        help="Folder where output CSV files will be written."
    )
    parser.add_argument(
        "--lat-col",
        required=True,
        help="Name of the latitude column in the CSV files."
    )
    parser.add_argument(
        "--lon-col",
        required=True,
        help="Name of the longitude column in the CSV files."
    )
    parser.add_argument(
        "--res-start",
        type=int,
        default=5,
        help="Starting H3 resolution (inclusive)."
    )
    parser.add_argument(
        "--res-end",
        type=int,
        default=12,
        help="Ending H3 resolution (inclusive)."
    )
    parser.add_argument(
        "--combine-output",
        action="store_true",
        help="If set, writes a single combined CSV instead of per-file outputs."
    )
    parser.add_argument(
        "--combined-filename",
        default="combined_with_h3.csv",
        help="Filename for the combined CSV when --combine-output is used."
    )
    parser.add_argument(
        "--pattern",
        default="*.csv",
        help="Glob pattern for CSV files in the input folder (default: *.csv)."
    )

    args = parser.parse_args()

    process_folder(
        input_dir=args.input_dir,
        output_dir=args.output_dir,
        lat_col=args.lat_col,
        lon_col=args.lon_col,
        res_start=args.res_start,
        res_end=args.res_end,
        combine_output=args.combine_output,
        combined_filename=args.combined_filename,
        pattern=args.pattern
    )


if __name__ == "__main__":
    main()
