Skip to content

Commit

Permalink
Merge branch 'go-gorm:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
freehere107 authored Sep 29, 2022
2 parents b0f2f14 + 328f301 commit f01ca29
Show file tree
Hide file tree
Showing 36 changed files with 656 additions and 197 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
sqlite:
strategy:
matrix:
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}

Expand All @@ -42,7 +42,7 @@ jobs:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}

Expand Down Expand Up @@ -86,7 +86,7 @@ jobs:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}

Expand Down Expand Up @@ -128,7 +128,7 @@ jobs:
sqlserver:
strategy:
matrix:
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}

Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ documents
coverage.txt
_book
.idea
vendor
vendor
.vscode
4 changes: 3 additions & 1 deletion association.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,9 @@ func (association *Association) buildCondition() *DB {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
}

tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
Expand Down
37 changes: 6 additions & 31 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gorm

import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
Expand All @@ -16,13 +15,12 @@ import (
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
"transaction": {db: db},
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
},
}
}
Expand Down Expand Up @@ -74,29 +72,6 @@ func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
}

func (cs *callbacks) Transaction() *processor {
return cs.processors["transaction"]
}

func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB {
var err error

switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}

if err != nil {
_ = tx.AddError(err)
}

return tx
}

func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
Expand Down
4 changes: 2 additions & 2 deletions callbacks/associations.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
}
}

cacheKey := utils.ToStringKey(relPrimaryValues)
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true
if isPtr {
Expand Down Expand Up @@ -292,7 +292,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
}
}

cacheKey := utils.ToStringKey(relPrimaryValues)
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true
distinctElems = reflect.Append(distinctElems, elem)
Expand Down
32 changes: 17 additions & 15 deletions callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ func Update(config *Config) func(db *gorm.DB) {
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else if _, ok := db.Statement.Clauses["SET"]; !ok {
return
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return
}
}

db.Statement.Build(db.Statement.BuildClauses...)
Expand Down Expand Up @@ -158,21 +160,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 {
var primaryKeyExprs []clause.Expression
var isZero bool
for i := 0; i < size; i++ {
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
for _, field := range stmt.Schema.PrimaryFields {
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
if !isZero {
break
}
}
}

stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
Expand Down
Loading

0 comments on commit f01ca29

Please sign in to comment.