diff --git a/callbacks/associations.go b/callbacks/associations.go index f3cd464ae6..0fa7655915 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -178,7 +178,8 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } else if ref.PrimaryValue != "" { db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) } - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + + assignmentColumns = append(assignmentColumns, getAssignmentColumnsForForeignKey(ref.ForeignKey)...) } saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) @@ -431,6 +432,33 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Val return db.AddError(tx.Create(values).Error) } +func getAssignmentColumnsForForeignKey(foreignKey *schema.Field) []string { + var assignmentColumns []string + + if foreignKey.Schema == nil { + return assignmentColumns + } + + if !foreignKey.PrimaryKey { + assignmentColumns = append(assignmentColumns, foreignKey.DBName) + return assignmentColumns + } + + for _, field := range foreignKey.Schema.Fields { + if field.PrimaryKey { + continue + } + + if field.DBName == "" { + continue + } + + assignmentColumns = append(assignmentColumns, field.DBName) + } + + return assignmentColumns +} + // check association values has been saved // if values kind is Struct, check it has been saved // if values kind is Slice/Array, check all items have been saved diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index 78290ce90b..ca3f71967c 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -256,14 +256,44 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { AssertAssociationCount(t, pets, "Toy", 0, "After Clear") } -func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) { - user := User{Name: "jinzhu", Account: Account{Number: "1"}} +func TestReplaceHasOneAssociationWithCustomPK(t *testing.T) { + if DB.Dialector.Name() == "sqlite" { + return + } - if err := DB.Create(&user).Error; err != nil { + DB.Migrator().DropTable(&Owner{}) + DB.Migrator().DropTable(&CreditCard{}) + + DB.AutoMigrate(&CreditCard{}) + DB.AutoMigrate(&Owner{}) + + owner := Owner{ + Name: "jinzhu", + CreditCard: CreditCard{ + Number: "123", + UserName: "jinzhu", + }, + } + + if err := DB.Create(&owner).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } - if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil { - t.Error("expected association error to be not nil") + wantNumber := "456" + + if err := DB.Model(&owner).Association("CreditCard").Replace(&CreditCard{ + Number: wantNumber, + UserName: "jinzhu", + }); err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var result Owner + if err := DB.Preload("CreditCard").First(&result, owner.ID).Error; err != nil { + t.Fatalf("errors happened when getting credit card: %v", err) + } + + if result.CreditCard.Number != wantNumber { + t.Fatal("wrong credit card number") } } diff --git a/utils/tests/models.go b/utils/tests/models.go index f9f4f50ecd..f104169d4e 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -102,3 +102,14 @@ type Child struct { ParentID *uint Parent *Parent } + +type Owner struct { + gorm.Model + Name string `gorm:"index"` + CreditCard CreditCard `gorm:"foreignKey:OwnerName;references:name"` +} + +type CreditCard struct { + Number string + UserName string `gorm:"primaryKey;unique;size:255"` +}