import os

import sys

import time

import traceback

import project1 as p1

import numpy as np


verbose = False



def green(s):

   return '\033[1;32m%s\033[m' % s



def yellow(s):

   return '\033[1;33m%s\033[m' % s



def red(s):

   return '\033[1;31m%s\033[m' % s



def log(*m):

   print(" ".join(map(str, m)))



def log_exit(*m):

   log(red("ERROR:"), *m)

   exit(1)



def check_real(ex_name, f, exp_res, *args):

   try:

       res = f(*args)

   except NotImplementedError:

       log(red("FAIL"), ex_name, ": not implemented")

       return True

   if not np.isreal(res):

       log(red("FAIL"), ex_name, ": does not return a real number, type: ", type(res))

       return True

   if res != exp_res:

       log(red("FAIL"), ex_name, ": incorrect answer. Expected", exp_res, ", got: ", res)

       return True



def equals(x, y):

   if type(y) == np.ndarray:

       return (x == y).all()

   return x == y



def check_tuple(ex_name, f, exp_res, *args, **kwargs):

   try:

       res = f(*args, **kwargs)

   except NotImplementedError:

       log(red("FAIL"), ex_name, ": not implemented")

       return True

   if not type(res) == tuple:

       log(red("FAIL"), ex_name, ": does not return a tuple, type: ", type(res))

       return True

   if not len(res) == len(exp_res):

       log(red("FAIL"), ex_name, ": expected a tuple of size ", len(exp_res), " but got tuple of size", len(res))

       return True

   if not all(equals(x, y) for x, y in zip(res, exp_res)):

       log(red("FAIL"), ex_name, ": incorrect answer. Expected", exp_res, ", got: ", res)

       return True



def check_array(ex_name, f, exp_res, *args):

   try:

       res = f(*args)

   except NotImplementedError:

       log(red("FAIL"), ex_name, ": not implemented")

       return True

   if not type(res) == np.ndarray:

       log(red("FAIL"), ex_name, ": does not return a numpy array, type: ", type(res))

       return True

   if not len(res) == len(exp_res):

       log(red("FAIL"), ex_name, ": expected an array of shape ", exp_res.shape, " but got array of shape", res.shape)

       return True

   if not all(equals(x, y) for x, y in zip(res, exp_res)):

       log(red("FAIL"), ex_name, ": incorrect answer. Expected", exp_res, ", got: ", res)

       return True



def check_list(ex_name, f, exp_res, *args):

   try:

       res = f(*args)

   except NotImplementedError:

       log(red("FAIL"), ex_name, ": not implemented")

       return True

   if not type(res) == list:

       log(red("FAIL"), ex_name, ": does not return a list, type: ", type(res))

       return True

   if not len(res) == len(exp_res):

       log(red("FAIL"), ex_name, ": expected a list of size ", len(exp_res), " but got list of size", len(res))

       return True

   if not all(equals(x, y) for x, y in zip(res, exp_res)):

       log(red("FAIL"), ex_name, ": incorrect answer. Expected", exp_res, ", got: ", res)

       return True



def check_get_order():

   ex_name = "Get order"

   if check_list(

           ex_name, p1.get_order,

           [0], 1):

       log("You should revert `get_order` to its original implementation for this test to pass")

       return

   if check_list(

           ex_name, p1.get_order,

           [1, 0], 2):

       log("You should revert `get_order` to its original implementation for this test to pass")

       return

   log(green("PASS"), ex_name, "")



def check_hinge_loss_single():

   ex_name = "Hinge loss single"


   feature_vector = np.array([1, 2])

   label, theta, theta_0 = 1, np.array([-1, 1]), -0.2

   exp_res = 1 - 0.8

   if check_real(

           ex_name, p1.hinge_loss_single,

           exp_res, feature_vector, label, theta, theta_0):

       return

   log(green("PASS"), ex_name, "")



