diff --git a/gofmt.sh b/gofmt.sh old mode 100644 new mode 100755 diff --git a/parser.go b/parser.go index a882c6c9..df95b89c 100644 --- a/parser.go +++ b/parser.go @@ -1,6 +1,7 @@ package sql import ( + "fmt" "io" "strings" ) @@ -1304,9 +1305,9 @@ func (p *Parser) parseTriggerBodyStatement() (stmt Statement, err error) { case INSERT, REPLACE: stmt, err = p.parseInsertStatement(nil) case UPDATE: - stmt, err = p.parseUpdateStatement(nil) + stmt, err = p.parseTriggerBodyUpdateStatement(nil) case DELETE: - stmt, err = p.parseDeleteStatement(nil) + stmt, err = p.parseTriggerBodyDeleteStatement(nil) case WITH: stmt, err = p.parseWithStatement() default: @@ -1325,6 +1326,175 @@ func (p *Parser) parseTriggerBodyStatement() (stmt Statement, err error) { return stmt, nil } +// parseTriggerBodyDeleteStatement parses a DELETE statement within a trigger body. +// It differs from parseDeleteStatement by only allowing unqualified table names. +func (p *Parser) parseTriggerBodyDeleteStatement(withClause *WithClause) (_ *DeleteStatement, err error) { + assert(p.peek() == DELETE) + + var stmt DeleteStatement + stmt.WithClause = withClause + + // Parse "DELETE FROM tbl" + stmt.Delete, _, _ = p.scan() + if p.peek() != FROM { + return &stmt, p.errorExpected(p.pos, p.tok, "FROM") + } + stmt.From, _, _ = p.scan() + if !isIdentToken(p.peek()) { + return nil, p.errorExpected(p.pos, p.tok, "table name") + } + ident, _ := p.parseIdent("table name") + + // In trigger bodies, only unqualified table names are allowed + if err = p.validateUnqualifiedTableName(ident); err != nil { + return &stmt, err + } + stmt.Table = &QualifiedTableName{Name: ident} + + // Parse WHERE clause. + if p.peek() == WHERE { + stmt.Where, _, _ = p.scan() + if stmt.WhereExpr, err = p.ParseExpr(); err != nil { + return &stmt, err + } + } + + // Parse ORDER BY clause. This differs from the SELECT parsing in that + // if an ORDER BY is specified then the LIMIT is required. + if p.peek() == ORDER || p.peek() == LIMIT { + if p.peek() == ORDER { + stmt.Order, _, _ = p.scan() + if p.peek() != BY { + return &stmt, p.errorExpected(p.pos, p.tok, "BY") + } + stmt.OrderBy, _, _ = p.scan() + + for { + term, err := p.parseOrderingTerm() + if err != nil { + return &stmt, err + } + stmt.OrderingTerms = append(stmt.OrderingTerms, term) + + if p.peek() != COMMA { + break + } + p.scan() + } + } + + // Parse LIMIT/OFFSET clause. + if p.peek() != LIMIT { + return &stmt, p.errorExpected(p.pos, p.tok, "LIMIT") + } + stmt.Limit, _, _ = p.scan() + if stmt.LimitExpr, err = p.ParseExpr(); err != nil { + return &stmt, err + } + + if p.peek() == OFFSET { + stmt.Offset, _, _ = p.scan() + if stmt.OffsetExpr, err = p.ParseExpr(); err != nil { + return &stmt, err + } + } + } + + return &stmt, nil +} + +// parseTriggerBodyUpdateStatement parses an UPDATE statement within a trigger body. +// It differs from parseUpdateStatement by only allowing unqualified table names. +func (p *Parser) parseTriggerBodyUpdateStatement(withClause *WithClause) (_ *UpdateStatement, err error) { + assert(p.peek() == UPDATE) + + var stmt UpdateStatement + stmt.WithClause = withClause + + stmt.Update, _, _ = p.scan() + if p.peek() == OR { + stmt.UpdateOr, _, _ = p.scan() + + switch p.peek() { + case ROLLBACK: + stmt.UpdateOrRollback, _, _ = p.scan() + case REPLACE: + stmt.UpdateOrReplace, _, _ = p.scan() + case ABORT: + stmt.UpdateOrAbort, _, _ = p.scan() + case FAIL: + stmt.UpdateOrFail, _, _ = p.scan() + case IGNORE: + stmt.UpdateOrIgnore, _, _ = p.scan() + default: + return &stmt, p.errorExpected(p.pos, p.tok, "ROLLBACK, REPLACE, ABORT, FAIL, or IGNORE") + } + } + + if !isIdentToken(p.peek()) { + return nil, p.errorExpected(p.pos, p.tok, "table name") + } + ident, _ := p.parseIdent("table name") + + // In trigger bodies, only unqualified table names are allowed + if err = p.validateUnqualifiedTableName(ident); err != nil { + return &stmt, err + } + stmt.Table = &QualifiedTableName{Name: ident} + + // Parse SET + list of assignments. + if p.peek() != SET { + return &stmt, p.errorExpected(p.pos, p.tok, "SET") + } + stmt.Set, _, _ = p.scan() + + for { + assignment, err := p.parseAssignment() + if err != nil { + return &stmt, err + } + stmt.Assignments = append(stmt.Assignments, assignment) + + if p.peek() != COMMA { + break + } + p.scan() + } + + // Parse WHERE clause. + if p.peek() == WHERE { + stmt.Where, _, _ = p.scan() + if stmt.WhereExpr, err = p.ParseExpr(); err != nil { + return &stmt, err + } + } + + // Parse optional RETURNING clause. + if p.peek() == RETURNING { + if stmt.ReturningClause, err = p.parseReturningClause(); err != nil { + return &stmt, err + } + } + + return &stmt, nil +} + +// validateUnqualifiedTableName ensures that the next tokens do not form +// a qualified table name (schema.table or table alias). +func (p *Parser) validateUnqualifiedTableName(ident *Ident) error { + // Check for schema qualification (schema.table) + if p.peek() == DOT { + return fmt.Errorf("qualified table names not allowed in trigger body") + } + + // Check for table alias + if tok := p.peek(); tok == AS || isIdentToken(tok) { + return fmt.Errorf("qualified table names not allowed in trigger body") + } + + return nil +} + func (p *Parser) parseDropTriggerStatement(dropPos Pos) (_ *DropTriggerStatement, err error) { assert(p.peek() == TRIGGER) diff --git a/parser_test.go b/parser_test.go index 659de07c..0f1cb301 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1860,6 +1860,12 @@ func TestParser_ParseStatement(t *testing.T) { End: pos(83), }) + // Test cases that should fail due to qualified table names in trigger body + AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN DELETE FROM host h; END`, `qualified table names not allowed in trigger body`) + AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN UPDATE host h SET x = 1; END`, `qualified table names not allowed in trigger body`) + AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN DELETE FROM schema.host; END`, `qualified table names not allowed in trigger body`) + AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN UPDATE schema.host SET x = 1; END`, `qualified table names not allowed in trigger body`) + AssertParseStatementError(t, `CREATE TRIGGER`, `1:14: expected index name, found 'EOF'`) AssertParseStatementError(t, `CREATE TRIGGER IF`, `1:17: expected NOT, found 'EOF'`) AssertParseStatementError(t, `CREATE TRIGGER IF NOT`, `1:21: expected EXISTS, found 'EOF'`)