Commit f3d23e0e677eb272874af14c84bdc806af390840

Authored by Austin Sun
1 parent 98c5b35993
Exists in master

think i have if done

Showing 2 changed files with 45 additions and 4 deletions Inline Diff

/* File: ast_stmt.cc 1 1 /* File: ast_stmt.cc
* ----------------- 2 2 * -----------------
* Implementation of statement node classes. 3 3 * Implementation of statement node classes.
*/ 4 4 */
#include "ast_stmt.h" 5 5 #include "ast_stmt.h"
#include "ast_type.h" 6 6 #include "ast_type.h"
#include "ast_decl.h" 7 7 #include "ast_decl.h"
#include "ast_expr.h" 8 8 #include "ast_expr.h"
#include "symtable.h" 9 9 #include "symtable.h"
10 10
#include "irgen.h" 11 11 #include "irgen.h"
#include "llvm/Bitcode/ReaderWriter.h" 12 12 #include "llvm/Bitcode/ReaderWriter.h"
#include "llvm/Support/raw_ostream.h" 13 13 #include "llvm/Support/raw_ostream.h"
14 14
15 15
Program::Program(List<Decl*> *d) { 16 16 Program::Program(List<Decl*> *d) {
Assert(d != NULL); 17 17 Assert(d != NULL);
(decls=d)->SetParentAll(this); 18 18 (decls=d)->SetParentAll(this);
} 19 19 }
20 20
void Program::PrintChildren(int indentLevel) { 21 21 void Program::PrintChildren(int indentLevel) {
decls->PrintAll(indentLevel+1); 22 22 decls->PrintAll(indentLevel+1);
printf("\n"); 23 23 printf("\n");
} 24 24 }
//pls work 25 25 //pls work
llvm::Value* Program::Emit() { 26 26 llvm::Value* Program::Emit() {
llvm::Module *module = irGen->GetOrCreateModule("swag"); 27 27 llvm::Module *module = irGen->GetOrCreateModule("program");
pushScope(); 28 28 pushScope();
for (int i = 0; i < decls->NumElements(); i++){ 29 29 for (int i = 0; i < decls->NumElements(); i++){
decls->Nth(i)->Emit(); 30 30 decls->Nth(i)->Emit();
} 31 31 }
popScope(); 32 32 popScope();
33 33
module->dump(); 34 34 module->dump();
llvm::WriteBitcodeToFile(module, llvm::outs()); 35 35 llvm::WriteBitcodeToFile(module, llvm::outs());
return NULL; 36 36 return NULL;
} 37 37 }
38 38
StmtBlock::StmtBlock(List<VarDecl*> *d, List<Stmt*> *s) { 39 39 StmtBlock::StmtBlock(List<VarDecl*> *d, List<Stmt*> *s) {
Assert(d != NULL && s != NULL); 40 40 Assert(d != NULL && s != NULL);
(decls=d)->SetParentAll(this); 41 41 (decls=d)->SetParentAll(this);
(stmts=s)->SetParentAll(this); 42 42 (stmts=s)->SetParentAll(this);
} 43 43 }
44 44
void StmtBlock::PrintChildren(int indentLevel) { 45 45 void StmtBlock::PrintChildren(int indentLevel) {
decls->PrintAll(indentLevel+1); 46 46 decls->PrintAll(indentLevel+1);
stmts->PrintAll(indentLevel+1); 47 47 stmts->PrintAll(indentLevel+1);
} 48 48 }
49 49
DeclStmt::DeclStmt(Decl *d) { 50 50 DeclStmt::DeclStmt(Decl *d) {
Assert(d != NULL); 51 51 Assert(d != NULL);
(decl=d)->SetParent(this); 52 52 (decl=d)->SetParent(this);
} 53 53 }
54 54
void DeclStmt::PrintChildren(int indentLevel) { 55 55 void DeclStmt::PrintChildren(int indentLevel) {
decl->Print(indentLevel+1); 56 56 decl->Print(indentLevel+1);
} 57 57 }
58 58
ConditionalStmt::ConditionalStmt(Expr *t, Stmt *b) { 59 59 ConditionalStmt::ConditionalStmt(Expr *t, Stmt *b) {
Assert(t != NULL && b != NULL); 60 60 Assert(t != NULL && b != NULL);
(test=t)->SetParent(this); 61 61 (test=t)->SetParent(this);
(body=b)->SetParent(this); 62 62 (body=b)->SetParent(this);
} 63 63 }
64 64
ForStmt::ForStmt(Expr *i, Expr *t, Expr *s, Stmt *b): LoopStmt(t, b) { 65 65 ForStmt::ForStmt(Expr *i, Expr *t, Expr *s, Stmt *b): LoopStmt(t, b) {
Assert(i != NULL && t != NULL && b != NULL); 66 66 Assert(i != NULL && t != NULL && b != NULL);
(init=i)->SetParent(this); 67 67 (init=i)->SetParent(this);
step = s; 68 68 step = s;
if ( s ) 69 69 if ( s )
(step=s)->SetParent(this); 70 70 (step=s)->SetParent(this);
} 71 71 }
72 72
void ForStmt::PrintChildren(int indentLevel) { 73 73 void ForStmt::PrintChildren(int indentLevel) {
init->Print(indentLevel+1, "(init) "); 74 74 init->Print(indentLevel+1, "(init) ");
test->Print(indentLevel+1, "(test) "); 75 75 test->Print(indentLevel+1, "(test) ");
if ( step ) 76 76 if ( step )
step->Print(indentLevel+1, "(step) "); 77 77 step->Print(indentLevel+1, "(step) ");
body->Print(indentLevel+1, "(body) "); 78 78 body->Print(indentLevel+1, "(body) ");
} 79 79 }
80 80
void WhileStmt::PrintChildren(int indentLevel) { 81 81 void WhileStmt::PrintChildren(int indentLevel) {
test->Print(indentLevel+1, "(test) "); 82 82 test->Print(indentLevel+1, "(test) ");
body->Print(indentLevel+1, "(body) "); 83 83 body->Print(indentLevel+1, "(body) ");
} 84 84 }
85 85
IfStmt::IfStmt(Expr *t, Stmt *tb, Stmt *eb): ConditionalStmt(t, tb) { 86 86 IfStmt::IfStmt(Expr *t, Stmt *tb, Stmt *eb): ConditionalStmt(t, tb) {
Assert(t != NULL && tb != NULL); // else can be NULL 87 87 Assert(t != NULL && tb != NULL); // else can be NULL
elseBody = eb; 88 88 elseBody = eb;
if (elseBody) elseBody->SetParent(this); 89 89 if (elseBody) elseBody->SetParent(this);
} 90 90 }
91 91
void IfStmt::PrintChildren(int indentLevel) { 92 92 void IfStmt::PrintChildren(int indentLevel) {
if (test) test->Print(indentLevel+1, "(test) "); 93 93 if (test) test->Print(indentLevel+1, "(test) ");
if (body) body->Print(indentLevel+1, "(then) "); 94 94 if (body) body->Print(indentLevel+1, "(then) ");
if (elseBody) elseBody->Print(indentLevel+1, "(else) "); 95 95 if (elseBody) elseBody->Print(indentLevel+1, "(else) ");
} 96 96 }
97 97
98 98
ReturnStmt::ReturnStmt(yyltype loc, Expr *e) : Stmt(loc) { 99 99 ReturnStmt::ReturnStmt(yyltype loc, Expr *e) : Stmt(loc) {
expr = e; 100 100 expr = e;
if (e != NULL) expr->SetParent(this); 101 101 if (e != NULL) expr->SetParent(this);
} 102 102 }
103 103
void ReturnStmt::PrintChildren(int indentLevel) { 104 104 void ReturnStmt::PrintChildren(int indentLevel) {
if ( expr ) 105 105 if ( expr )
expr->Print(indentLevel+1); 106 106 expr->Print(indentLevel+1);
} 107 107 }
108 108
SwitchLabel::SwitchLabel(Expr *l, Stmt *s) { 109 109 SwitchLabel::SwitchLabel(Expr *l, Stmt *s) {
Assert(l != NULL && s != NULL); 110 110 Assert(l != NULL && s != NULL);
(label=l)->SetParent(this); 111 111 (label=l)->SetParent(this);
(stmt=s)->SetParent(this); 112 112 (stmt=s)->SetParent(this);
} 113 113 }
114 114
SwitchLabel::SwitchLabel(Stmt *s) { 115 115 SwitchLabel::SwitchLabel(Stmt *s) {
Assert(s != NULL); 116 116 Assert(s != NULL);
label = NULL; 117 117 label = NULL;
(stmt=s)->SetParent(this); 118 118 (stmt=s)->SetParent(this);
} 119 119 }
120 120
void SwitchLabel::PrintChildren(int indentLevel) { 121 121 void SwitchLabel::PrintChildren(int indentLevel) {
if (label) label->Print(indentLevel+1); 122 122 if (label) label->Print(indentLevel+1);
if (stmt) stmt->Print(indentLevel+1); 123 123 if (stmt) stmt->Print(indentLevel+1);
} 124 124 }
125 125
SwitchStmt::SwitchStmt(Expr *e, List<Stmt *> *c, Default *d) { 126 126 SwitchStmt::SwitchStmt(Expr *e, List<Stmt *> *c, Default *d) {
Assert(e != NULL && c != NULL && c->NumElements() != 0 ); 127 127 Assert(e != NULL && c != NULL && c->NumElements() != 0 );
(expr=e)->SetParent(this); 128 128 (expr=e)->SetParent(this);
(cases=c)->SetParentAll(this); 129 129 (cases=c)->SetParentAll(this);
def = d; 130 130 def = d;
if (def) def->SetParent(this); 131 131 if (def) def->SetParent(this);
} 132 132 }
133 133
void SwitchStmt::PrintChildren(int indentLevel) { 134 134 void SwitchStmt::PrintChildren(int indentLevel) {
if (expr) expr->Print(indentLevel+1); 135 135 if (expr) expr->Print(indentLevel+1);
if (cases) cases->PrintAll(indentLevel+1); 136 136 if (cases) cases->PrintAll(indentLevel+1);
if (def) def->Print(indentLevel+1); 137 137 if (def) def->Print(indentLevel+1);
} 138 138 }
139 139 //-----------------------------------------------------------------------
//rest of the emits 140 140 //rest of the emits
141 //-----------------------------------------------------------------------
llvm::Value * StmtBlock::Emit(){ 141 142 llvm::Value * StmtBlock::Emit(){
pushScope(); 142 143 pushScope();
for (int i = 0; i < decls->NumElements(); i++){ 143 144 for (int i = 0; i < decls->NumElements(); i++){
decls->Nth(i)->Emit(); 144 145 decls->Nth(i)->Emit();
} 145 146 }
for (int i = 0; i < stmts->NumElements(); i++){ 146 147 for (int i = 0; i < stmts->NumElements(); i++){
stmts->Nth(i)->Emit(); 147 148 stmts->Nth(i)->Emit();
} 148 149 }
149 150
//TODO 150
151
popScope(); 152 151 popScope();
return NULL; 153 152 return NULL;
} 154 153 }
155 154
llvm::Value * DeclStmt::Emit(){ 156 155 llvm::Value * DeclStmt::Emit(){
llvm::Value * val; 157 156 llvm::Value * val;
if (VarDecl * vd = dynamic_cast<VarDecl*>(this->decl)){ 158 157 if (VarDecl * vd = dynamic_cast<VarDecl*>(this->decl)){
val = vd->Emit(); 159 158 val = vd->Emit();
} 160 159 }
else if (FnDecl * fd = dynamic_cast<FnDecl*>(this->decl)){ 161 160 else if (FnDecl * fd = dynamic_cast<FnDecl*>(this->decl)){
val = fd->Emit(); 162 161 val = fd->Emit();
} 163 162 }
else{ 164 163 else{
val = NULL; 165 164 val = NULL;
} 166 165 }
return val; 167 166 return val;
} 168 167 }
169 168
llvm::Value * ConditionalStmt::Emit(){ 170 169 llvm::Value * ConditionalStmt::Emit(){
return NULL; 171 170 return NULL;
} 172 171 }
//for statement 173 172 //for statement
173
//while statement 174 174 //while statement
175
//if statement 175 176 //if statement
177 llvm::Value * IfStmt::Emit(){
178 llvm::Function * func = irGen->GetFunction();
179 llvm::BasicBlock * elseBlock = NULL;
180 llvm::BasicBlock * thenBlock = llvm::BasicBlock::Create(*context, "thenBlock", func);
181 llvm::BasicBlock * footBlock = llvm::BasicBlock::Create(*context, "footBlock", func);
182 llvm::Value * val;
183 llvm::Value * cond = test->Emit();
184 llvm::LLVMContext * context = irGen->GetContext();
185 if(elseBody)
186 {
187 elseBlock = llvm::BasicBlock::Create(*context, "elseBlock", func);
188 }
189
190 val = llvm::BranchInst::Create(thenBlock, elseBody ? elseBlock : footBlock, cond, irGen->GetBasicBlock());
191 pushScope();
192 irGen->SetBasicBlock(thenBlock);
193 body->Emit();
194
195 if(!irGen->GetBasicBlock()->getTerminator())
196 {
197 val = llvm::BranchInst::Create(footBlock, irGen->GetBasicBlock());
198 }
199 popScope();
200
201 if(elseBody)
202 {
203 pushScope();
204 irGen->SetBasicBlock(elseBlock);
205 elseBody->Emit();
206
207 if(!irGen->GetBasicBlock()->getTerminator())
208 {
209 val = llvm::BranchInst::Create(footBlock, irGen->GetBasicBlock());
210 }
211 popScope();
212 }
213 irGen->SetBasicBlock(footBlock);
214 return val;
215 }
176 216
llvm::Value * BreakStmt::Emit(){ 177 217 llvm::Value * BreakStmt::Emit(){
return NULL; 178 218 return NULL;
} 179 219 }
/* File: ast_stmt.h 1 1 /* File: ast_stmt.h
* ---------------- 2 2 * ----------------
* The Stmt class and its subclasses are used to represent 3 3 * The Stmt class and its subclasses are used to represent
* statements in the parse tree. For each statment in the 4 4 * statements in the parse tree. For each statment in the
* language (for, if, return, etc.) there is a corresponding 5 5 * language (for, if, return, etc.) there is a corresponding
* node class for that construct. 6 6 * node class for that construct.
* 7 7 *
* pp3: You will need to extend the Stmt classes to generate 8 8 * pp3: You will need to extend the Stmt classes to generate
* LLVM IR instructions. 9 9 * LLVM IR instructions.
*/ 10 10 */
11 11
12 12
#ifndef _H_ast_stmt 13 13 #ifndef _H_ast_stmt
#define _H_ast_stmt 14 14 #define _H_ast_stmt
15 15
#include "list.h" 16 16 #include "list.h"
#include "ast.h" 17 17 #include "ast.h"
18 18
class Decl; 19 19 class Decl;
class VarDecl; 20 20 class VarDecl;
class Expr; 21 21 class Expr;
class IntConstant; 22 22 class IntConstant;
23 23
void yyerror(const char *msg); 24 24 void yyerror(const char *msg);
25 25
class Program : public Node 26 26 class Program : public Node
{ 27 27 {
protected: 28 28 protected:
List<Decl*> *decls; 29 29 List<Decl*> *decls;
30 30
public: 31 31 public:
Program(List<Decl*> *declList); 32 32 Program(List<Decl*> *declList);
const char *GetPrintNameForNode() { return "Program"; } 33 33 const char *GetPrintNameForNode() { return "Program"; }
void PrintChildren(int indentLevel); 34 34 void PrintChildren(int indentLevel);
virtual llvm::Value* Emit(); 35 35 virtual llvm::Value* Emit();
}; 36 36 };
37 37
class Stmt : public Node 38 38 class Stmt : public Node
{ 39 39 {
public: 40 40 public:
Stmt() : Node() {} 41 41 Stmt() : Node() {}
Stmt(yyltype loc) : Node(loc) {} 42 42 Stmt(yyltype loc) : Node(loc) {}
}; 43 43 };
44 44
class StmtBlock : public Stmt 45 45 class StmtBlock : public Stmt
{ 46 46 {
protected: 47 47 protected:
List<VarDecl*> *decls; 48 48 List<VarDecl*> *decls;
List<Stmt*> *stmts; 49 49 List<Stmt*> *stmts;
50 50
public: 51 51 public:
StmtBlock(List<VarDecl*> *variableDeclarations, List<Stmt*> *statements); 52 52 StmtBlock(List<VarDecl*> *variableDeclarations, List<Stmt*> *statements);
const char *GetPrintNameForNode() { return "StmtBlock"; } 53 53 const char *GetPrintNameForNode() { return "StmtBlock"; }
void PrintChildren(int indentLevel); 54 54 void PrintChildren(int indentLevel);
55 55
llvm::Value *Emit(); 56 56 llvm::Value *Emit();
}; 57 57 };
58 58
class DeclStmt: public Stmt 59 59 class DeclStmt: public Stmt
{ 60 60 {
protected: 61 61 protected:
Decl* decl; 62 62 Decl* decl;
63 63
public: 64 64 public:
DeclStmt(Decl *d); 65 65 DeclStmt(Decl *d);
const char *GetPrintNameForNode() { return "DeclStmt"; } 66 66 const char *GetPrintNameForNode() { return "DeclStmt"; }
void PrintChildren(int indentLevel); 67 67 void PrintChildren(int indentLevel);
68 68
llvm::Value *Emit(); 69 69 llvm::Value *Emit();
}; 70 70 };
71 71
class ConditionalStmt : public Stmt 72 72 class ConditionalStmt : public Stmt
{ 73 73 {
protected: 74 74 protected:
Expr *test; 75 75 Expr *test;
Stmt *body; 76 76 Stmt *body;
77 77
public: 78 78 public:
ConditionalStmt() : Stmt(), test(NULL), body(NULL) {} 79 79 ConditionalStmt() : Stmt(), test(NULL), body(NULL) {}
ConditionalStmt(Expr *testExpr, Stmt *body); 80 80 ConditionalStmt(Expr *testExpr, Stmt *body);
81 81
llvm::Value *Emit(); 82 82 llvm::Value *Emit();
}; 83 83 };
84 84
class LoopStmt : public ConditionalStmt 85 85 class LoopStmt : public ConditionalStmt
{ 86 86 {
public: 87 87 public:
LoopStmt(Expr *testExpr, Stmt *body) 88 88 LoopStmt(Expr *testExpr, Stmt *body)
: ConditionalStmt(testExpr, body) {} 89 89 : ConditionalStmt(testExpr, body) {}
}; 90 90 };
91 91
class ForStmt : public LoopStmt 92 92 class ForStmt : public LoopStmt
{ 93 93 {
protected: 94 94 protected:
Expr *init, *step; 95 95 Expr *init, *step;
96 96
public: 97 97 public:
ForStmt(Expr *init, Expr *test, Expr *step, Stmt *body); 98 98 ForStmt(Expr *init, Expr *test, Expr *step, Stmt *body);
const char *GetPrintNameForNode() { return "ForStmt"; } 99 99 const char *GetPrintNameForNode() { return "ForStmt"; }
void PrintChildren(int indentLevel); 100 100 void PrintChildren(int indentLevel);
101 101
}; 102 102 };
103 103
class WhileStmt : public LoopStmt 104 104 class WhileStmt : public LoopStmt
{ 105 105 {
public: 106 106 public:
WhileStmt(Expr *test, Stmt *body) : LoopStmt(test, body) {} 107 107 WhileStmt(Expr *test, Stmt *body) : LoopStmt(test, body) {}
const char *GetPrintNameForNode() { return "WhileStmt"; } 108 108 const char *GetPrintNameForNode() { return "WhileStmt"; }
void PrintChildren(int indentLevel); 109 109 void PrintChildren(int indentLevel);
110 110
}; 111 111 };
112 112
class IfStmt : public ConditionalStmt 113 113 class IfStmt : public ConditionalStmt
{ 114 114 {
protected: 115 115 protected:
Stmt *elseBody; 116 116 Stmt *elseBody;
117 117
public: 118 118 public:
IfStmt() : ConditionalStmt(), elseBody(NULL) {} 119 119 IfStmt() : ConditionalStmt(), elseBody(NULL) {}
IfStmt(Expr *test, Stmt *thenBody, Stmt *elseBody); 120 120 IfStmt(Expr *test, Stmt *thenBody, Stmt *elseBody);
const char *GetPrintNameForNode() { return "IfStmt"; } 121 121 const char *GetPrintNameForNode() { return "IfStmt"; }
void PrintChildren(int indentLevel); 122 122 void PrintChildren(int indentLevel);
123 123
124 llvm::Value *Emit();
}; 124 125 };
125 126
class IfStmtExprError : public IfStmt 126 127 class IfStmtExprError : public IfStmt
{ 127 128 {
public: 128 129 public:
IfStmtExprError() : IfStmt() { yyerror(this->GetPrintNameForNode()); } 129 130 IfStmtExprError() : IfStmt() { yyerror(this->GetPrintNameForNode()); }
const char *GetPrintNameForNode() { return "IfStmtExprError"; } 130 131 const char *GetPrintNameForNode() { return "IfStmtExprError"; }
}; 131 132 };
132 133
class BreakStmt : public Stmt 133 134 class BreakStmt : public Stmt
{ 134 135 {
public: 135 136 public:
BreakStmt(yyltype loc) : Stmt(loc) {} 136 137 BreakStmt(yyltype loc) : Stmt(loc) {}
const char *GetPrintNameForNode() { return "BreakStmt"; } 137 138 const char *GetPrintNameForNode() { return "BreakStmt"; }
138 139
llvm::Value *Emit(); 139 140 llvm::Value *Emit();
}; 140 141 };
141 142
class ContinueStmt : public Stmt 142 143 class ContinueStmt : public Stmt
{ 143 144 {
public: 144 145 public:
ContinueStmt(yyltype loc) : Stmt(loc) {} 145 146 ContinueStmt(yyltype loc) : Stmt(loc) {}
const char *GetPrintNameForNode() { return "ContinueStmt"; } 146 147 const char *GetPrintNameForNode() { return "ContinueStmt"; }
147 148
llvm::Value *Emit(); 148 149 llvm::Value *Emit();
}; 149 150 };
150 151
class ReturnStmt : public Stmt 151 152 class ReturnStmt : public Stmt
{ 152 153 {
protected: 153 154 protected:
Expr *expr; 154 155 Expr *expr;
155 156
public: 156 157 public:
ReturnStmt(yyltype loc, Expr *expr = NULL); 157 158 ReturnStmt(yyltype loc, Expr *expr = NULL);
const char *GetPrintNameForNode() { return "ReturnStmt"; } 158 159 const char *GetPrintNameForNode() { return "ReturnStmt"; }
void PrintChildren(int indentLevel); 159 160 void PrintChildren(int indentLevel);
160 161
llvm::Value *Emit(); 161 162 llvm::Value *Emit();
}; 162 163 };
163 164
class SwitchLabel : public Stmt 164 165 class SwitchLabel : public Stmt
{ 165 166 {
protected: 166 167 protected:
Expr *label; 167 168 Expr *label;
Stmt *stmt; 168 169 Stmt *stmt;
169 170
public: 170 171 public:
SwitchLabel() { label = NULL; stmt = NULL; } 171 172 SwitchLabel() { label = NULL; stmt = NULL; }
SwitchLabel(Expr *label, Stmt *stmt); 172 173 SwitchLabel(Expr *label, Stmt *stmt);
SwitchLabel(Stmt *stmt); 173 174 SwitchLabel(Stmt *stmt);
void PrintChildren(int indentLevel); 174 175 void PrintChildren(int indentLevel);
175 176