def check_hinge_loss_full():

   ex_name = "Hinge loss full"


   feature_vector = np.array([[1, 2], [1, 2]])

   label, theta, theta_0 = np.array([1, 1]), np.array([-1, 1]), -0.2

   exp_res = 0.7551293635113929

   theta = np.array(

       [-0.82211481, -0.18307211, -0.81219305, -0.42792084, 0.39996585, -0.22333808, -0.85901045, -0.28555658,

        0.68972721, 0.10356705])

   feature_vector = np.array([[0.853047104420466, 0.92458851363296, -0.93669308238158, 0.485652277912542,

                               -0.227688957306831, -0.840632378917697, -0.380153896598382, 0.458101465178383,

                               -0.201164708393563, -0.708552201422551],

                              [-0.435947281040453, -0.275618285508508, -0.962833099150192, 0.492466581112431,

                               -0.950740227537879, -0.730767558031207, 0.835443727085544, -0.487593201883201,

                               -0.927305707125401, -0.332255638111472],

                              [0.546800160002319, -0.68646826675953, -0.936084671843869, -0.793932156134798,

                               -0.419237513625059, -0.289511399134108, -0.740813923795108, -0.838739192444391,

                               -0.827514919584982, -0.37754524577323],

                              [0.0290103524207108, -0.259225214310641, -0.0511485235055935, 0.0817392157563362,

                               -0.598963821928654, 0.642759524333666, 0.537532481316604, 0.426377770688738,

                               -0.623050000012637, 0.272382527684435],

                              [0.839610243377748, -0.563985043559053, -0.812711062505414, -0.775198984634139,

                               -0.503145086853907, -0.0969642883977621, -0.832946596951278, 0.868607712949865,

                               -0.735727419880641, -0.203649102475537],

                              [-0.305332614246176, -0.577466592827372, -0.309148373248618, 0.505095613377923,

                               0.00839802776239684, -0.550160142512996, -0.549794998953918, -0.416474925052188,

                               -0.902864487728197, -0.498087729530643],

                              [-0.543401784307858, -0.838075831400481, -0.731847424313786, -0.923641718232543,

                               -0.0264838053587217, -0.615076094148371, 0.421337378420538, -0.462636179013577,

                               0.0877415486187202, -0.222681003564926],

                              [-0.0431997662658585, -0.64086961778174, -0.607659553509879, -0.724862215208416,

                               -0.728610588683776, -0.524117054122693, 0.368944687941134, -0.923786846156236,

                               -0.0597560752935269, 0.413222053159296],

                              [-0.399568293999173, 0.815288995886186, -0.817199557088931, -0.768432705783836,

                               -0.0684560482492637, -0.013897658307912, -0.854377866246344, 0.386983593505216,

                               0.0262234650871292, 0.338476726443852],

                              [-0.886798337089141, -0.856520738607455, 0.0762089450302142, -0.331676606944875,

                               0.961201301178152, 0.882786807725685, -0.674578989323133, -0.811985381764876,

                               0.0735253809813046, -0.0629855999551604],

                              [0.486529691155726, -0.591585285933527, -0.148303358729466, 0.554136833155947,

                               -0.233557511590573, -0.20386058679155, 0.87774995528288, 0.836622425750782,

                               -0.31736741553135, -0.502706229691884],

                              [0.83970224881889, -0.786431817088105, -0.979360279051319, 0.698900279759187,

                               0.19944058643209, -0.616082293133205, -0.321638082097873, 0.51726616847625,

                               -0.30144241787091, -0.603303318590072],

                              [0.182204140705668, -0.581972504911827, -0.745902693538881, -0.411492827372293,

                               -0.049957678714064, -0.354089896315584, 0.853027666893416, 0.357983240757515,

                               -0.642643945639318, -0.498069872091213],

                              [-0.45126965721213, -0.936962607533486, 0.646158447313732, -0.95097442793594,

                               -0.680714560504185, -0.948203489417886, -0.356801223121791, -0.0134499794097405,

                               0.993635246521581, -0.174164478830752],

                              [-0.174460481493638, -0.291366228803816, -0.760320483573548, -0.518594224103403,

                               0.116216446913184, -0.111578762983225, -0.217977936603957, 0.248027012375155,

                               0.299676073935871, -0.0946823454004537],

                              [-0.00734385930874304, 0.390413037906223, -0.493685240564596, 0.799268946959125,

                               -0.989994570691187, -0.613398467399992, -0.431262374191085, -0.65408322696627,

                               0.376993724356529, -0.175926803978519],

                              [-0.378245987757025, -0.247075383254038, 0.736269794673573, 0.171474630539081,

                               0.28299630171977, -0.802954545127667, -0.6498709727982, 0.716178583765797,

                               0.276356298324105, -0.357548338449613],

                              [-0.843611481365903, -0.77229151913469, -0.811145449109455, -0.417383698578268,

                               -0.117093110980164, 0.813478911490971, -0.111806184555956, 0.68134261623292,

                               -0.327034374502997, -0.0579041208830625],

                              [-0.746575429912273, -0.848078443307712, 0.105104872014906, -0.533593948314463,

                               0.235091219002227, -0.568317776792953, -0.426330189198796, -0.365305547214323,

                               -0.519292151765541, 0.799024519922367],

                              [-0.375738571798775, -0.798185780078637, 0.823781138160157, -0.39453620215073,

                               -0.33097178033255, -0.615726080048584, -0.0233468854373248, -0.889794158096667,

                               0.232086951127982, -0.0563817200996439]])

   label = np.array([1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, 1, 1])

   xx = p1.get_order(feature_vector.shape[0])

   if check_real(

           ex_name, p1.hinge_loss_full,

           exp_res, feature_vector, label, theta, theta_0):

       return


   log(green("PASS"), ex_name, "")



