#!/usr/bin/env python3

import csv
import collections
import datetime
import email.utils
import git
import hashlib
import itertools
import os
import simdjson
import subprocess
import json
from textwrap import dedent, indent
from tabulate import tabulate
from typing import Dict, Tuple

AK_INDEX = 0
AZ_INDEX = 3
GA_INDEX = 10
NC_INDEX = 27
NV_INDEX = 33
PA_INDEX = 38

BATTLEGROUND_STATES = ["Alaska", "Arizona", "Georgia", "North Carolina", "Nevada", "Pennsylvania"]
STATE_INDEXES = range(51) # 50 States + DC

CACHE_DIR = '_cache'
# Bump this with any changes to `fetch_all_records`
CACHE_VERSION = 2

def git_commits_for(path):
    return subprocess.check_output(['git', 'log', "--format=%H", path]).strip().decode().splitlines()

def git_show(ref, name, repo_client):
    commit_tree = repo_client.commit(ref).tree

    return commit_tree[name].data_stream.read()

def fetch_all_records():
    commits = git_commits_for("results.json")

    repo = git.Repo('.', odbt=git.db.GitCmdObjectDB)

    out = []

    parser = simdjson.Parser()
    for ref in commits:
        cache_path = os.path.join(CACHE_DIR, ref[:2], ref[2:] + ".json")
        if os.path.exists(cache_path):
            with open(cache_path) as fh:
                try:
                    record = simdjson.load(fh)
                except ValueError:
                    continue
                if record['version'] == CACHE_VERSION:
                    for row in record['rows']:
                        out.append(InputRecord(*row))
                    continue
        blob = git_show(ref, 'results.json', repo)
        json = parser.parse(blob)
        timestamp = json['meta']['timestamp']
        rows = []
        for index in STATE_INDEXES:
            race = json['data']['races'][index]

            county_tot_exp_vote = [county['tot_exp_vote'] for county in race['counties']]
            if all(county_tot_exp_vote):
                tot_exp_vote = sum(county_tot_exp_vote)
            else:
                # NYT's data doesn't seem to have a county-by-county vote
                # estimate for the following states: Connecticut, Massachusetts,
                # Maine, New Hampshire, Rhode Island, and Vermont. So we have to
                # fall back to the state wide one.
                tot_exp_vote = race['tot_exp_vote']

            record = InputRecord(
                timestamp,
                race['state_name'],
                race['state_id'],
                race['electoral_votes'],
                race['candidates'].as_list(),
                race['votes'],
                tot_exp_vote,
                race['precincts_total'],
                race['precincts_reporting'],
                {n['name']: n['votes'] for n in race['counties']},
            )
            rows.append(record)
            out.append(record)
        try:
            os.makedirs(os.path.dirname(cache_path))
        except FileExistsError:
            pass
        with open(cache_path, 'w') as fh:
            simdjson.dump({"version": CACHE_VERSION, "rows": rows}, fh)

    out.sort(key=lambda row: row.timestamp)
    grouped = collections.defaultdict(list)
    for row in out:
        grouped[row.state_name].append(row)

    return grouped

InputRecord = collections.namedtuple(
    'InputRecord',
    [
        'timestamp',
        'state_name',
        'state_abbrev',
        'electoral_votes',
        'candidates',
        'votes',
        'expected_votes',
        'precincts_total',
        'precincts_reporting',
        'counties',
    ],
)

# Information that is shared across loop iterations
IterationInfo = collections.namedtuple(
    'IterationInfo',
    ['vote_diff', 'votes', 'precincts_reporting', 'hurdle', 'leading_candidate_name', 'counties', 'candidate_votes']
)

IterationSummary = collections.namedtuple(
    'IterationSummary',
    [
        'timestamp',
        'leading_candidate_name',
        'trailing_candidate_name',
        'leading_candidate_votes',
        'trailing_candidate_votes',
        'vote_differential',
        'votes_remaining',
        'new_votes',
        'new_votes_relevant',
        'new_votes_formatted',
        'leading_candidate_partition',
        'trailing_candidate_partition',
        'precincts_reporting',
        'precincts_total',
        'hurdle',
        'hurdle_change',
        'hurdle_mov_avg',
        'counties_partition',
        'total_votes_count',
    ]
)

