From 8b244a2fc92b711e308236a536f1fd97234ef144 Mon Sep 17 00:00:00 2001 From: thinkgos Date: Mon, 31 Jul 2023 21:01:22 +0800 Subject: [PATCH] fix: add foriegn key --- codegen/mapper.go | 18 ++++++++- def.go | 14 ++++++- driver/mysql/def.go | 46 +++++++++++++++++------ driver/mysql/def_utils.go | 15 +++----- entity.go | 61 ++++++++++++++++-------------- foreign_key.go | 79 +++++++++++++++++++++++++++++++++++++++ internal/sqlx/sqlx.go | 22 ++++++++--- 7 files changed, 198 insertions(+), 57 deletions(-) create mode 100644 foreign_key.go diff --git a/codegen/mapper.go b/codegen/mapper.go index ca76aa3..7e3f1dc 100644 --- a/codegen/mapper.go +++ b/codegen/mapper.go @@ -42,7 +42,7 @@ func (g *CodeGen) GenMapper() *CodeGen { g.P("// ", structName, " ", trimStructComment(et.Comment, "\n", "\n// ")) g.P("message ", structName, " {") if (et.Table != nil && et.Table.PrimaryKey() != nil) || - len(et.Indexes) > 0 { + len(et.Indexes) > 0 || len(et.ForeignKeys) > 0 { g.P("option (things_go.seaql.options) = {") g.P("index: [") remain := len(et.Indexes) @@ -51,10 +51,24 @@ func (g *CodeGen) GenMapper() *CodeGen { g.P("'", et.Table.PrimaryKey().Definition(), "'", ending) } for _, index := range et.Indexes { + remain-- + if et.Table != nil && + et.Table.PrimaryKey() != nil && + et.Table.PrimaryKey().Index().Name == index.Name { + continue + } ending := commaOrEmpty(remain) g.P("'", index.Index.Definition(), "'", ending) } - g.P("];") + g.P("],") + g.P("foreign_key: [") + remain = len(et.ForeignKeys) + for _, fk := range et.ForeignKeys { + remain-- + ending := commaOrEmpty(remain) + g.P("'", fk.ForeignKey.Definition(), "'", ending) + } + g.P("],") g.P("};") } g.P() diff --git a/def.go b/def.go index f013c0a..bdab4a3 100644 --- a/def.go +++ b/def.go @@ -19,6 +19,11 @@ type IndexDef interface { Definition() string } +type ForeignKeyDef interface { + ForeignKey() *schema.ForeignKey + Definition() string +} + type Schemaer interface { Build(opt *Option) *Schema } @@ -31,9 +36,14 @@ type Fielder interface { Build(opt *Option) *FieldDescriptor } +type ForeignKeyer interface { + Build() *ForeignKeyDescriptor +} + type MixinEntity interface { - Fields() []Fielder - Indexes() []Indexer Metadata() (string, string) Table() TableDef + Fields() []Fielder + Indexes() []Indexer + ForeignKeys() []ForeignKeyer } diff --git a/driver/mysql/def.go b/driver/mysql/def.go index 20d6bb2..ed76a4e 100644 --- a/driver/mysql/def.go +++ b/driver/mysql/def.go @@ -94,7 +94,7 @@ func (self *TableDef) Definition() string { b.Grow(64) fmt.Fprintf(b, "CREATE TABLE `%s` (\n", tb.Name) - remain := len(tb.Columns) + len(tb.Indexes) + remain := len(tb.Columns) + len(tb.Indexes) + len(tb.ForeignKeys) if tb.PrimaryKey != nil { remain++ } @@ -122,14 +122,18 @@ func (self *TableDef) Definition() string { } for _, val := range tb.Indexes { remain-- - if primaryKey(val.Name) { // ignore primary key, maybe include + if sqlx.IndexEqual(tb.PrimaryKey, val) { // ignore primary key, maybe include continue } suffix := suffixOrEmpty(remain) fmt.Fprintf(b, " %s%s\n", NewIndexDef(val).Definition(), suffix) } //* foreignKeys - // TODO: ForeignKeys + for _, val := range tb.ForeignKeys { + remain-- + suffix := suffixOrEmpty(remain) + fmt.Fprintf(b, " %s%s\n", NewForeignKey(val).Definition(), suffix) + } engine := mysql.EngineInnoDB charset := "utf8mb4" @@ -236,21 +240,21 @@ func (self *ColumnDef) GormTag(tb *schema.Table) string { fmt.Fprintf(b, ",priority:%d", pkPriority) } } - for _, v := range col.Indexes { - if primaryKey(v.Name) { // ignore primary key, may be include + for _, val := range col.Indexes { + if sqlx.IndexEqual(tb.PrimaryKey, val) { // ignore primary key, may be include continue } - if v.Unique { - fmt.Fprintf(b, ";uniqueIndex:%s", v.Name) + if val.Unique { + fmt.Fprintf(b, ";uniqueIndex:%s", val.Name) } else { - fmt.Fprintf(b, ";index:%s", v.Name) + fmt.Fprintf(b, ";index:%s", val.Name) // mysql.IndexTypeFullText // if v.IndexType == "FULLTEXT" { // b.WriteString(",class:FULLTEXT") // } } - if len(v.Parts) > 1 { - priority, ok := sqlx.FindIndexPartSeq(v.Parts, col) + if len(val.Parts) > 1 { + priority, ok := sqlx.FindIndexPartSeq(val.Parts, col) if ok { fmt.Fprintf(b, ",priority:%d", priority) } @@ -278,7 +282,7 @@ func (self *IndexDef) Definition() string { fields := sqlx.IndexPartColumnNames(index.Parts) indexType := findIndexType(index.Attrs) fieldList := "`" + strings.Join(fields, "`,`") + "`" - if primaryKey(index.Name) { + if sqlx.IndexEqual(index.Table.PrimaryKey, index) { return fmt.Sprintf("PRIMARY KEY (%s) USING %s", fieldList, indexType) } else if index.Unique { return fmt.Sprintf("UNIQUE KEY `%s` (%s) USING %s", index.Name, fieldList, indexType) @@ -286,3 +290,23 @@ func (self *IndexDef) Definition() string { return fmt.Sprintf("KEY `%s` (%s) USING %s", index.Name, fieldList, indexType) } } + +type ForeignKeyDef struct { + fk *schema.ForeignKey +} + +func NewForeignKey(fk *schema.ForeignKey) *ForeignKeyDef { + return &ForeignKeyDef{fk: fk} +} + +func (self *ForeignKeyDef) ForeignKey() *schema.ForeignKey { return self.fk } + +func (self *ForeignKeyDef) Definition() string { + fk := self.fk + columnNameList := "`" + strings.Join(sqlx.ColumnNames(fk.Columns), "`,`") + "`" + refColumnNameList := "`" + strings.Join(sqlx.ColumnNames(fk.RefColumns), "`,`") + "`" + return fmt.Sprintf( + "CONSTRAINT `%s` FOREIGN KEY (%s) REFERENCES `%s` (%s) ON DELETE %s ON UPDATE %s", + fk.Symbol, columnNameList, fk.RefTable.Name, refColumnNameList, fk.OnDelete, fk.OnUpdate, + ) +} diff --git a/driver/mysql/def_utils.go b/driver/mysql/def_utils.go index 6164682..eccc86f 100644 --- a/driver/mysql/def_utils.go +++ b/driver/mysql/def_utils.go @@ -1,8 +1,6 @@ package mysql import ( - "strings" - "ariga.io/atlas/sql/mysql" "ariga.io/atlas/sql/schema" "github.com/things-go/ens" @@ -13,11 +11,6 @@ func autoIncrement(attrs []schema.Attr) bool { return sqlx.Has(attrs, &mysql.AutoIncrement{}) } -func primaryKey(name string) bool { - return strings.EqualFold(name, "PRI") || - strings.EqualFold(name, "PRIMARY") -} - func findIndexType(attrs []schema.Attr) string { var t mysql.IndexType if sqlx.Has(attrs, &t) && t.T != "" { @@ -43,10 +36,14 @@ func IntoEntity(tb *schema.Table) ens.MixinEntity { indexers = append(indexers, ens.IndexFromDef(NewIndexDef(index))) } //* foreignKeys - // TODO: ... + fkers := make([]ens.ForeignKeyer, 0, len(tb.ForeignKeys)) + for _, fk := range tb.ForeignKeys { + fkers = append(fkers, ens.ForeignKeyFromDef(NewForeignKey(fk))) + } // * table return ens.EntityFromDef(NewTableDef(tb)). SetFields(fielders...). - SetIndexes(indexers...) + SetIndexes(indexers...). + SetForeignKeys(fkers...) } diff --git a/entity.go b/entity.go index 95ec4af..c1f2e80 100644 --- a/entity.go +++ b/entity.go @@ -1,15 +1,18 @@ package ens -import "github.com/things-go/ens/internal/sqlx" +import ( + "github.com/things-go/ens/internal/sqlx" +) // EntityDescriptor Each table corresponds to an EntityDescriptor type EntityDescriptor struct { - Name string // entity name - Comment string // entity comment - Table TableDef // entity table define - Fields []*FieldDescriptor // field information - Indexes []*IndexDescriptor // index information - ProtoMessage []*ProtoMessage // protobuf message information. + Name string // entity name + Comment string // entity comment + Table TableDef // entity table define + Fields []*FieldDescriptor // field information + Indexes []*IndexDescriptor // index information + ForeignKeys []*ForeignKeyDescriptor // foreign key information + ProtoMessage []*ProtoMessage // protobuf message information. } type EntityDescriptorSlice []*EntityDescriptor @@ -36,7 +39,11 @@ func BuildEntity(m MixinEntity, opt *Option) *EntityDescriptor { for _, v := range indexers { indexes = append(indexes, v.Build()) } - + fkers := m.ForeignKeys() + fks := make([]*ForeignKeyDescriptor, 0, len(fkers)) + for _, v := range fkers { + fks = append(fks, v.Build()) + } name, comment := m.Metadata() return &EntityDescriptor{ Name: name, @@ -44,6 +51,7 @@ func BuildEntity(m MixinEntity, opt *Option) *EntityDescriptor { Table: m.Table(), Fields: fields, Indexes: indexes, + ForeignKeys: fks, ProtoMessage: protoMessages, } } @@ -51,11 +59,12 @@ func BuildEntity(m MixinEntity, opt *Option) *EntityDescriptor { var _ MixinEntity = (*EntityBuilder)(nil) type EntityBuilder struct { - name string // schema entity name - comment string // schema entity comment - table TableDef // entity table define - fields []Fielder // field information - indexes []Indexer // index information + name string // schema entity name + comment string // schema entity comment + table TableDef // entity table define + fields []Fielder // field information + indexes []Indexer // index information + foreignKeys []ForeignKeyer // foreign key information } // EntityFromDef returns a new entity with the TableDef. @@ -76,6 +85,10 @@ func (self *EntityBuilder) SetMetadata(name, comment string) *EntityBuilder { self.comment = comment return self } +func (self *EntityBuilder) SetTable(tb TableDef) *EntityBuilder { + self.table = tb + return self +} func (self *EntityBuilder) SetFields(fields ...Fielder) *EntityBuilder { self.fields = fields return self @@ -84,20 +97,12 @@ func (self *EntityBuilder) SetIndexes(indexes ...Indexer) *EntityBuilder { self.indexes = indexes return self } -func (self *EntityBuilder) SetTable(tb TableDef) *EntityBuilder { - self.table = tb +func (self *EntityBuilder) SetForeignKeys(fks ...ForeignKeyer) *EntityBuilder { + self.foreignKeys = fks return self } - -func (self *EntityBuilder) Fields() []Fielder { - return self.fields -} -func (self *EntityBuilder) Indexes() []Indexer { - return self.indexes -} -func (self *EntityBuilder) Metadata() (name, comment string) { - return self.name, self.comment -} -func (self *EntityBuilder) Table() TableDef { - return self.table -} +func (self *EntityBuilder) Metadata() (name, comment string) { return self.name, self.comment } +func (self *EntityBuilder) Table() TableDef { return self.table } +func (self *EntityBuilder) Fields() []Fielder { return self.fields } +func (self *EntityBuilder) Indexes() []Indexer { return self.indexes } +func (self *EntityBuilder) ForeignKeys() []ForeignKeyer { return self.foreignKeys } diff --git a/foreign_key.go b/foreign_key.go new file mode 100644 index 0000000..d1d9ebe --- /dev/null +++ b/foreign_key.go @@ -0,0 +1,79 @@ +package ens + +import ( + "ariga.io/atlas/sql/schema" + "github.com/things-go/ens/internal/sqlx" +) + +type ForeignKeyDescriptor struct { + Symbol string + Table string + Columns []string + RefTable string + RefColumns []string + OnUpdate schema.ReferenceOption + OnDelete schema.ReferenceOption + ForeignKey ForeignKeyDef +} + +// ForeignKeyFromDef returns a new ForeignKey with the ForeignKeyDef. +func ForeignKey(symbol string) *foreignKeyBuilder { + return &foreignKeyBuilder{ + inner: &ForeignKeyDescriptor{ + Symbol: symbol, + Table: "", + Columns: nil, + RefTable: "", + RefColumns: nil, + OnUpdate: schema.Restrict, + OnDelete: schema.Restrict, + ForeignKey: nil, + }, + } +} + +// ForeignKeyFromDef returns a new ForeignKey with the ForeignKeyDef. +func ForeignKeyFromDef(def ForeignKeyDef) *foreignKeyBuilder { + fk := def.ForeignKey() + return &foreignKeyBuilder{ + inner: &ForeignKeyDescriptor{ + Symbol: fk.Symbol, + Table: fk.Table.Name, + Columns: sqlx.ColumnNames(fk.Columns), + RefTable: fk.RefTable.Name, + RefColumns: sqlx.ColumnNames(fk.RefColumns), + OnUpdate: fk.OnUpdate, + OnDelete: fk.OnDelete, + ForeignKey: def, + }, + } +} + +// foreignKeyBuilder is the builder for ForeignKey. +type foreignKeyBuilder struct { + inner *ForeignKeyDescriptor +} + +func (b *foreignKeyBuilder) Table(tbName string, columns []string) *foreignKeyBuilder { + b.inner.Table = tbName + b.inner.Columns = columns + return b +} +func (b *foreignKeyBuilder) RefTable(tbName string, columns []string) *foreignKeyBuilder { + b.inner.RefTable = tbName + b.inner.RefColumns = columns + return b +} +func (b *foreignKeyBuilder) OnDelete(v schema.ReferenceOption) *foreignKeyBuilder { + b.inner.OnDelete = v + return b +} +func (b *foreignKeyBuilder) OnUpdate(v schema.ReferenceOption) *foreignKeyBuilder { + b.inner.OnUpdate = v + return b +} + +// Build implements the ForeignKeyer interface by returning its descriptor. +func (b *foreignKeyBuilder) Build() *ForeignKeyDescriptor { + return b.inner +} diff --git a/internal/sqlx/sqlx.go b/internal/sqlx/sqlx.go index 363cef6..fc53c25 100644 --- a/internal/sqlx/sqlx.go +++ b/internal/sqlx/sqlx.go @@ -77,6 +77,10 @@ func DefaultValue(c *schema.Column) (string, bool) { } } +func IndexEqual(idx1, idx2 *schema.Index) bool { + return idx1 != nil && idx2 != nil && (idx1 == idx2 || idx1.Name == idx2.Name) +} + func FindIndexPartSeq(parts []*schema.IndexPart, col *schema.Column) (int, bool) { for _, p := range parts { if p.C == col || p.C.Name == col.Name { @@ -86,6 +90,14 @@ func FindIndexPartSeq(parts []*schema.IndexPart, col *schema.Column) (int, bool) return 0, false } +func IndexPartColumnNames(parts []*schema.IndexPart) []string { + fields := make([]string, 0, len(parts)) + for _, v := range parts { + fields = append(fields, v.C.Name) + } + return fields +} + func FindColumn(columns []*schema.Column, columnName string) (*schema.Column, bool) { for _, col := range columns { if col.Name == columnName { @@ -95,12 +107,12 @@ func FindColumn(columns []*schema.Column, columnName string) (*schema.Column, bo return nil, false } -func IndexPartColumnNames(parts []*schema.IndexPart) []string { - fields := make([]string, 0, len(parts)) - for _, v := range parts { - fields = append(fields, v.C.Name) +func ColumnNames(columns []*schema.Column) []string { + ns := make([]string, 0, len(columns)) + for _, col := range columns { + ns = append(ns, col.Name) } - return fields + return ns } // P returns a pointer to v.