def check_perceptron_single_update():

   ex_name = "Perceptron single update"


   feature_vector = np.array([1, 2])

   label, theta, theta_0 = 1, np.array([-1, 1]), -1.5

   exp_res = (np.array([0, 3]), -0.5)

   if check_tuple(

           ex_name, p1.perceptron_single_step_update,

           exp_res, feature_vector, label, theta, theta_0):

       return


   feature_vector = np.array([1, 2])

   label, theta, theta_0 = 1, np.array([-1, 1]), -1

   exp_res = (np.array([0, 3]), 0)

   if check_tuple(

           ex_name + " (boundary case)", p1.perceptron_single_step_update,

           exp_res, feature_vector, label, theta, theta_0):

       return


   log(green("PASS"), ex_name, "")



def check_perceptron():

   ex_name = "Perceptron"


   feature_matrix = np.array([[1, 2]])

   labels = np.array([1])

   T = 1

   exp_res = (np.array([1, 2]), 1)

   if check_tuple(

           ex_name, p1.perceptron,

           exp_res, feature_matrix, labels, T):

       return


   feature_matrix = np.array([[1, 2], [-1, 0]])

   labels = np.array([1, 1])

   T = 1

   exp_res = (np.array([0, 2]), 2)

   if check_tuple(

           ex_name, p1.perceptron,

           exp_res, feature_matrix, labels, T):

       return


   feature_matrix = np.array([[1, 2]])

   labels = np.array([1])

   T = 2

   exp_res = (np.array([1, 2]), 1)

   if check_tuple(

           ex_name, p1.perceptron,

           exp_res, feature_matrix, labels, T):

       return


   feature_matrix = np.array([[1, 2], [-1, 0]])

   labels = np.array([1, 1])

   T = 2

   exp_res = (np.array([0, 2]), 2)

   if check_tuple(

           ex_name, p1.perceptron,

           exp_res, feature_matrix, labels, T):

       return


   log(green("PASS"), ex_name, "")



def check_average_perceptron():

   ex_name = "Average perceptron"


   feature_matrix = np.array([[1, 2]])

   labels = np.array([1])

   T = 1

   exp_res = (np.array([1, 2]), 1)

   if check_tuple(

           ex_name, p1.average_perceptron,

           exp_res, feature_matrix, labels, T):

       return


   feature_matrix = np.array([[1, 2], [-1, 0]])

   labels = np.array([1, 1])

   T = 1

   exp_res = (np.array([-0.5, 1]), 1.5)

   if check_tuple(

           ex_name, p1.average_perceptron,

           exp_res, feature_matrix, labels, T):

       return


   feature_matrix = np.array([[1, 2]])

   labels = np.array([1])

   T = 2

   exp_res = (np.array([1, 2]), 1)

   if check_tuple(

           ex_name, p1.average_perceptron,

           exp_res, feature_matrix, labels, T):

       return


   feature_matrix = np.array([[1, 2], [-1, 0]])

   labels = np.array([1, 1])

   T = 2

   exp_res = (np.array([-0.25, 1.5]), 1.75)

   if check_tuple(

           ex_name, p1.average_perceptron,

           exp_res, feature_matrix, labels, T):

       return


   log(green("PASS"), ex_name, "")



