#!/usr/bin/env python

import sys
import json
import warnings
import pprint
import argparse

import metrics


def load_gt(f):
    res = {}

    for s in f:
        ann = json.loads(s.strip())
        if ann["id"] in res:
            warnings.warn("Key %s duplicated in ground truth" % ann["id"])

        res[ann["id"]] = ann

    return res


def convert_annotation(ann):
    bb = ann.get("bounding_boxes", [])
    res = {"name": [], "handwritten": []}

    for box in bb:
        label = box["label"]

        # identify signatures with handwritten data
        if label == "signature":
            label = "handwritten"

        if label in ["phone", "address"]:
            continue

        res[label] = res[label] + \
            [[box["x0"], box["y0"], box["x1"], box["y1"]]]

    return res


def calc_total_score(gt_file, prediction_file):
    gt = load_gt(gt_file)

    total_res = {}
    pred_count = 0

    processed_ids = set()

    for s in prediction_file:
        ann = json.loads(s.strip())
        if not ann["id"] in gt:
            continue

        if ann["id"] in processed_ids:
            print("The following id has already been processed. Skipping",
                  ann["id"])
            continue

        processed_ids.add(ann["id"])

        if ann.get("bad_quality", False):
            local_res = {"score": 0.35, "bad_quality_count": 1}
        else:
            local_res = metrics.score(convert_annotation(gt[ann["id"]]),
                                      convert_annotation(ann))

        for k, v in local_res.items():
            total_res[k] = total_res.get(k, 0) + v

        pred_count = pred_count + 1

    # average over scores
    for k, v in total_res.items():
        if k.endswith("score"):
            total_res[k] = float(v) / pred_count

    return total_res


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Sample implementation for score for task 1")
    parser.add_argument("ground_truth",
                        type=argparse.FileType("r"),
                        help="Ground truth given as jsonl file.")
    parser.add_argument("prediction",
                        type=argparse.FileType("r"),
                        help="Prediction given as jsonl file.")

    args = parser.parse_args()

    total_score = calc_total_score(args.ground_truth, args.prediction)

    pprint.pprint(total_score)