def compute_hurdle_sma(summarized_state_data, newest_votes, new_partition_pct, trailing_candidate_name):
    """
    trend gain of last 30k (or more) votes for trailing candidate
    """
    hurdle_moving_average = None
    MIN_AGG_VOTES = 30000

    agg_votes = newest_votes
    agg_c2_votes = round(new_partition_pct * newest_votes)
    step = 0
    while step < len(summarized_state_data) and agg_votes < MIN_AGG_VOTES:
        this_summary = summarized_state_data[step]
        step += 1
        if this_summary.new_votes_relevant > 0:
            if this_summary.trailing_candidate_name == trailing_candidate_name:
                trailing_candidate_partition = this_summary.trailing_candidate_partition
            else:
                # Broken for 3 way race
                trailing_candidate_partition = this_summary.leading_candidate_partition

            if this_summary.new_votes_relevant + agg_votes > MIN_AGG_VOTES:
                subset_pct = (MIN_AGG_VOTES - agg_votes) / float(this_summary.new_votes_relevant)
                agg_votes += round(this_summary.new_votes_relevant * subset_pct)
                agg_c2_votes += round(trailing_candidate_partition * this_summary.new_votes_relevant * subset_pct)
            else:
                agg_votes += this_summary.new_votes_relevant
                agg_c2_votes += round(trailing_candidate_partition * this_summary.new_votes_relevant)

    if agg_votes:
        hurdle_moving_average = float(agg_c2_votes) / agg_votes

    return hurdle_moving_average


def string_summary(summary):
    thirty_ago = (datetime.datetime.utcnow() - datetime.timedelta(minutes=30))

    bumped = summary.leading_candidate_partition
    bumped_name = summary.leading_candidate_name
    if bumped < summary.trailing_candidate_partition:
      bumped = summary.trailing_candidate_partition
      bumped_name = summary.trailing_candidate_name
    bumped -= 0.50

    lead_part = summary.leading_candidate_partition - 50

    visible_hurdle = f'{summary.trailing_candidate_name} needs {summary.hurdle:.2%}' if summary.hurdle > 0 else 'Unknown'

    return [
        f'{summary.timestamp.strftime("%Y-%m-%d %H:%M")}',
        '***' if summary.timestamp > thirty_ago else '---',
        f'{summary.leading_candidate_name} up {summary.vote_differential:,}',
        f'Left (est.): {summary.votes_remaining:,}' if summary.votes_remaining > 0 else 'Unknown',
        f'Δ: {summary.new_votes_formatted} ({f"{bumped_name} +{bumped:5.01%}" if (summary.leading_candidate_partition or summary.trailing_candidate_partition)  else "n/a"})',
        f'{summary.precincts_reporting/summary.precincts_total:.2%} precincts',
        f'{visible_hurdle}',
        f'& trends  {f"{summary.hurdle_mov_avg:.2%}" if summary.hurdle_mov_avg else "n/a"}'
    ]

