import logging
from typing import Dict
 
import pandas as pd
 
logger = logging.getLogger(__name__)
 
 
def focus_share(
  events: pd.DataFrame,
  feature_col: str = "interaction_type",
  user_col: str = "user_id",
) -> pd.DataFrame:
  if events.empty:
    return pd.DataFrame(
      columns=[feature_col, "events", "unique_users", "share_of_events", "share_of_users"],
    )
 
  missing = {feature_col, user_col} - set(events.columns)
  if missing:
    raise ValueError(f"Missing required columns: {sorted(missing)}")
 
  grouped = (
    events.groupby(feature_col)
    .agg(events=(feature_col, "size"), unique_users=(user_col, "nunique"))
    .reset_index()
  )
 
  total_events = grouped["events"].sum()
  total_users = grouped["unique_users"].sum()
 
  grouped["share_of_events"] = grouped["events"] / total_events if total_events else 0.0
  grouped["share_of_users"] = grouped["unique_users"] / total_users if total_users else 0.0
  return grouped.sort_values("events", ascending=False).reset_index(drop=True)
 
 
def switch_summary(
  events: pd.DataFrame,
  feature_col: str = "interaction_type",
  user_col: str = "user_id",
  ts_col: str = "timestamp",
) -> Dict[str, float]:
  defaults = {
    "total_switches": 0,
    "agent_entry_rate": 0.0,
    "tab_return_rate": 0.0,
    "avg_seconds_between_switches": 0.0,
  }
 
  if events.empty:
    return defaults
 
  required = {feature_col, user_col, ts_col}
  missing = required - set(events.columns)
  if missing:
    raise ValueError(f"Missing required columns: {sorted(missing)}")
 
  df = events.copy()
  df[ts_col] = pd.to_datetime(df[ts_col], utc=True, errors="coerce")
  df = df.dropna(subset=[ts_col]).sort_values([user_col, ts_col])
  if df.empty:
    return defaults
 
  df["prev_feature"] = df.groupby(user_col)[feature_col].shift(1)
  df["prev_timestamp"] = df.groupby(user_col)[ts_col].shift(1)
  switches = df[df["prev_feature"].notna() & (df[feature_col] != df["prev_feature"])]
  total = int(switches.shape[0])
  if total == 0:
    return defaults
 
  agent_entries = switches[switches[feature_col] == "agent"]
  tab_returns = switches[(switches[feature_col] == "tab") & (switches["prev_feature"] != "tab")]
  durations = (switches[ts_col] - switches["prev_timestamp"]).dt.total_seconds()
 
  return {
    "total_switches": total,
    "agent_entry_rate": float(agent_entries.shape[0] / total),
    "tab_return_rate": float(tab_returns.shape[0] / total),
    "avg_seconds_between_switches": float(durations.mean()) if not durations.empty else 0.0,
  }
 
 
def rolling_focus_share(
  events: pd.DataFrame,
  ts_col: str = "timestamp",
  feature_col: str = "interaction_type",
  freq: str = "1D",
  window: int = 7,
) -> pd.DataFrame:
  if events.empty:
    return pd.DataFrame()
 
  required = {ts_col, feature_col}
  missing = required - set(events.columns)
  if missing:
    raise ValueError(f"Missing required columns: {sorted(missing)}")
 
  df = events.copy()
  df[ts_col] = pd.to_datetime(df[ts_col], utc=True, errors="coerce")
  df = df.dropna(subset=[ts_col])
  if df.empty:
    return pd.DataFrame()
 
  daily = (
    df.groupby([pd.Grouper(key=ts_col, freq=freq), feature_col])
    .size()
    .unstack(fill_value=0)
    .sort_index()
  )
 
  if daily.empty:
    return pd.DataFrame()
 
  totals = daily.sum(axis=1).replace(0, pd.NA)
  shares = daily.div(totals, axis=0).fillna(0.0)
  return shares.rolling(window, min_periods=1).mean()