From 63c11a20b750dede6641fc0f71d13c589e4b40f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=A1s=20De=20los=20Santos?= Date: Tue, 2 Apr 2024 22:38:01 +0200 Subject: [PATCH] feat: completion for non default schema non default schema tables where not considered as candidates for completion --- internal/completer/candidates.go | 72 +++++++++--- internal/completer/completer.go | 13 ++- internal/database/cache.go | 18 +-- internal/database/database_mock.go | 159 ++++++++++++++++++------- internal/handler/completion_test.go | 175 +++++++++++++++++++++++++--- 5 files changed, 353 insertions(+), 84 deletions(-) diff --git a/internal/completer/candidates.go b/internal/completer/candidates.go index c512a7c..6f39a71 100644 --- a/internal/completer/candidates.go +++ b/internal/completer/candidates.go @@ -68,7 +68,17 @@ func (c *Completer) columnCandidates(targetTables []*parseutil.TableInfo, parent if table.Name != parent.Name && table.Alias != parent.Name { continue } - columns, ok := c.DBCache.ColumnDescs(table.Name) + + var columns []*database.ColumnDesc + var ok bool + + if table.DatabaseSchema != "" { + columns, ok = c.DBCache.ColumnDatabase(table.DatabaseSchema, table.Name) + + } else { + columns, ok = c.DBCache.ColumnDescs(table.Name) + } + if !ok { continue } @@ -113,7 +123,13 @@ func (c *Completer) ReferencedTableCandidates(targetTables []*parseutil.TableInf for _, targetTable := range targetTables { includeTables := []*parseutil.TableInfo{} - for _, table := range c.DBCache.SortedTables() { + var schemaTables []string + if (targetTable.DatabaseSchema != "") { + schemaTables, _ = c.DBCache.SortedTablesByDBName(targetTable.DatabaseSchema) + } else { + schemaTables = c.DBCache.SortedTables() + } + for _, table := range schemaTables { if table == targetTable.Name { includeTables = append(includeTables, targetTable) } @@ -129,24 +145,29 @@ func (c *Completer) TableCandidates(parent *completionParent, targetTables []*pa switch parent.Type { case ParentTypeNone: - excludeTables := []string{} - for _, table := range c.DBCache.SortedTables() { - isExclude := false - for _, targetTable := range targetTables { - if table == targetTable.Name { - isExclude = true + targetTablesMap := make(map[string]*parseutil.TableInfo) + for _, targetTable := range targetTables { + targetTablesMap[targetTable.Name] = targetTable + } + for schemaKey, schema := range c.DBCache.Schemas { + excludeTables := []string{} + + tables := c.DBCache.SchemaTables[schemaKey] + for _, table := range tables { + _, isExclude := targetTablesMap[table] + if isExclude { + continue } + excludeTables = append(excludeTables, table) } - if isExclude { - continue - } - excludeTables = append(excludeTables, table) + + schemaCandidates := generateTableCandidatesBySchema(schema, excludeTables, c.DBCache) + candidates = append(candidates, schemaCandidates...) } - candidates = append(candidates, generateTableCandidates(excludeTables, c.DBCache)...) case ParentTypeSchema: tables, ok := c.DBCache.SortedTablesByDBName(parent.Name) if ok { - candidates = append(candidates, generateTableCandidates(tables, c.DBCache)...) + candidates = append(candidates, generateTableCandidatesBySchema(parent.Name, tables, c.DBCache)...) } case ParentTypeTable: // pass @@ -320,14 +341,24 @@ func generateForeignKeyCandidate(target string, } func generateTableCandidates(tables []string, dbCache *database.DBCache) []lsp.CompletionItem { + return generateTableCandidatesBySchema(dbCache.DefaultSchema, tables, dbCache) +} + +func generateTableCandidatesBySchema(schemaName string, tables []string, dbCache *database.DBCache) []lsp.CompletionItem { candidates := []lsp.CompletionItem{} for _, tableName := range tables { + var label string + if schemaName != dbCache.DefaultSchema { + label = fmt.Sprintf("%s.%s", schemaName, tableName) + } else { + label = tableName + } candidate := lsp.CompletionItem{ - Label: tableName, + Label: label, Kind: lsp.ClassCompletion, Detail: "table", } - cols, ok := dbCache.ColumnDescs(tableName) + cols, ok := dbCache.ColumnDatabase(schemaName, tableName) if ok { candidate.Documentation = lsp.MarkupContent{ Kind: lsp.Markdown, @@ -353,7 +384,14 @@ func generateTableCandidatesByInfos(tables []*parseutil.TableInfo, dbCache *data Kind: lsp.ClassCompletion, Detail: detail, } - cols, ok := dbCache.ColumnDescs(table.Name) + var cols []*database.ColumnDesc + var ok bool + + if table.DatabaseSchema != "" { + cols, ok = dbCache.ColumnDatabase(table.DatabaseSchema, table.Name) + } else { + cols, ok = dbCache.ColumnDescs(table.Name) + } if ok { candidate.Documentation = lsp.MarkupContent{ Kind: lsp.Markdown, diff --git a/internal/completer/completer.go b/internal/completer/completer.go index a79f32a..7d0050c 100644 --- a/internal/completer/completer.go +++ b/internal/completer/completer.go @@ -411,8 +411,19 @@ func getCompletionTypes(nw *parseutil.NodeWalker) *CompletionContext { func filterCandidates(candidates []lsp.CompletionItem, lastWord string) []lsp.CompletionItem { filtered := []lsp.CompletionItem{} + withBackQuote := strings.HasPrefix(lastWord, "`") + for _, candidate := range candidates { - if strings.HasPrefix(strings.ToUpper(candidate.Label), strings.ToUpper(lastWord)) { + label := strings.ToUpper(candidate.Label) + + if !withBackQuote && candidate.Kind != lsp.SnippetCompletion { + index := strings.LastIndex(label, ".") + if index != -1 { + label = label[index+1:] + } + } + + if strings.HasPrefix(label, strings.ToUpper(lastWord)) { filtered = append(filtered, candidate) } } diff --git a/internal/database/cache.go b/internal/database/cache.go index dd0ba48..1da6ab7 100644 --- a/internal/database/cache.go +++ b/internal/database/cache.go @@ -19,7 +19,7 @@ func NewDBCacheUpdater(repo DBRepository) *DBCacheGenerator { func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache, error) { var err error dbCache := &DBCache{} - dbCache.defaultSchema, err = u.repo.CurrentSchema(ctx) + dbCache.DefaultSchema, err = u.repo.CurrentSchema(ctx) if err != nil { return nil, err } @@ -32,13 +32,13 @@ func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache dbCache.Schemas[strings.ToUpper(index)] = element } - if dbCache.defaultSchema == "" { + if dbCache.DefaultSchema == "" { var topKey string for k := range dbCache.Schemas { topKey = k continue } - dbCache.defaultSchema = dbCache.Schemas[topKey] + dbCache.DefaultSchema = dbCache.Schemas[topKey] } schemaTables, err := u.repo.SchemaTables(ctx) if err != nil { @@ -49,11 +49,11 @@ func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache dbCache.SchemaTables[strings.ToUpper(index)] = element } - dbCache.ColumnsWithParent, err = u.genColumnCacheCurrent(ctx, dbCache.defaultSchema) + dbCache.ColumnsWithParent, err = u.genColumnCacheCurrent(ctx, dbCache.DefaultSchema) if err != nil { return nil, err } - dbCache.ForeignKeys, err = u.genForeignKeysCache(ctx, dbCache.defaultSchema) + dbCache.ForeignKeys, err = u.genForeignKeysCache(ctx, dbCache.DefaultSchema) if err != nil { return nil, err } @@ -128,7 +128,7 @@ func genColumnMap(columnDescs []*ColumnDesc) map[string][]*ColumnDesc { } type DBCache struct { - defaultSchema string + DefaultSchema string Schemas map[string]string SchemaTables map[string][]string ColumnsWithParent map[string][]*ColumnDesc @@ -156,12 +156,12 @@ func (dc *DBCache) SortedTablesByDBName(dbName string) (tbls []string, ok bool) } func (dc *DBCache) SortedTables() []string { - tbls, _ := dc.SortedTablesByDBName(dc.defaultSchema) + tbls, _ := dc.SortedTablesByDBName(dc.DefaultSchema) return tbls } func (dc *DBCache) ColumnDescs(tableName string) (cols []*ColumnDesc, ok bool) { - cols, ok = dc.ColumnsWithParent[columnDatabaseKey(dc.defaultSchema, tableName)] + cols, ok = dc.ColumnsWithParent[columnDatabaseKey(dc.DefaultSchema, tableName)] return } @@ -171,7 +171,7 @@ func (dc *DBCache) ColumnDatabase(dbName, tableName string) (cols []*ColumnDesc, } func (dc *DBCache) Column(tableName, colName string) (*ColumnDesc, bool) { - cols, ok := dc.ColumnsWithParent[columnDatabaseKey(dc.defaultSchema, tableName)] + cols, ok := dc.ColumnsWithParent[columnDatabaseKey(dc.DefaultSchema, tableName)] if !ok { return nil, false } diff --git a/internal/database/database_mock.go b/internal/database/database_mock.go index 1f3be58..0afabd5 100644 --- a/internal/database/database_mock.go +++ b/internal/database/database_mock.go @@ -21,35 +21,44 @@ type MockDBRepository struct { } func NewMockDBRepository(_ *sql.DB) DBRepository { + defaultDatabase := "world" + return &MockDBRepository{ - MockDatabase: func(ctx context.Context) (string, error) { return "world", nil }, + MockDatabase: func(ctx context.Context) (string, error) { return defaultDatabase, nil }, MockDatabases: func(ctx context.Context) ([]string, error) { return dummyDatabases, nil }, MockDatabaseTables: func(ctx context.Context) (map[string][]string, error) { return dummyDatabaseTables, nil }, - MockTables: func(ctx context.Context) ([]string, error) { return dummyTables, nil }, + MockTables: func(ctx context.Context) ([]string, error) { return dummyDatabaseTables[defaultDatabase], nil }, MockDescribeTable: func(ctx context.Context, tableName string) ([]*ColumnDesc, error) { - switch tableName { - case "city": - return dummyCityColumns, nil - case "country": - return dummyCountryColumns, nil - case "countrylanguage": - return dummyCountryLanguageColumns, nil + var res []*ColumnDesc + schemaTables, ok := dummyColumns[defaultDatabase] + if !ok { + return res, nil + } + + columnTables, ok := schemaTables[tableName] + if !ok { + return res, nil } - return nil, nil + return columnTables, nil }, MockDescribeDatabaseTable: func(ctx context.Context) ([]*ColumnDesc, error) { var res []*ColumnDesc - res = append(res, dummyCityColumns...) - res = append(res, dummyCountryColumns...) - res = append(res, dummyCountryLanguageColumns...) + for _, tc := range dummyColumns { + for _, columns := range tc { + res = append(res, columns...) + } + } return res, nil - }, MockDescribeDatabaseTableBySchema: func(ctx context.Context, schemaName string) ([]*ColumnDesc, error) { var res []*ColumnDesc - res = append(res, dummyCityColumns...) - res = append(res, dummyCountryColumns...) - res = append(res, dummyCountryLanguageColumns...) + schemaTables, ok := dummyColumns[schemaName] + if !ok { + return res, nil + } + for _, cd := range schemaTables { + res = append(res, cd...) + } return res, nil }, @@ -63,6 +72,10 @@ func NewMockDBRepository(_ *sql.DB) DBRepository { return &sql.Rows{}, nil }, MockDescribeForeignKeysBySchema: func(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + foreignKeys, ok := foreignKeysBySchema[schemaName] + if !ok { + return nil, nil + } return foreignKeys, nil }, } @@ -129,11 +142,9 @@ var dummyDatabaseTables = map[string][]string{ "country", "countrylanguage", }, -} -var dummyTables = []string{ - "city", - "country", - "countrylanguage", + "mysql": { + "city_population", + }, } var dummyCityColumns = []*ColumnDesc{ { @@ -502,32 +513,94 @@ var dummyCountryLanguageColumns = []*ColumnDesc{ }, } -var foreignKeys = []*ForeignKey{ +var dummyCityPopulationColumns = []*ColumnDesc{ { - [2]*ColumnBase{ - { - Schema: "world", - Table: "city", - Name: "CountryCode", - }, - { - Schema: "world", - Table: "country", - Name: "Code", - }, + ColumnBase: ColumnBase{ + Schema: "mysql", + Table: "city_population", + Name: "population", }, + Type: "int(11)", + Null: "NO", + Key: "", + Default: sql.NullString{ + String: "0", + Valid: false, + }, + Extra: "", }, { - [2]*ColumnBase{ - { - Schema: "world", - Table: "countrylanguage", - Name: "CountryCode", + ColumnBase: ColumnBase{ + Schema: "mysql", + Table: "city_population", + Name: "city_id", + }, + Type: "int(11)", + Null: "NO", + Key: "PRI", + Default: sql.NullString{ + String: "", + Valid: false, + }, + Extra: "", + }, +} + +var dummyColumns = map[string]map[string][]*ColumnDesc{ + "world": { + "city": dummyCityColumns, + "county": dummyCountryColumns, + "countrylanguage": dummyCountryLanguageColumns, + }, + "mysql": { + "city_population": dummyCityPopulationColumns, + }, +} + +var foreignKeysBySchema = map[string][]*ForeignKey{ + "world": { + { + [2]*ColumnBase{ + { + Schema: "world", + Table: "city", + Name: "CountryCode", + }, + { + Schema: "world", + Table: "country", + Name: "Code", + }, + }, + }, + { + [2]*ColumnBase{ + { + Schema: "world", + Table: "countrylanguage", + Name: "CountryCode", + }, + { + Schema: "world", + Table: "country", + Name: "Code", + }, }, - { - Schema: "world", - Table: "country", - Name: "Code", + }, + }, + "mysql": { + { + [2]*ColumnBase{ + { + Schema: "mysql", + Table: "city_population", + Name: "city_id", + }, + { + Schema: "world", + Table: "city", + Name: "ID", + }, }, }, }, diff --git a/internal/handler/completion_test.go b/internal/handler/completion_test.go index e58a6b5..96cd722 100644 --- a/internal/handler/completion_test.go +++ b/internal/handler/completion_test.go @@ -1,20 +1,25 @@ package handler import ( + "errors" + "fmt" "testing" + "github.com/google/go-cmp/cmp" "github.com/sqls-server/sqls/internal/config" "github.com/sqls-server/sqls/internal/database" "github.com/sqls-server/sqls/internal/lsp" ) type completionTestCase struct { - name string - input string - line int - col int - want []string - bad []string + name string + input string + line int + col int + want []string + bad []string + filter func(*lsp.CompletionItem) bool + validator func(*lsp.CompletionItem) error } var statementCase = []completionTestCase{ @@ -306,6 +311,26 @@ var selectExprCase = []completionTestCase{ "countrylanguage", }, }, + { + name: "columns of table from non default database", + input: "select from mysql.city_population", + line: 0, + col: 7, + want: []string{ + "city_id", + "population", + }, + }, + { + name: "columns of table from non default database filter by name", + input: "select city_population. from mysql.city_population", + line: 0, + col: 23, + want: []string{ + "city_id", + "population", + }, + }, } var tableReferenceCase = []completionTestCase{ @@ -323,6 +348,16 @@ var tableReferenceCase = []completionTestCase{ "performance_schema", "sys", "world", + "mysql.city_population", + }, + }, + { + name: "from tables from a non default database", + input: "select * from mysql.", + line: 0, + col: 20, + want: []string{ + "mysql.city_population", }, }, { @@ -339,6 +374,7 @@ var tableReferenceCase = []completionTestCase{ "`performance_schema`", "`sys`", "`world`", + "`mysql.city_population`", }, }, { @@ -351,6 +387,16 @@ var tableReferenceCase = []completionTestCase{ "countrylanguage", }, }, + { + name: "from filtered tables including non default schemas", + input: "select city_id from ci", + line: 0, + col: 22, + want: []string{ + "city", + "mysql.city_population", + }, + }, { name: "from quoted filtered tables", input: "select CountryCode from `co", @@ -370,6 +416,7 @@ var tableReferenceCase = []completionTestCase{ "city", "country", "countrylanguage", + "mysql.city_population", }, }, { @@ -382,6 +429,37 @@ var tableReferenceCase = []completionTestCase{ "countrylanguage", }, }, + { + name: "join filtered tables by schema", + input: "select CountryCode from city join mysql.c", + line: 0, + col: 40, + want: []string{ + "mysql.city_population", + }, + }, + { + name: "join table referenced on clause", + input: "select CountryCode from city join mysql.city_population ON ", + line: 0, + col: 59, + want: []string{ + "city", + "city_population", + }, + filter: func(item *lsp.CompletionItem) bool { + return item.Detail == "referenced table" + }, + validator: func(item *lsp.CompletionItem) error { + if item.Label == "city_population" { + expectedDoc := "# `city_population` table\n\n\n| Name   | Type   | Primary key   | Default   | Extra   |\n| :--------------- | :--------------- | :---------------------- | :------------------ | :---------------- |\n| `population` | `int(11)` | `` | `0` | |\n| `city_id` | `int(11)` | `PRI` | `` | |\n" + if diff := cmp.Diff(expectedDoc, item.Documentation.Value); diff != "" { + return errors.New(fmt.Sprintf("Expected different documentation for city_population, diff:\n%s", diff)) + } + } + return nil + }, + }, { name: "left join tables", input: "select CountryCode from city left join ", @@ -391,6 +469,7 @@ var tableReferenceCase = []completionTestCase{ "city", "country", "countrylanguage", + "mysql.city_population", }, }, { @@ -402,6 +481,7 @@ var tableReferenceCase = []completionTestCase{ "city", "country", "countrylanguage", + "mysql.city_population", }, }, { @@ -413,6 +493,7 @@ var tableReferenceCase = []completionTestCase{ "city", "country", "countrylanguage", + "mysql.city_population", }, }, { @@ -424,6 +505,7 @@ var tableReferenceCase = []completionTestCase{ "`city`", "`country`", "`countrylanguage`", + "`mysql.city_population`", }, }, { @@ -436,6 +518,15 @@ var tableReferenceCase = []completionTestCase{ "countrylanguage", }, }, + { + name: "insert filtered tables from non default schema", + input: "INSERT INTO mysql.c", + line: 0, + col: 19, + want: []string{ + "mysql.city_population", + }, + }, { name: "insert columns", input: "INSERT INTO city (", @@ -454,6 +545,21 @@ var tableReferenceCase = []completionTestCase{ "countrylanguage", }, }, + { + name: "insert columns from non default schema", + input: "INSERT INTO mysql.city_population (", + line: 0, + col: 35, + want: []string{ + "city_id", + "population", + }, + bad: []string{ + "city", + "country", + "countrylanguage", + }, + }, { name: "insert filtered columns", input: "INSERT INTO city (cou", @@ -481,6 +587,7 @@ var tableReferenceCase = []completionTestCase{ "city", "country", "countrylanguage", + "mysql.city_population", }, }, { @@ -533,6 +640,7 @@ var tableReferenceCase = []completionTestCase{ "city", "country", "countrylanguage", + "mysql.city_population", }, }, { @@ -561,6 +669,16 @@ var whereCondition = []completionTestCase{ "Population", }, }, + { + name: "where columns in non default schema", + input: "select * from mysql.city_population where ", + line: 0, + col: 43, + want: []string{ + "city_id", + "population", + }, + }, { name: "where columns of specified table", input: "select * from city where city.", @@ -574,6 +692,16 @@ var whereCondition = []completionTestCase{ "Population", }, }, + { + name: "where columns of specified table in non default schema", + input: "select * from mysql.city_population where city_population.", + line: 0, + col: 58, + want: []string{ + "city_id", + "population", + }, + }, { name: "where columns in left of comparison", input: "select * from city where = ID", @@ -884,6 +1012,17 @@ var joinClauseCase = []completionTestCase{ "countrylanguage", }, }, + { + name: "join filtered tables reversed", + input: "select CountryCode from country join cou", + line: 0, + col: 40, + want: []string{ + "countrylanguage c1 ON c1.CountryCode = country.Code", + "country", + "countrylanguage", + }, + }, { name: "join filtered tables with reference", input: "select c.CountryCode from city c join co", @@ -1122,7 +1261,7 @@ func TestCompleteMain(t *testing.T) { if err := tx.conn.Call(tx.ctx, "textDocument/completion", completionParams, &got); err != nil { t.Fatal("conn.Call textDocument/completion:", err) } - testCompletionItem(t, tt.want, tt.bad, got) + testCompletionItem(t, &tt, got) }) } } @@ -1172,7 +1311,7 @@ func TestCompleteJoin(t *testing.T) { if err := tx.conn.Call(tx.ctx, "textDocument/completion", completionParams, &got); err != nil { t.Fatal("conn.Call textDocument/completion:", err) } - testCompletionItem(t, tt.want, tt.bad, got) + testCompletionItem(t, &tt, got) }) } } @@ -1228,22 +1367,30 @@ func TestCompleteNoneDBConnection(t *testing.T) { } } -func testCompletionItem(t *testing.T, expectLabels []string, badLabels []string, gotItems []lsp.CompletionItem) { +func testCompletionItem(t *testing.T, tc *completionTestCase, gotItems []lsp.CompletionItem) { t.Helper() - itemMap := map[string]struct{}{} + itemMap := map[string]*lsp.CompletionItem{} for _, item := range gotItems { - itemMap[item.Label] = struct{}{} + if tc.filter != nil && !tc.filter(&item) { + continue + } + itemMap[item.Label] = &item } - for _, el := range expectLabels { - _, ok := itemMap[el] + for _, el := range tc.want { + item, ok := itemMap[el] if !ok { t.Errorf("expected to be included in the results, expect candidate %q", el) + } else if tc.validator != nil { + if err := tc.validator(item); err != nil { + t.Errorf("item %v didnt pass validaton: %s", item, err) + } + } } - for _, el := range badLabels { + for _, el := range tc.bad { _, ok := itemMap[el] if ok { t.Errorf("should not be included in the results, got candidate %q", el)