def html_write_state_head(state: str, state_slug: str, summary: IterationSummary):
    percentage_candidate1 = summary.leading_candidate_votes / summary.total_votes_count
    percentage_candidate2 = summary.trailing_candidate_votes / summary.total_votes_count
    return f'''
        <thead class="thead-light">
        <tr>
            <td class="text-left flag-bg" style="background-image: url('flags/{state_slug}.svg')" colspan="9">
                <span class="has-tip" data-toggle="tooltip" title="Number of electoral votes contributed by this state and total votes by each candidate.">
                    <span class="statename">{state_formatted_name[state]}</span>
                </span>
                <br>
                {summary.leading_candidate_name} leads with {summary.leading_candidate_votes:,} votes ({percentage_candidate1:5.01%}), {summary.trailing_candidate_name} trails with {summary.trailing_candidate_votes:,} votes ({percentage_candidate2:5.01%}).
            </td>
        </tr>
        <tr>
            <th class="has-tip" data-toggle="tooltip" title="When did this batch of votes get reported?">Timestamp</th>
            <th class="has-tip" data-toggle="tooltip" title="Which candidate currently leads this state?">In The Lead</th>
            <th class="has-tip numeric" data-toggle="tooltip" title="How many votes separate the two candidates?">Vote Margin</th>
            <th class="has-tip numeric" data-toggle="tooltip" title="Approximately how many votes are remaining to be counted? These values might be off! Consult state websites and officials for the most accurate and up-to-date figures.">Votes Remaining (est.)</th>
            <th class="has-tip numeric" data-toggle="tooltip" title="How many votes were reported in this batch, excluding third-party votes. If available, an estimated county-level breakdown is shown, but it might be inaccurate!">Change</th>
            <th class="has-tip" data-toggle="tooltip" title="How were the votes in this batch, excluding third-party votes, split between the two candidates?">Batch Breakdown</th>
            <th class="has-tip" data-toggle="tooltip" title="How has the trailing candidate's share of recent batches trended? Computed using a moving average of previous 30k votes.">Batch Trend</th>
            <th class="has-tip" data-toggle="tooltip" title="What percentage of estimated remaining votes (excluding third-party votes) does the trailing candidate need to take the lead? Since this assumes the proportion of third-party votes remains the same, this might deviate from the breakdown in batches where this is not the case.">Hurdle</th>
        </tr>
        </thead>
    '''

def html_summary(state_slug: str, summary: IterationSummary, always_visible: bool):
    if summary.counties_partition:
        counties_partition = sorted(summary.counties_partition.items(), key=lambda x: x[1], reverse=True)
        counties_total = sum(summary.counties_partition.values())
        counties_tooltip_attributes = (
            'class="has-tip numeric" data-toggle="tooltip" data-html="true" data-title="' +
            '<strong>Estimated county-level breakdown:</strong><br>' +
            '<br>'.join(f'<strong>{name}:</strong> {value / counties_total:.0%}' for name, value in counties_partition) +
            '" title="' +
            'Estimated county-level breakdown: ' +
            ', '.join(f'{name}: {value / counties_total:.0%}' for name, value in counties_partition) +
            '"'
        )
    else:
        counties_tooltip_attributes = 'class="numeric"'
    shown_votes_remaining = f'{summary.votes_remaining:,}' if summary.votes_remaining > 0 else 'Unknown'
    always_visible_attribute = ' class="always-visible"' if always_visible else ''
    html = f'''
        <tr id="{state_slug}-{summary.timestamp.isoformat()}"{always_visible_attribute}>
            <td class="timestamp">{summary.timestamp.strftime('%Y-%m-%d %H:%M:%S')} UTC</td>
            <td class="{summary.leading_candidate_name}">{summary.leading_candidate_name}</td>
            <td class="numeric">{summary.vote_differential:,}</td>
            <td class="numeric">{shown_votes_remaining}</td>
            <td {counties_tooltip_attributes}>{summary.new_votes_formatted}</td>
    '''

    if summary.leading_candidate_partition == 0 and summary.trailing_candidate_partition == 0:
        html += '<td>N/A</td>'
    elif summary.leading_candidate_partition >= 0 or summary.trailing_candidate_partition >= 0:
        # Since both must now positive or zero, so we're abs them because `-0`
        # is a thing for floats.
        html += f'''
            <td>
                {summary.leading_candidate_name} {abs(summary.leading_candidate_partition):5.01%} /
                {abs(summary.trailing_candidate_partition):5.01%} {summary.trailing_candidate_name}
            </td>
        '''
    else:
        html += '<td>Unknown</td>'

    if summary.hurdle_mov_avg and summary.hurdle_mov_avg >= 0:
        html += f'''
            <td>
                {summary.trailing_candidate_name} is averaging {summary.hurdle_mov_avg:5.01%}
            </td>
        '''
    elif summary.hurdle_mov_avg and summary.hurdle_mov_avg < 0:
        html += '<td>Unknown</td>'
    else:
        html += '<td>N/A</td>'
    
    visible_hurdle = f'{summary.trailing_candidate_name} needs {summary.hurdle:.1%}' if summary.hurdle > 0 else 'Unknown'
    html += f'''
            <td class="hurdle">{visible_hurdle}</td>
        </tr>
    '''

    return html

