From 6aaff7ae54dbe34ea6bf3f9ee8c93e9d79db3c22 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 1 Dec 2011 22:18:27 +0000 Subject: Fix stack corruption, support func calls with arguments and automatic variables --- LinkedNode.py | 71 ++++++++++++++------------------------- Variables.py | 104 +++++++++++++++++++++++++++++++++------------------------- pp2cc.py | 57 +++++++++++++++++++++++++++----- 3 files changed, 132 insertions(+), 100 deletions(-) diff --git a/LinkedNode.py b/LinkedNode.py index 88bc6da..f068993 100644 --- a/LinkedNode.py +++ b/LinkedNode.py @@ -18,12 +18,9 @@ __version__ = "1.0" __maintainer__ = "Peter Wu" __email__ = "uwretep@gmail.com" -from Variables import Variables - class LinkedNode(object): """Stores nodes with a reference to the parent""" - def __init__(self, node, parent=None, level_increment=False, - defined_names=None): + def __init__(self, node, parent=None, level_increment=False): """Holds properties for a node Keyword arguments: @@ -31,8 +28,6 @@ class LinkedNode(object): parent -- a parent LinkedNode object level_increment -- True if the indentation level needs to be incremented, False otherwise - defined_names -- a list of names which will be used in the @DATA - section. If not set, it'll be looked up in the parent """ self.node = node if parent: @@ -46,48 +41,32 @@ class LinkedNode(object): self.variables = None # for supporting post increment and post decrement self.post_lines = [] - parent_variables = None # inherit level and variables from parent if parent: self.level = parent.level - self.variables = parent_variables = parent.variables - if not defined_names: - defined_names = parent.variables.defined_names + self.variables = parent.variables - # for is added for the init part (C99) - if self.type in ("Compound", "FileAST", "For"): - # the node appears to have an own variable scope - if defined_names is None: - raise RuntimeError("No object found for storing variables") - # pass function object for local variables allocation - function = self.getFunctionNode() - if function: - function = function.function - self.variables = Variables(parent_variables, defined_names, - function=function) - # Identifiers which are in use (think of variables and labels) - self.defined_names = defined_names - if not self.variables: - raise RuntimeError("No variables object found") if level_increment: self.incrementLevel() + # labels are limited to function contexts if self.type == "FuncDef": self.goto_labels = {} + def setVariablesObj(self, variables_obj): + """Sets a new variables scope object for this node""" + self.variables = variables_obj def handle_post_lines(self, lines): """Add post-increment lines to the lines list and clear the queue""" lines += self.post_lines self.post_lines = [] - def getScopeNode(self): - """Get the nearest node which introduces a new scope. - - If there is no node found an exception is raised because it expects at - least a global scope""" - if self.local_vars is not None: - return self - if self.parent: - return self.parent.getScopeNode() - raise RuntimeError("No global variable scope was found") + def needNewScope(self): + """Returns True if a new name scope is necessary and False otherwise""" + if self.type == "Compound": + if self.parent.type not in ("For", "FuncDef"): + return True + elif self.type in ("FileAST", "For", "FuncDef"): + return True + return False def isTypeStatement(self): """Returns True if the node is a statement type""" return self.type in ("Compound", "If", "Return", "DoWhile", "While", @@ -118,11 +97,14 @@ class LinkedNode(object): return self.parent.getFunctionNode() return None def setFunction(self, function): - """Sets the function object containing label information""" + """Sets the function object containing label information and a stack + allocation function + """ self.function = function + self.variables.setFunction(function) def getLocation(self): if hasattr(self.node, "coord"): - return self.node.coord + return str(self.node.coord) return "Unknown" def setBreak(self, break_label): """Marks this node as a loop or switch by setting the label for break @@ -153,9 +135,12 @@ class LinkedNode(object): if self.parent: return self.parent.getContinueNode() return None - def setLabel(self, label_name): - """Sets the label for this node and return the label name as it appears - in assembly + def setLabel(self, label_name, label_asm): + """Sets the label for this node + + Keyword arguments: + label_name -- The label name as can be used in C + label_asm -- The label as it appears in assembly """ if self.parent: function = self.parent.getFunctionNode() @@ -164,12 +149,6 @@ class LinkedNode(object): if label_name in function.goto_labels: raise RuntimeError("Duplicate label '{}'".format(label_name)) - label_asm = "lbl_" + label_name - i = 0 - while label_asm in self.defined_names: - label_asm = "lbl_" + label_name + str(i) - i += 1 - function.goto_labels[label_name] = label_asm return label_asm def lookupLabel(self, label_name): diff --git a/Variables.py b/Variables.py index 338cfdf..bab1a34 100644 --- a/Variables.py +++ b/Variables.py @@ -19,31 +19,71 @@ __maintainer__ = "Peter Wu" __email__ = "uwretep@gmail.com" class Variables(object): - def __init__(self, parent_variables, defined_names, function=None): + def __init__(self, parent_variables): """A scope for holding variable names Keywords arguments: parent_variables -- the parent Variables object. If None, it's a global variable scope - defined_names -- A list of defined variables to which additional - variables might be appended - function -- The function object used for allocating memory. If not set, - the variables object will ask the parent_variables object for it """ self.parent_variables = parent_variables + self.function = None + if self.parent_variables and hasattr(self.parent_variables, "function"): + self.function = self.parent_variables.function + # key: name, value: address n relative to BP (n >= 1) + self.local_vars = {} + self.param_vars = [] + def setFunction(self, function): + """Sets the function object containing the stack allocation function""" self.function = function - # if there is a parent_variables object, it must be a function + def getAddress(self, name): + """Gets the address for a variable as a tuple containing a register and + displacement + + To get the address of the variable, add the register value and + displacement + """ + # first try the local scope + if name in self.local_vars: + # XXX don't hardcode R5 + return ("R5", str(-self.local_vars[name])) + try: + return ("R5", str(1 + self.param_vars.index(name))) + except ValueError: + pass + # lookup in the parent if self.parent_variables: - # key: name, value: address n relative to BP (n >= 1) - self.local_vars = {} - self.param_vars = [] - if not self.function: - self.function = self.parent_variables.function + return self.parent_variables.getAddress(name) + raise RuntimeError("Use of undefined variable '{}'".format(name)) + def declName(self, name, size=1, is_param=False): + """Declares a variable in the nearest scope + + Keyword arguments: + name -- The symbolic name of the variable + size -- The size of the memory to be allocated in words (default 1) + is_param -- Whether the name is a function parameter or not + """ + if name in self.local_vars or name in self.param_vars: + raise RuntimeError("Redeclaration of variable '{}'".format(name)) + + # parameters do not need a special allocation because the callee + # pushes it into the stack + if is_param: + self.param_vars.append(name) else: - # key: name of var, value: label of var - self.global_vars = {} + self.local_vars[name] = self.function.allocStack(size) + +class GlobalVariables(Variables): + def __init__(self, defined_names): + """A scope for holding variable names + + Keywords arguments: + defined_names -- A dictionary holding identifiers which are already + defined in assembly + """ self.defined_names = defined_names - def uniqName(self, name): + self.global_vars = {} + def _uniqName(self, name): """Returns an unique global variable name for assembly""" uniq_name = name i = 0 @@ -58,22 +98,8 @@ class Variables(object): To get the address of the variable, add the register value and displacement """ - # first try the local scope - if self.function: - if name in self.local_vars: - # XXX don't hardcode R5 - return ("R5", str(self.local_vars[name])) - try: - return ("R5", str(-self.param_vars.index(name))) - except ValueError: - pass - else: - # lookup in global vars - if name in self.global_vars: - return ("GB", self.global_vars[name]) - # lookup in the parent - if self.parent_variables: - return self.parent_variables.getAddress(name) + if name in self.global_vars: + return ("GB", self.global_vars[name]) raise RuntimeError("Use of undefined variable '{}'".format(name)) def declName(self, name, size=1, is_param=False): """Declares a variable in the nearest scope @@ -83,25 +109,13 @@ class Variables(object): size -- The size of the memory to be allocated in words (default 1) is_param -- Whether the name is a function parameter or not """ - already_defined = False - if self.function: - already_defined = name in self.local_vars or name in self.param_vars - else: - already_defined = name in self.global_vars - if already_defined: + if name in self.global_vars: raise RuntimeError("Redeclaration of variable '{}'".format(name)) - if self.function: - # parameters do not need a special allocation because the callee - # pushes it into the stack - if is_param: - self.param_vars.append(name) - else: - self.local_vars[name] = self.function.allocStack(size) - elif is_param: + if is_param: raise RuntimeError("Parameter '{}' declared in global context".format(name)) else: # global variables are prefixed "var_" - var_name = "var_" + name + var_name = self._uniqName("var_" + name) self.global_vars[name] = var_name self.defined_names[var_name] = size diff --git a/pp2cc.py b/pp2cc.py index 0ae2803..c50bff2 100755 --- a/pp2cc.py +++ b/pp2cc.py @@ -16,6 +16,7 @@ from Asm import Asm from Registers import Registers from LinkedNode import LinkedNode from Function import Function +from Variables import Variables, GlobalVariables __author__ = "Peter Wu" __copyright__ = "Copyright 2011, Peter Wu" @@ -115,7 +116,9 @@ class Parse(object): self.labels.add(labelname) def compile(self): """Loops through the nodes in an AST syntax tree and generates ASM""" - root_node = LinkedNode(self.node, defined_names=self.varNames) + root_node = LinkedNode(self.node) + variables = GlobalVariables(self.varNames) + root_node.setVariablesObj(variables) for thing in self.node.ext: if not isinstance(thing, c_ast.Decl) and not isinstance(thing, c_ast.FuncDef): linked_node = LinkedNode(thing, parent=root_node) @@ -161,22 +164,30 @@ class Parse(object): self.addLabel(lbl_func) self.addLabel(lbl_end) linked_node.setFunction(function) - # save Base Pointer lines = [self.asm.push(self.registers.BP, lbl_func)] lines.append(self.asm.binary_op("LOAD", self.registers.BP, "SP")) # parse function declaration (which will parse params as well) lines += self.parseStatement(node.decl, linked_node) - lines += self.parseStatement(node.body, linked_node) + body = self.parseStatement(node.body, linked_node) self.asm.level = linked_node.level - # restore stack pointer - lines.append(self.asm.binary_op("LOAD", "SP", self.registers.BP)) + # Reserve space on the stack for local variables if necessary + if function.reserved_stack: + lines.append(self.asm.binary_op("SUB", "SP", function.reserved_stack)) + lines += body + # restore SP + lines.append(self.asm.binary_op("LOAD", "SP", self.registers.BP, + label=lbl_end)) + else: + lines += body + lines.append(self.asm.noop(lbl_end)) + # restore Base Pointer lines.append(self.asm.pull(self.registers.BP)) # return from function - lines.append(self.asm.format_line("RTS", label=lbl_end)) + lines.append(self.asm.format_line("RTS")) # add an extra newline lines.append("") return lines @@ -677,10 +688,27 @@ class Parse(object): #lines.append(self.asm.format_line("RTS")) return lines def parseFuncCall(self, linked_node): - # XXX function arguments + lines = [] # node.name is a c_ast.ID, the real function name is in name funcname = linked_node.node.name.name - return [self.asm.branch_op("BRS", self.functionLabel(funcname))] + params = linked_node.node.args + if params: + linked_params = LinkedNode(params, linked_node) + # call convention: params in reverse order + for expr in reversed(params.exprs): + line = self.parseExpression(expr, linked_params) + result_reg = self.registers.find_register(line, fatal=True) + lines += line + lines.append(self.asm.push(result_reg)) + + lines.append(self.asm.branch_op("BRS", self.functionLabel(funcname))) + + if params: + lines.append(self.asm.binary_op("ADD", "SP", len(params.exprs))) + # by convention, a function call must return the result in R0, we + # also make sure that the PSW flags are set properly + lines.append(self.asm.binary_op("LOAD", "R0", "R0")) + return lines def parseCast(self, linked_node): self.logger.warning("Found a type cast, but these are unsupported.", linked_node=linked_node) return self.parseExpression(linked_node.node.expr, linked_node) @@ -765,6 +793,9 @@ class Parse(object): lines.append(self.asm.noop(lbl_for_end)) return lines def parseExprList(self, linked_node): + """Parse an expression list. This method should not be used when + parsing a list declaration (done by parseDecl) nor function params + """ lines = [] for expr in linked_node.node.exprs: lines += self.parseExpression(expr, linked_node) @@ -1062,7 +1093,9 @@ class Parse(object): def parseLabel(self, linked_node): lines = [] try: - label_name = linked_node.setLabel(linked_node.node.name) + name = linked_node.node.name + label_name = self.uniqLbl(name) + linked_node.setLabel(name, label_name) except RuntimeError as errmsg: self.logger.error(errmsg, linked_node=linked_node) lines.append(self.asm.noop(label_name)) @@ -1099,9 +1132,13 @@ class Parse(object): return lines def parseStatement(self, node, parent_linked_node, level_increment=False): linked_node = LinkedNode(node, parent_linked_node, level_increment=level_increment) + if linked_node.needNewScope(): + linked_node.setVariablesObj(Variables(linked_node.parent.variables)) self.asm.level = linked_node.level lines = [] if linked_node.isTypeStatement(): + lines.append(self.asm.format_line("; " + linked_node.type + " " + + linked_node.getLocation())) if hasattr(self, "parse" + linked_node.type): lines += getattr(self, "parse" + linked_node.type)(linked_node) # add for support post increment and decrement @@ -1120,6 +1157,8 @@ class Parse(object): self.asm.level = linked_node.level lines = [] + lines.append(self.asm.format_line("; " + linked_node.type + " " + + linked_node.getLocation())) if linked_node.type in ("ID", "Constant", "UnaryOp", "FuncCall", "Cast", "BinaryOp", "TernaryOp", "Assignment", "ExprList", "ArrayRef"): lines += getattr(self, "parse" + linked_node.type)(linked_node) -- cgit v1.2.1