def check_pegasos_single_update():

   ex_name = "Pegasos single update"


   feature_vector = np.array([1, 2])

   label, theta, theta_0 = 1, np.array([-1, 1]), -1.5

   L = 0.2

   eta = 0.1

   exp_res = (np.array([-0.88, 1.18]), -1.4)

   if check_tuple(

           ex_name, p1.pegasos_single_step_update,

           exp_res,

           feature_vector, label, L, eta, theta, theta_0):

       return


   feature_vector = np.array([1, 1])

   label, theta, theta_0 = 1, np.array([-1, 1]), 1

   L = 0.2

   eta = 0.1

   exp_res = (np.array([-0.88, 1.08]), 1.1)

   if check_tuple(

           ex_name + " (boundary case)", p1.pegasos_single_step_update,

           exp_res,

           feature_vector, label, L, eta, theta, theta_0):

       return


   feature_vector = np.array([1, 2])

   label, theta, theta_0 = 1, np.array([-1, 1]), -2

   L = 0.2

   eta = 0.1

   exp_res = (np.array([-0.88, 1.18]), -1.9)

   if check_tuple(

           ex_name, p1.pegasos_single_step_update,

           exp_res,

           feature_vector, label, L, eta, theta, theta_0):

       return


   log(green("PASS"), ex_name, "")



def check_pegasos():

   ex_name = "Pegasos"


   feature_matrix = np.array([[1, 2]])

   labels = np.array([1])

   T = 1

   L = 0.2

   exp_res = (np.array([1, 2]), 1)

   if check_tuple(

           ex_name, p1.pegasos,

           exp_res, feature_matrix, labels, T, L):

       return


   feature_matrix = np.array([[1, 1], [1, 1]])

   labels = np.array([1, 1])

   T = 1

   L = 1

   exp_res = (np.array([1 - 1 / np.sqrt(2), 1 - 1 / np.sqrt(2)]), 1)

   if check_tuple(

           ex_name, p1.pegasos,

           exp_res, feature_matrix, labels, T, L):

       return


   log(green("PASS"), ex_name, "")



def check_classify():

   ex_name = "Classify"


   feature_matrix = np.array([[1, 1], [1, 1], [1, 1]])

   theta = np.array([1, 1])

   theta_0 = 0

   exp_res = np.array([1, 1, 1])

   if check_array(

           ex_name, p1.classify,

           exp_res, feature_matrix, theta, theta_0):

       return


   feature_matrix = np.array([[-1, 1]])

   theta = np.array([1, 1])

   theta_0 = 0

   exp_res = np.array([-1])

   if check_array(

           ex_name + " (boundary case)", p1.classify,

           exp_res, feature_matrix, theta, theta_0):

       return


   log(green("PASS"), ex_name, "")



def check_classifier_accuracy():

   ex_name = "Classifier accuracy"


   train_feature_matrix = np.array([[1, 0], [1, -1], [2, 3]])

   val_feature_matrix = np.array([[1, 1], [2, -1]])

   train_labels = np.array([1, -1, 1])

   val_labels = np.array([-1, 1])

   exp_res = 1, 0

   T = 1

   if check_tuple(

           ex_name, p1.classifier_accuracy,

           exp_res,

           p1.perceptron,

           train_feature_matrix, val_feature_matrix,

           train_labels, val_labels,

           T=T):

       return


   train_feature_matrix = np.array([[1, 0], [1, -1], [2, 3]])

   val_feature_matrix = np.array([[1, 1], [2, -1]])

   train_labels = np.array([1, -1, 1])

   val_labels = np.array([-1, 1])

   exp_res = 1, 0

   T = 1

   L = 0.2

   if check_tuple(

           ex_name, p1.classifier_accuracy,

           exp_res,

           p1.pegasos,

           train_feature_matrix, val_feature_matrix,

           train_labels, val_labels,

           T=T, L=L):

       return


   log(green("PASS"), ex_name, "")