# Capture the time at the top of the main script logic so it's closer to when the pull of data happened
scrape_time = datetime.datetime.utcnow()

# Dict[str, List[InputRecords]]
records = fetch_all_records()

# Where we’ll aggregate the data from the JSON files
summarized = {}
state_formatted_name = {}
state_abbrev = {}

def json_to_summary(
    state_name: str,
    row: InputRecord,
    last_iteration_info: IterationInfo,
    latest_candidate_votes: Dict[str, int],
    expected_votes: int,
    batch_time: datetime.datetime,
) -> Tuple[IterationInfo, IterationSummary]:
    timestamp = datetime.datetime.strptime(row.timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')

    # Retrieve relevant data from the state’s JSON blob
    candidate1 = row.candidates[0] # Leading candidate
    candidate2 = row.candidates[1] # Trailing candidate
    candidate1_name = candidate1['last_name']
    candidate2_name = candidate2['last_name']
    candidate1_votes = candidate1['votes']
    candidate2_votes = candidate2['votes']
    candidate1_key = candidate1['candidate_key']
    candidate2_key = candidate2['candidate_key']
    total_votes = sum([candidate['votes'] for candidate in row.candidates])
    vote_diff = candidate1_votes - candidate2_votes
    votes = row.votes
    votes_remaining = expected_votes - votes
    precincts_reporting = row.precincts_reporting
    precincts_total = row.precincts_total
    new_votes = 0 if last_iteration_info.votes is None else (votes - last_iteration_info.votes)

    counties_partition = {}
    if new_votes != 0:
        assert row.counties.keys() == last_iteration_info.counties.keys()
        for k, v in row.counties.items():
            partition = (v - last_iteration_info.counties[k])
            if partition > 0:
                counties_partition[k] = partition

    bumped = candidate1_name != last_iteration_info.leading_candidate_name

    # If we're trying to estimate the proportion of third-party votes across
    # the entire vote population, we want to use the largest sample size we
    # have - i.e. the newest data we've received
    latest_relevant_proportion = (latest_candidate_votes[candidate1_key] + latest_candidate_votes[candidate2_key]) / sum(latest_candidate_votes.values())
    votes_remaining_relevant = votes_remaining * latest_relevant_proportion
    hurdle = (vote_diff + votes_remaining_relevant) / (2 * votes_remaining_relevant) if votes_remaining_relevant > 0 else 0

    candidate_votes = {c['candidate_key']: c['votes'] for c in row.candidates}

    # We need to use the votes delta for our two leading candidates, not for
    # all the candidates, when calculating the breakdown - especially since our
    # data source frequently revises write-in and third party figures.
    if new_votes == 0:
        new_votes_relevant = 0
    else:
        new_votes_relevant = sum(candidate_votes[k] - last_iteration_info.candidate_votes[k] for k in (candidate1_key, candidate2_key))

    if new_votes_relevant == 0:
        trailing_candidate_partition = 0
        leading_candidate_partition = 0
    else:
        trailing_candidate_partition = (candidate2_votes - last_iteration_info.candidate_votes[candidate2_key]) / new_votes_relevant
        leading_candidate_partition = 1 - trailing_candidate_partition

    # Info we’ll need for the next loop iteration
    iteration_info = IterationInfo(
        vote_diff=vote_diff,
        votes=votes,
        precincts_reporting=precincts_reporting,
        hurdle=hurdle,
        leading_candidate_name=candidate1_name,
        counties=row.counties,
        candidate_votes=candidate_votes,
    )

    # Compute aggregate of last 5 hurdle, if available
    hurdle_mov_avg = compute_hurdle_sma(summarized[state_name], new_votes_relevant, trailing_candidate_partition, candidate2_name)

    summary = IterationSummary(
        batch_time,
        candidate1_name,
        candidate2_name,
        candidate1_votes,
        candidate2_votes,
        vote_diff,
        votes_remaining,
        new_votes,
        new_votes_relevant,
        f"{new_votes_relevant:7,}" if new_votes_relevant >= 0 else "Unknown",
        leading_candidate_partition,
        trailing_candidate_partition,
        precincts_reporting,
        precincts_total,
        hurdle,
        hurdle-(1-last_iteration_info.hurdle if bumped else last_iteration_info.hurdle),
        hurdle_mov_avg,
        counties_partition,
        total_votes,
    )

    return iteration_info, summary

states_updated = []

for rows in records.values():
    latest_batch_time = datetime.datetime.strptime(rows[-1].timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')

    state_name = rows[0].state_name

    summarized[state_name] = []
    state_formatted_name[state_name] = f"{state_name} (EV: {rows[0].electoral_votes})"
    state_abbrev[state_name] = rows[0].state_abbrev

    last_iteration_info = IterationInfo(
        vote_diff=None,
        votes=None,
        precincts_reporting=None,
        hurdle=0,
        leading_candidate_name=None,
        counties=None,
        candidate_votes=None,
    )

    latest_candidate_votes = {c['candidate_key']: c['votes'] for c in rows[-1].candidates}
    latest_expected_votes = rows[-1].expected_votes

    for row in rows:
        iteration_info, summary = json_to_summary(
            state_name,
            row,
            last_iteration_info,
            latest_candidate_votes,
            latest_expected_votes,
            batch_time=datetime.datetime.strptime(row.timestamp, '%Y-%m-%dT%H:%M:%S.%fZ'),
        )

        # Avoid writing duplicate rows
        if last_iteration_info == iteration_info:
            continue

        # Generate the string we’ll output and store it
        summarized[state_name].insert(0, summary)

        # Save info for the next iteration
        last_iteration_info = iteration_info

    if summarized[state_name] and summarized[state_name][0].timestamp == latest_batch_time:
        states_updated.append(state_name)

# Pull out the battleground state summaries
battlegrounds_summarized = {
    state: summarized[state]
    for state in BATTLEGROUND_STATES
}
battleground_states_updated = [
    state for state in states_updated
    if state in BATTLEGROUND_STATES
]

# print the summaries
batch_time = max(itertools.chain.from_iterable(battlegrounds_summarized.values()), key=lambda s: s.timestamp).timestamp
with open("results.json", "r", encoding='utf8') as f:
    RESULTS_HASH = hashlib.sha256(f.read().encode('utf8')).hexdigest()

def txt_output(path, summarized, states_updated):
    with open(path, "w") as f:
        print(tabulate([
            ["Last updated:", scrape_time.strftime("%Y-%m-%d %H:%M UTC")],
            ["Latest batch received: {}".format(f"({', '.join(states_updated)})" if states_updated else ""), batch_time.strftime("%Y-%m-%d %H:%M UTC")],
            ["Prettier web version:", "https://alex.github.io/nyt-2020-election-scraper/battleground-state-changes.html"],
        ]), file=f)

        for (state, timestamped_results) in sorted(summarized.items()):
            print(f'\n{state_formatted_name[state]} Total Votes: ({timestamped_results[0][1]}: {timestamped_results[0][3]:,}, {timestamped_results[0][2]}: {timestamped_results[0][4]:,})', file=f)
            print(tabulate([string_summary(summary) for summary in timestamped_results]), file=f)

txt_output("battleground-state-changes.txt", battlegrounds_summarized, battleground_states_updated)
txt_output("all-state-changes.txt", summarized, states_updated)

def html_table(summarized):
    # The NYTimes array of states is not sorted alphabetically, so we'll use `sorted`
    for (state, timestamped_results) in sorted(summarized.items()):
        # 'Alaska (3)' -> 'alaska', 'North Carolina (15)' -> 'north-carolina'
        state_slug = state.split('(')[0].strip().replace(' ', '-').lower()
        yield f"<div class='table-responsive'><table id='{state_slug}' class='table table-bordered'>"
        yield html_write_state_head(state, state_slug, timestamped_results[0])
        visible_rows = 0
        last_change_value = None
        for summary in timestamped_results:
            change_value = summary.new_votes_formatted.strip()
            if visible_rows >= 3:
                always_visible = False
            elif {change_value, last_change_value} <= {'0', 'Unknown'}:
                always_visible = False
            else:
                always_visible = True
                visible_rows += 1

            yield html_summary(state_slug=state_slug, summary=summary, always_visible=always_visible)
            last_change_value = change_value
        yield "</table></div><hr>"

def html_output(path, html_table, states_updated, other_page_html):
    html_template = "<!-- Don't update me by hand, I'm generated by a program -->\n\n"
    with open("battleground-state-changes.html.tmpl", "r", encoding='utf8') as f:
        html_template += f.read()
    TEMPLATE_HASH = hashlib.sha256(html_template.encode('utf8')).hexdigest()

    states_updated_abbrev = [state_abbrev[state_name] for state_name in states_updated]

    with open(path,"w", encoding='utf8') as f:
        page_metadata = json.dumps({
            "template_hash": TEMPLATE_HASH,
            "results_hash": RESULTS_HASH,
            "states_updated": states_updated_abbrev,
        })

        html = html_template \
            .replace('{% TABLES %}', "\n".join(html_table)) \
            .replace('{% SCRAPE_TIME %}', scrape_time.strftime('%Y-%m-%d %H:%M:%S UTC')) \
            .replace('{% BATCH_TIME %}', batch_time.strftime('%Y-%m-%d %H:%M:%S UTC')) \
            .replace('{% LAST_BATCH %}', f"({', '.join(states_updated)})" if states_updated else "") \
            .replace('{% TEMPLATE_HASH %}', TEMPLATE_HASH) \
            .replace('{% PAGE_METADATA %}', page_metadata) \
            .replace('{% OTHER_PAGE_TEXT %}', other_page_html)

        f.write(html)

html_output(
    "battleground-state-changes.html",
    html_table(battlegrounds_summarized),
    battleground_states_updated,
    'Data for all 50 states and DC is <a href="all-state-changes.html">also available</a>.'
)
html_output(
    "all-state-changes.html",
    html_table(summarized),
    states_updated,
    'View <a href="battleground-state-changes.html">battleground states only</a>.'
)

def csv_output(path, summarized):
    with open(path, 'w') as csvfile:
        wr = csv.writer(csvfile)
        wr.writerow(('state',) + IterationSummary._fields)
        for state, results in summarized.items():
            for row in results:
                wr.writerow((state_formatted_name[state],) + row)

csv_output('battleground-state-changes.csv', battlegrounds_summarized)
csv_output('all-state-changes.csv', summarized)

def rss_output(path, summarized):
    with open(path, 'w') as rssfile:
        print(dedent(f'''
            <?xml version="1.0" encoding="UTF-8"?>
            <rss version="2.0">
             <channel>
              <title>NYT 2020 Election Scraper RSS Feed</title>
              <link>https://alex.github.io/nyt-2020-election-scraper/battleground-state-changes.html</link>
              <description>Latest results from battleground states.</description>
              <lastBuildDate>{email.utils.formatdate(batch_time.timestamp())}</lastBuildDate>
            ''').strip(), file=rssfile)

        for state, results in summarized.items():
            try:
                result = results[0]
            except IndexError:
                continue
            state_slug = state.split('(')[0].strip().replace(' ', '-').lower()
            timestamp = result.timestamp.timestamp()
            print(indent(dedent(f'''
                <item>
                 <description>{state_formatted_name[state]}: {result.leading_candidate_name} +{result.vote_differential}</description>
                 <pubDate>{email.utils.formatdate(timestamp)}</pubDate>
                 <guid isPermaLink="false">{state_slug}@{timestamp}</guid>
                </item>
            ''').strip(), "  "), file=rssfile)

        print(dedent('''
             </channel>
            </rss>'''
        ), file=rssfile)

rss_output('battleground-state-changes.xml', battlegrounds_summarized)
rss_output('all-state-changes.xml', summarized)

# this file is deprecated and should not be used! it will be removed soontm
with open("notification-updates.json", "w") as f:
    f.write(json.dumps({
        "_comment": "this file is deprecated and should not be used",
        "results_hash": RESULTS_HASH,
        "states_updated": battleground_states_updated,
    }))
