#
# The classes for the abstract syntax tree (ASTNode)
#
#  ASTNode 
#   |
#   +-- Exp 
#   |    |
#   |    +-- IntLiteral
#   |    +-- Identifier
#   |    +-- UnaryExp 
#   |    +-- BinOpExp 
#   |
#   +-- Stmt 
#        |
#        +-- IfStmt 
#        +-- WhileStmt 
#        +-- AssignStmt

#----------------------------------------------=
# Simple symbol table 
# key is variable name, value is final value (no type because we only have integers)
symtab = {}

#-----------------------------------------------
# ASTNode - Abstract Syntax Tree
#-----------------------------------------------

class ASTNode:

    def __init__(self, line_no = '', parent = None):
        '''Create an abstract syntax tree node'''
        self.line_no = line_no           # may be null (i.e. empty string)
        self.parent = parent

    def __repr__(self):
        '''Return a string representation for this ASTNode object'''
        return str(self)

    def __str__(self):
        '''Return a string representation for this ASTNode object'''
        return repr(self)
    
    def eval(self):
        '''Do nothing, this is an abstract class'''
        return ''
    
#-----------------------------------------------
# Expression
#-----------------------------------------------

class Exp(ASTNode):

    def __init__(self, line_no = '', parent = None):
        '''Create an expression'''
        ASTNode.__init__(self, line_no, parent)
         
    def eval(self):
        '''Do nothing, this is an abstract class'''
        pass

#-----------------------------------------------
# Number Literal
#-----------------------------------------------

class IntLiteral(Exp):

    def __init__(self, val, line_no = '', parent = None):
        '''Create a numeric literal'''
        Exp.__init__(self, line_no, parent)
        self.val = val

    def __repr__(self):
        return str(self.val)
    
    def eval(self):
        return self.val   # this is an int
        
#-----------------------------------------------
# Identifier
#-----------------------------------------------

class Identifier(Exp):

    def __init__(self, name, line_no = ''):
        '''Create an identifier'''
        Exp.__init__(self, line_no)
        self.name = name
        
    def __repr__(self):
        return self.name
    
    def eval(self):
        # Just look up value in symbol table; 
        # If not found, return 0 and add to sym table
        if self.name in symtab.keys():
            return symtab[self.name]
        else:
            symtab[self.name] = 0
            return 0


#-----------------------------------------------
# Unary Expression
#-----------------------------------------------

class UnaryExp(Exp):
    NOT = 1

    def __init__(self, exp, op_type, line_no = ''):
        '''Create a unary operation expression'''
        Exp.__init__(self, line_no)
        self.exp = exp
        self.op_type = op_type

    def __repr__(self):
        if self.op_type == self.NOT:
            return '!' + str(self.exp)
    
    def eval(self):
        if self.op_type == self.NOT: 
            return not self.exp.eval()
        else:
            print "IMP Interpreter error: unrecognized unary operator", self.op_type
            exit(1)

#-----------------------------------------------
# Binary Operation
#-----------------------------------------------

class BinOpExp(Exp):
    SUB = 1
    MUL = 2
    AND = 3
    LE = 4
    GE = 5
    PLUS = 6
    DIVIDE = 7
    MOD = 8
    EQ = 9

    def __init__(self, lhs, rhs, op_type, line_no = ''):
        '''Create a binary operation expression'''
        Exp.__init__(self, line_no)
        self.lhs = lhs
        self.rhs = rhs
        self.op_type = op_type

    def __repr__(self):
        if (self.op_type == self.SUB): op = '-'
        elif (self.op_type == self.MUL): op = '*'
        elif (self.op_type == self.AND): op = '&&'
        elif (self.op_type == self.LE): op = '<='
        elif (self.op_type == self.GE): op = '>='
        elif (self.op_type == self.PLUS): op = '+'
        elif (self.op_type == self.DIVIDE): op = '/'
        elif (self.op_type == self.MOD): op = '%'
        elif (self.op_type == self.EQ): op = '=='

        return str(self.lhs) + op + str(self.rhs) 
    
    def eval(self):
        if (self.op_type == self.SUB): 
            return self.lhs.eval() - self.rhs.eval()
        elif (self.op_type == self.MUL): 
            return self.lhs.eval() * self.rhs.eval()
        elif (self.op_type == self.AND): 
            return self.lhs.eval() and self.rhs.eval()
        elif (self.op_type == self.LE): 
            return self.lhs.eval() <= self.rhs.eval()
        elif (self.op_type == self.GE): 
            return self.lhs.eval() >= self.rhs.eval()
        elif (self.op_type == self.PLUS): 
            return self.lhs.eval() + self.rhs.eval()
        elif (self.op_type == self.DIVIDE): 
            return self.lhs.eval() // self.rhs.eval()
        elif (self.op_type == self.MOD): 
            return self.lhs.eval() % self.rhs.eval()
        elif (self.op_type == self.EQ): 
            return self.lhs.eval() == self.rhs.eval()

        else: 
            print "IMP Interpreter error: unrecognized binary operation", self.op_type
            exit(1)

       
    
#-----------------------------------------------
# Statement
#-----------------------------------------------

class Stmt(ASTNode):

    def __init__(self, line_no = ''):
        '''Create a statement'''
        ASTNode.__init__(self, line_no)
    
    def eval(self):
        '''Do nothing, this is an abstract class'''
        pass
    
      
#-----------------------------------------------
# If-Then-Else
#-----------------------------------------------

class IfStmt(Stmt):

    def __init__(self, test, true_stmts, false_stmts = None, line_no = ''):
        '''Create an if statement'''
        Stmt.__init__(self, line_no)
        self.test = test
        self.true_stmts = true_stmts
        self.false_stmts = false_stmts

    def __repr__(self):
        thestr = "if ( " + str(self.test) + ") {\n" 
        for stmt in self.true_stmts: thestr += '\t' + str(stmt)
        thestr += "\n} else {\n" 
        for stmt in self.false_stmts: thestr += '\t' + str(stmt)
        thestr += "\n}\n" 
        return thestr
    
    def eval(self):
        if self.test.eval():
            for stmt in self.true_stmts: stmt.eval()
        else:
            for stmt in self.false_stmts: stmt.eval()
    
#-----------------------------------------------
# Whike Loop
#-----------------------------------------------

class WhileStmt(Stmt):

    def __init__(self, test, stmts, line_no = ''):
        '''Create a for-loop statement'''
        Stmt.__init__(self, line_no)
        self.test = test      # may be null
        self.stmts = stmts

    def __repr__(self):
        thestr = "while (%s) {\n" % str(self.test)
        for stmt in self.stmts: thestr += '\t' + str(stmt)
        thestr += "\n}\n"
        return thestr
        
    def eval(self):
        while self.test.eval():
            for stmt in self.stmts: stmt.eval()
            
        
#-----------------------------------------------
# Assignment
#-----------------------------------------------

class AssignStmt(Stmt):

    def __init__(self, var, exp, line_no = ''):
        '''Create a statement'''
        Stmt.__init__(self, line_no)
        self.var = var
        self.exp = exp

    def __repr__(self):
        '''Replicate this node'''
        return str(self.var) + ' = ' + str(self.exp) + ';\n'
    
    def eval(self):
        theval = self.exp.eval()
        symtab[self.var] = self.exp.eval()