def check_bag_of_words():

   ex_name = "Bag of words"


   texts = [

       "He loves to walk on the beach",

       "There is nothing better"]


   try:

       res = p1.bag_of_words(texts)

   except NotImplementedError:

       log(red("FAIL"), ex_name, ": not implemented")

       return

   if not type(res) == dict:

       log(red("FAIL"), ex_name, ": does not return a tuple, type: ", type(res))

       return


   vals = sorted(res.values())

   exp_vals = list(range(len(res.keys())))

   if not vals == exp_vals:

       log(red("FAIL"), ex_name, ": wrong set of indices. Expected: ", exp_vals, " got ", vals)

       return


   log(green("PASS"), ex_name, "")


   keys = sorted(res.keys())

   exp_keys = ['beach', 'better', 'he', 'is', 'loves', 'nothing', 'on', 'the', 'there', 'to', 'walk']

   stop_keys = ['beach', 'better', 'loves', 'nothing', 'walk']


   if keys == exp_keys:

       log(yellow("WARN"), ex_name, ": does not remove stopwords:", [k for k in keys if k not in stop_keys])

   elif keys == stop_keys:

       log(green("PASS"), ex_name, " stopwords removed")

   else:

       log(red("FAIL"), ex_name, ": keys are missing:", [k for k in stop_keys if k not in keys],

           " or are not unexpected:", [k for k in keys if k not in stop_keys])



def check_extract_bow_feature_vectors():

   ex_name = "Extract bow feature vectors"

   texts = [

       "He loves her ",

       "He really really loves her"]

   keys = ["he", "loves", "her", "really"]

   dictionary = {k: i for i, k in enumerate(keys)}

   exp_res = np.array(

       [[1, 1, 1, 0],

        [1, 1, 1, 1]])

   non_bin_res = np.array(

       [[1, 1, 1, 0],

        [1, 1, 1, 2]])


   try:

       res = p1.extract_bow_feature_vectors(texts, dictionary)

   except NotImplementedError:

       log(red("FAIL"), ex_name, ": not implemented")

       return


   if not type(res) == np.ndarray:

       log(red("FAIL"), ex_name, ": does not return a numpy array, type: ", type(res))

       return

   if not len(res) == len(exp_res):

       log(red("FAIL"), ex_name, ": expected an array of shape ", exp_res.shape, " but got array of shape", res.shape)

       return


   log(green("PASS"), ex_name)


   if (res == exp_res).all():

       log(yellow("WARN"), ex_name, ": uses binary indicators as features")

   elif (res == non_bin_res).all():

       log(green("PASS"), ex_name, ": correct non binary features")

   else:

       log(red("FAIL"), ex_name, ": unexpected feature matrix")

       return



def main():

   log(green("PASS"), "Import project1")

   try:

       check_get_order()

       check_hinge_loss_single()

       check_hinge_loss_full()

       check_perceptron_single_update()

       check_perceptron()

       check_average_perceptron()

       check_pegasos_single_update()

       check_pegasos()

       check_classify()

       check_classifier_accuracy()

       check_bag_of_words()

       check_extract_bow_feature_vectors()

   except Exception:

       log_exit(traceback.format_exc())



class event():

   dic = {}


   def __init__(self, line):

       text = line.split('#')

       for t1 in text:

           t2 = t1.split('=')

           for t3, t4 in t2:

               self.dic[t3] = t4



def read_file():

   import csv

   data = {}

   with open('madhavi.ics', newline='') as csvfile:

       datareader = csv.reader(csvfile, delimiter='$', quotechar='|')

       count = 0

       for row in datareader:

           for r in row:

               data[count] = event(r)

           count += 1

   print(count)

def testPy():

   import torch

   from torch.autograd import Variable

   import torch.optim as optim


   def linear_model(x, W, b):

       return torch.matmul(x, W) + b


   data, targets = ...


   W = Variable(torch.randn(4, 3), requires_grad=True)

   b = Variable(torch.randn(3), requires_grad=True)


   optimizer = optim.Adam([W, b])


   for sample, target in zip(data, targets):

       # clear out the gradients of all Variables

       # in this optimizer (i.e. W, b)

       optimizer.zero_grad()

       output = linear_model(sample, W, b)

       loss = (output - target) ** 2

       loss.backward()

       optimizer.step()


if __name__ == "__main__":

   testPy()