summaryrefslogtreecommitdiff
path: root/LinkedNode.py
blob: 61ffa746d894e470b1bfc3bdabf60a6edede837a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python
"""Compiles C into assembly for the practicum processor (PP2)

Copyright (C) 2011-2014 Peter Wu <lekensteyn@gmail.com>
Licensed under the MIT license <http://opensource.org/licenses/MIT>.
"""

__author__ = "Peter Wu"
__copyright__ = "Copyright (C) 2011-2014 Peter Wu"
__credits__ = ["Peter Wu"]
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Peter Wu"
__email__ = "lekensteyn@gmail.com"

class LinkedNode(object):
    """Stores nodes with a reference to the parent"""
    def __init__(self, node, parent=None, level_increment=False):
        """Holds properties for a node

        Keyword arguments:
        node -- a Node object which is an object from the c_ast class
        parent -- a parent LinkedNode object
        level_increment -- True if the indentation level needs to be
        incremented, False otherwise
        """
        self.node = node
        if parent:
            assert isinstance(parent, LinkedNode), "parent is not a LinkedNode!"
        self.parent = parent
        self.function = None
        self.break_label = None
        self.continue_label = None
        self.type = type(node).__name__
        self.level = 0
        self.variables = None
        # for supporting post increment and post decrement
        self.post_lines = []
        # inherit level and variables from parent
        if parent:
            self.level = parent.level
            self.variables = parent.variables

        if level_increment:
            self.incrementLevel()

        # labels are limited to function contexts
        if self.type == "FuncDef":
            self.goto_labels = {}
    def setVariablesObj(self, variables_obj):
        """Sets a new variables scope object for this node"""
        self.variables = variables_obj
    def handle_post_lines(self, lines):
        """Add post-increment lines to the lines list and clear the queue"""
        lines += self.post_lines
        self.post_lines = []
    def needNewScope(self):
        """Returns True if a new name scope is necessary and False otherwise"""
        if self.type == "Compound":
            if self.parent.type not in ("For", "FuncDef"):
                return True
        elif self.type in ("FileAST", "For", "FuncDef"):
            return True
        return False
    def isTypeStatement(self):
        """Returns True if the node is a statement type"""
        return self.type in ("Compound", "If", "Return", "DoWhile", "While",
                        "For", "Decl", "FuncDef", "Break", "Continue",
                        "EmptyStatement", "Switch", "DeclList",
                        "FuncDecl", "ArrayDecl", "Case",
                        "Default", "EllipsisParam",# (int a, ...)
                        "Enum", # enum type
                        "Enumerator", # enum value
                        "EnumeratorList", # list of enum values
                        "Goto", "Label", "ParamList", "PtrDecl", "Struct",
                        "TypeDecl", "Typedef", "Union")
    def getStatementNode(self):
        """Returns the nearest LinkedNode which is a statement node type"""
        if self.isTypeStatement():
            return self
        if self.parent:
            return self.parent.getStatementNode()
        return None
    def incrementLevel(self):
        self.level += 1
    def getFunctionNode(self):
        """Returns the nearest LinkedNode which is a function definition node
        type"""
        if self.type == "FuncDef":
            return self
        if self.parent:
            return self.parent.getFunctionNode()
        return None
    def setFunction(self, function):
        """Sets the function object containing label information and a stack
        allocation function
        """
        self.function = function
        function.setLinkedNode(self)
        self.variables.setFunction(function)
    def getLocation(self):
        if hasattr(self.node, "coord"):
            return str(self.node.coord)
        return "Unknown"
    def setBreak(self, break_label):
        """Marks this node as a loop or switch by setting the label for break

        Keywords arguments:
        break_label -- The label to continue when using the break keyword
        """
        self.break_label = break_label
    def setContinue(self, continue_label):
        """Marks this node as a loop by setting the label for continue

        Keywords arguments:
        continue_label -- The label to continue when using the continue keyword
        """
        self.continue_label = continue_label
    def getBreakNode(self):
        """Returns the label to the end of the nearest switch statement or for
        loop"""
        if self.break_label is not None:
            return self
        if self.parent:
            return self.parent.getBreakNode()
        return None
    def getContinueNode(self):
        """Returns the label to the label to continue a loop"""
        if self.continue_label is not None:
            return self
        if self.parent:
            return self.parent.getContinueNode()
        return None
    def setLabel(self, label_name, label_asm):
        """Sets the label for this node

        Keyword arguments:
        label_name -- The label name as can be used in C
        label_asm -- The label as it appears in assembly
        """
        if self.parent:
            function = self.parent.getFunctionNode()
        if not self.parent or not function:
            raise RuntimeError("Labels are only allowed in functions")
        if label_name in function.goto_labels:
            raise RuntimeError("Duplicate label '{}'".format(label_name))

        function.goto_labels[label_name] = label_asm
        return label_asm
    def lookupLabel(self, label_name):
        """Returns the label name as it appears in assembly for label_name"""
        # get the nearest function for this node
        if self.parent:
            function = self.parent.getFunctionNode()
        if not self.parent or not function:
            raise RuntimeError("Labels are only allowed in functions")
        if label_name in function.goto_labels:
            return function.goto_labels[label_name]
        raise RuntimeError("Label '{}' used but not defined".format(name))