From 19ac396a22668e2cdbd77a262de84478787989d0 Mon Sep 17 00:00:00 2001 From: li-jin-gou <97824201+li-jin-gou@users.noreply.github.com> Date: Tue, 15 Feb 2022 20:32:03 +0800 Subject: [PATCH 01/92] fix: isPrintable incorrect (#5076) * fix: isPrintable incorrect * fix: isPrintable incorrect * style: use ReplaceAll instead of Replace --- logger/sql.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index e0be57c01..04a2dbd49 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -19,9 +19,9 @@ const ( nullStr = "NULL" ) -func isPrintable(s []byte) bool { +func isPrintable(s string) bool { for _, r := range s { - if !unicode.IsPrint(rune(r)) { + if !unicode.IsPrint(r) { return false } } @@ -84,8 +84,8 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } } case []byte: - if isPrintable(v) { - vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + if s := string(v); isPrintable(s) { + vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } From 39d84cba5f7403dd60aee6f7aa2cb0b6bb48f82b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 16 Feb 2022 15:30:43 +0800 Subject: [PATCH 02/92] Add serializer support (#5078) * Update context * Update GormFieldValuer * Add Serializer * Add Serializer Interface * Refactor gorm field * Refactor setter, valuer * Add sync.Pool * Fix test * Add pool manager * Fix pool manager * Add poolInitializer * Add Serializer Scan support * Add Serializer Value method * Add serializer test * Finish Serializer * Fix JSONSerializer for postgres * Fix JSONSerializer for sqlserver * Test serializer tag * Add unixtime serializer * Update go.mod --- association.go | 64 ++-- callbacks/associations.go | 58 ++-- callbacks/create.go | 40 +-- callbacks/delete.go | 8 +- callbacks/preload.go | 28 +- callbacks/query.go | 2 +- callbacks/update.go | 14 +- finisher_api.go | 12 +- interfaces.go | 4 + scan.go | 32 +- schema/field.go | 552 ++++++++++++++++++++--------------- schema/field_test.go | 13 +- schema/interfaces.go | 11 + schema/pool.go | 62 ++++ schema/relationship.go | 5 +- schema/schema_helper_test.go | 3 +- schema/serializer.go | 125 ++++++++ schema/utils.go | 17 +- soft_delete.go | 4 +- statement.go | 16 +- tests/create_test.go | 2 +- tests/go.mod | 2 +- tests/serializer_test.go | 71 +++++ utils/utils.go | 17 +- 24 files changed, 767 insertions(+), 395 deletions(-) create mode 100644 schema/pool.go create mode 100644 schema/serializer.go create mode 100644 tests/serializer_test.go diff --git a/association.go b/association.go index 62c25b711..09e79ca60 100644 --- a/association.go +++ b/association.go @@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error { primaryFields []*schema.Field foreignKeys []string updateMap = map[string]interface{}{} - relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) - if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } @@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error { } } - if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } @@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrPrimaryKeyRequired } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } @@ -186,11 +186,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.BelongsTo: tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -198,11 +198,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.HasOne, schema.HasMany: tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -228,11 +228,11 @@ func (association *Association) Delete(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -241,11 +241,11 @@ func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { // clean up deleted values's foreign key - relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { - if _, zero := rel.Field.ValueOf(data); !zero { - fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { @@ -253,7 +253,7 @@ func (association *Association) Delete(values ...interface{}) error { validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { @@ -261,23 +261,23 @@ func (association *Association) Delete(values ...interface{}) error { } } - association.Error = rel.Field.Set(data, validFieldValues.Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { - if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { break } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { - association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { - association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -329,14 +329,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { - association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: - association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) @@ -344,7 +344,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) if clear { fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() } @@ -373,7 +373,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) } } } @@ -421,7 +421,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { association.Error = err break } @@ -429,7 +429,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { association.Error = err break } @@ -453,12 +453,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Struct: // clear old data if clear && len(values) == 0 { - association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -475,7 +475,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } for _, assignBack := range assignBacks { - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) if assignBack.Index > 0 { reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) } else { @@ -486,7 +486,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ func (association *Association) buildCondition() *DB { var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) diff --git a/callbacks/associations.go b/callbacks/associations.go index 75bd6c6a1..d6fd21ded 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { setupReferences := func(obj reflect.Value, elem reflect.Value) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elem) - db.AddError(ref.ForeignKey.Set(obj, pv)) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv @@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { break } - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value objs = append(objs, obj) if isPtr { elems = reflect.Append(elems, rv) @@ -76,8 +76,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value if rv.Kind() != reflect.Ptr { rv = rv.Addr() } @@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) if rv.Kind() != reflect.Ptr { rv = rv.Addr() } for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) } } @@ -149,8 +149,8 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) if f.Kind() != reflect.Ptr { f = f.Addr() } @@ -158,10 +158,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(f, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + ref.ForeignKey.Set(db.Statement.Context, f, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(f, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(v) - ref.ForeignKey.Set(elem, pv) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) + ref.ForeignKey.Set(db.Statement.Context, elem, pv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue) } } relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) for _, pf := range rel.FieldSchema.PrimaryFields { - if pfv, ok := pf.ValueOf(elem); !ok { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { relPrimaryValues = append(relPrimaryValues, pfv) } } @@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue) } else { - fv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) } } joins = reflect.Append(joins, joinValue) } appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) diff --git a/callbacks/create.go b/callbacks/create.go index 291131283..b0964e2b6 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -117,9 +117,9 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -130,16 +130,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID) } } } @@ -219,23 +219,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(rv, field.DefaultValueInterface) + field.Set(stmt.Context, rv, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(rv); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } @@ -259,23 +259,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.ReflectValue, field.DefaultValueInterface) + field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], rvOfvalue) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 7f1e09cee..1fb5261cb 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { switch rel.Type { case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false @@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { } } - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) queryConds = append(queryConds, clause.IN{Column: column, Values: values}) @@ -123,7 +123,7 @@ func Delete(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Delete{}) if db.Statement.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -131,7 +131,7 @@ func Delete(config *Config) func(db *gorm.DB) { } if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { diff --git a/callbacks/preload.go b/callbacks/preload.go index 41405a22a..2363a8cab 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -48,7 +48,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { return } @@ -63,11 +63,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(joinIndexValue) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) + joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -76,7 +76,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -92,7 +92,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { return } @@ -125,17 +125,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } } } @@ -143,7 +143,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(elem) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] @@ -154,7 +154,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) + reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -162,12 +162,12 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(data, elem.Interface()) + rel.Field.Set(db.Statement.Context, data, elem.Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 490863549..03798859d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) { if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { - if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) } } diff --git a/callbacks/update.go b/callbacks/update.go index 511e994e7..4f07ca304 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]) } } } @@ -137,13 +137,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { if stmt.ReflectValue.CanAddr() { - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } default: @@ -165,7 +165,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) exprs[idx] = clause.Eq{Column: field.DBName, Value: value} notZero = notZero || !isZero } @@ -178,7 +178,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } @@ -258,7 +258,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field := updatingSchema.LookUpField(dbName); field != nil { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) + value, isZero := field.ValueOf(stmt.Context, updatingValue) if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() @@ -278,7 +278,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } diff --git a/finisher_api.go b/finisher_api.go index 3a1799778..d2a8b981e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -83,7 +83,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { - if _, isZero := pf.ValueOf(reflectValue); isZero { + if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { return tx.callbacks.Create().Execute(tx) } } @@ -199,7 +199,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -216,11 +216,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { switch column := eq.Column.(type) { case string: if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } case clause.Column: if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { @@ -238,9 +238,9 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { case reflect.Struct: for _, f := range s.Fields { if f.Readable { - if v, isZero := f.ValueOf(reflectValue); !isZero { + if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero { if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v)) } } } diff --git a/interfaces.go b/interfaces.go index 44b2fcedb..ff0ca60ae 100644 --- a/interfaces.go +++ b/interfaces.go @@ -40,14 +40,17 @@ type SavePointerDialectorInterface interface { RollbackTo(tx *DB, name string) error } +// TxBeginner tx beginner type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } +// ConnPoolBeginner conn pool beginner type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxCommitter tx commiter type TxCommitter interface { Commit() error Rollback() error @@ -58,6 +61,7 @@ type Valuer interface { GormValue(context.Context, *DB) clause.Expr } +// GetDBConnector SQL db connector type GetDBConnector interface { GetDBConn() (*sql.DB, error) } diff --git a/scan.go b/scan.go index b03b79b45..0da12dafb 100644 --- a/scan.go +++ b/scan.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/schema" ) +// prepareValues prepare values slice func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { @@ -54,11 +55,13 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re if sch == nil { values[idx] = reflectValue.Interface() } else if field := sch.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) continue } } @@ -77,21 +80,21 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re if sch != nil { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(reflectValue, values[idx]) + field.Set(db.Statement.Context, reflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(reflectValue) - value := reflect.ValueOf(values[idx]).Elem() + relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { continue } + relValue.Set(reflect.New(relValue.Type().Elem())) } - field.Set(relValue, values[idx]) + field.Set(db.Statement.Context, relValue, values[idx]) } } } @@ -99,14 +102,17 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re } } +// ScanMode scan data mode type ScanMode uint8 +// scan modes const ( ScanInitialized ScanMode = 1 << 0 // 1 ScanUpdate ScanMode = 1 << 1 // 2 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) +// Scan scan rows into db statement func Scan(rows *sql.Rows, db *DB, mode ScanMode) { var ( columns, _ = rows.Columns() @@ -138,7 +144,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } scanIntoMap(mapValue, values, columns) } - case *[]map[string]interface{}, []map[string]interface{}: + case *[]map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) @@ -149,11 +155,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) - if values, ok := dest.([]map[string]interface{}); ok { - values = append(values, mapValue) - } else if values, ok := dest.(*[]map[string]interface{}); ok { - *values = append(*values, mapValue) - } + *dest = append(*dest, mapValue) } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, @@ -174,7 +176,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { reflectValue = db.Statement.ReflectValue ) - if reflectValue.Kind() == reflect.Interface { + for reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } @@ -244,7 +246,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflectValue.Index(int(db.RowsAffected)) if onConflictDonothing { for _, field := range fields { - if _, ok := field.ValueOf(elem); !ok { + if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { db.RowsAffected++ goto BEGIN } diff --git a/schema/field.go b/schema/field.go index 485bbdf3d..319f36934 100644 --- a/schema/field.go +++ b/schema/field.go @@ -1,6 +1,7 @@ package schema import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -14,12 +15,21 @@ import ( "gorm.io/gorm/utils" ) -type DataType string - -type TimeType int64 +// special types' reflect type +var ( + TimeReflectType = reflect.TypeOf(time.Time{}) + TimePtrReflectType = reflect.TypeOf(&time.Time{}) + ByteReflectType = reflect.TypeOf(uint8(0)) +) -var TimeReflectType = reflect.TypeOf(time.Time{}) +type ( + // DataType GORM data type + DataType string + // TimeType GORM time type + TimeType int64 +) +// GORM time types const ( UnixTime TimeType = 1 UnixSecond TimeType = 2 @@ -27,6 +37,7 @@ const ( UnixNanosecond TimeType = 4 ) +// GORM fields types const ( Bool DataType = "bool" Int DataType = "int" @@ -37,6 +48,7 @@ const ( Bytes DataType = "bytes" ) +// Field is the representation of model schema's field type Field struct { Name string DBName string @@ -49,9 +61,9 @@ type Field struct { Creatable bool Updatable bool Readable bool - HasDefaultValue bool AutoCreateTime TimeType AutoUpdateTime TimeType + HasDefaultValue bool DefaultValue string DefaultValueInterface interface{} NotNull bool @@ -60,6 +72,7 @@ type Field struct { Size int Precision int Scale int + IgnoreMigration bool FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -68,27 +81,39 @@ type Field struct { Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema - ReflectValueOf func(reflect.Value) reflect.Value - ValueOf func(reflect.Value) (value interface{}, zero bool) - Set func(reflect.Value, interface{}) error - IgnoreMigration bool + ReflectValueOf func(context.Context, reflect.Value) reflect.Value + ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) + Set func(context.Context, reflect.Value, interface{}) error + Serializer SerializerInterface + NewValuePool FieldNewValuePool } +// ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { - var err error + var ( + err error + tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") + ) field := &Field{ Name: fieldStruct.Name, + DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, + Tag: fieldStruct.Tag, + TagSettings: tagSetting, + Schema: schema, Creatable: true, Updatable: true, Readable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), - Schema: schema, + PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), + Unique: utils.CheckTruth(tagSetting["UNIQUE"]), + Comment: tagSetting["COMMENT"], AutoIncrementIncrement: 1, } @@ -97,7 +122,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - // if field is valuer, used its value or first fields as data type + // if field is valuer, used its value or first field as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { @@ -105,31 +130,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue = reflect.ValueOf(v) } + // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { - rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { - for i := 0; i < rv.Type().NumField(); i++ { - newFieldType := rv.Type().Field(i).Type + var ( + rv = reflect.Indirect(v) + rvType = rv.Type() + ) + + if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { + for i := 0; i < rvType.NumField(); i++ { + for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + + for i := 0; i < rvType.NumField(); i++ { + newFieldType := rvType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - - if rv.Type() != reflect.Indirect(fieldValue).Type() { + if rvType != reflect.Indirect(fieldValue).Type() { getRealFieldValue(fieldValue) } if fieldValue.IsValid() { return } - - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value - } - } } } } @@ -138,19 +169,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if dbName, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = dbName - } - - if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } - - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { - field.AutoIncrement = true - field.HasDefaultValue = true + if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + field.DataType = String + field.Serializer = v + } else { + var serializerName = field.TagSettings["JSON"] + if serializerName == "" { + serializerName = field.TagSettings["SERIALIZER"] + } + if serializerName != "" { + if serializer, ok := GetSerializer(serializerName); ok { + // Set default data type to string for serializer + field.DataType = String + field.Serializer = serializer + } else { + schema.err = fmt.Errorf("invalid serializer type %v", serializerName) + } + } } if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { @@ -176,20 +211,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Scale, _ = strconv.Atoi(s) } - if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } - - if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { - field.Unique = true - } - - if val, ok := field.TagSettings["COMMENT"]; ok { - field.Comment = val - } - // default value is function or null or blank (primary keys) field.DefaultValue = strings.TrimSpace(field.DefaultValue) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && @@ -225,7 +246,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } case reflect.String: field.DataType = String - if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, `"`) @@ -236,17 +256,15 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } case reflect.Array, reflect.Slice: - if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { + if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes } } - field.GORMDataType = field.DataType - if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } @@ -346,8 +364,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes && - (ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) { + // Normal anonymous field or having `EMBEDDED` tag + if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && + fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { kind := reflect.Indirect(fieldValue).Kind() switch kind { case reflect.Struct: @@ -410,95 +429,122 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { - // ValueOf - switch { - case len(field.StructField.Index) == 1: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue.Interface(), fieldValue.IsZero() - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - return fieldValue.Interface(), fieldValue.IsZero() + // Setup NewValuePool + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + } + }, } - default: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - v := reflect.Indirect(value) + } else if _, ok := fieldValue.(sql.Scanner); !ok { + // set default NewValuePool + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool + } + } + } - for _, idx := range field.StructField.Index { - if idx >= 0 { - v = v.Field(idx) - } else { - v = v.Field(-idx - 1) + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } - if v.Type().Elem().Kind() != reflect.Struct { - return nil, true - } + // ValueOf returns field's value and if it is zero + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) - if !v.IsNil() { - v = v.Elem() - } else { - return nil, true - } + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true } } - return v.Interface(), v.IsZero() } + + fv, zero := v.Interface(), v.IsZero() + return fv, zero } - // ReflectValueOf - switch { - case len(field.StructField.Index) == 1: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]) - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + if field.Serializer != nil { + oldValuerOf := field.ValueOf + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + value, zero := oldValuerOf(ctx, v) + if zero { + return value, zero + } + + s, ok := value.(SerializerValuerInterface) + if !ok { + s = field.Serializer + } + + return serializer{ + Field: field, + SerializeValuer: s, + Destination: v, + Context: ctx, + fieldValue: value, + }, false } - default: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - v := reflect.Indirect(value) - for idx, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) - } + } - if v.Kind() == reflect.Ptr { - if v.Type().Elem().Kind() == reflect.Struct { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - } + // ReflectValueOf returns field's reflect value + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) - if idx < len(field.StructField.Index)-1 { - v = v.Elem() - } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() } } - return v } + return v } - fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { if v == nil { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) // Optimal value type acquisition for v reflectValType := reflectV.Type() if reflectValType.AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) return } else if reflectValType.ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) return } else if field.FieldType.Kind() == reflect.Ptr { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) fieldType := field.FieldType.Elem() if reflectValType.AssignableTo(fieldType) { @@ -521,13 +567,16 @@ func (field *Field) setupValuerAndSetter() { if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) + return } else { - err = setter(value, reflectV.Elem().Interface()) + err = setter(ctx, value, reflectV.Elem().Interface()) } } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = setter(value, v) + err = setter(ctx, value, v) } } else { return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) @@ -540,191 +589,201 @@ func (field *Field) setupValuerAndSetter() { // Set switch field.FieldType.Kind() { case reflect.Bool: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { - case bool: - field.ReflectValueOf(value).SetBool(data) - case *bool: - if data != nil { - field.ReflectValueOf(value).SetBool(*data) - } else { - field.ReflectValueOf(value).SetBool(false) + case **bool: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetBool(**data) } + case bool: + field.ReflectValueOf(ctx, value).SetBool(data) case int64: - if data > 0 { - field.ReflectValueOf(value).SetBool(true) - } else { - field.ReflectValueOf(value).SetBool(false) - } + field.ReflectValueOf(ctx, value).SetBool(data > 0) case string: b, _ := strconv.ParseBool(data) - field.ReflectValueOf(value).SetBool(b) + field.ReflectValueOf(ctx, value).SetBool(b) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **int64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(**data) + } case int64: - field.ReflectValueOf(value).SetInt(data) + field.ReflectValueOf(ctx, value).SetInt(data) case int: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetInt(i) + field.ReflectValueOf(ctx, value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } } else { - field.ReflectValueOf(value).SetInt(0) + field.ReflectValueOf(ctx, value).SetInt(0) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **uint64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(**data) + } case uint64: - field.ReflectValueOf(value).SetUint(data) + field.ReflectValueOf(ctx, value).SetUint(data) case uint: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) } else { - field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetUint(i) + field.ReflectValueOf(ctx, value).SetUint(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **float64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(**data) + } case float64: - field.ReflectValueOf(value).SetFloat(data) + field.ReflectValueOf(ctx, value).SetFloat(data) case float32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { - field.ReflectValueOf(value).SetFloat(i) + field.ReflectValueOf(ctx, value).SetFloat(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.String: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **string: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetString(**data) + } case string: - field.ReflectValueOf(value).SetString(data) + field.ReflectValueOf(ctx, value).SetString(data) case []byte: - field.ReflectValueOf(value).SetString(string(data)) + field.ReflectValueOf(ctx, value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(utils.ToString(data)) + field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) case float64, float32: - field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } @@ -732,41 +791,49 @@ func (field *Field) setupValuerAndSetter() { fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.Set(ctx, value, *data) + } case time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case *time.Time: if data != nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) } else { - field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) } case string: if t, err := now.Parse(data); err == nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(t)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case *time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } case time.Time: - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { if v == "" { return nil @@ -778,27 +845,27 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } @@ -813,32 +880,61 @@ func (field *Field) setupValuerAndSetter() { } } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else { - field.Set = func(value reflect.Value, v interface{}) (err error) { - return fallbackSetter(value, v, field.Set) + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + return fallbackSetter(ctx, value, v, field.Set) } } } } + + if field.Serializer != nil { + var ( + oldFieldSetter = field.Set + sameElemType bool + sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() + ) + + if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { + sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() + } + + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + if s, ok := v.(*serializer); ok { + if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if sameElemType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } else if sameType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } + } + } else { + err = oldFieldSetter(ctx, value, v) + } + return + } + } } diff --git a/schema/field_test.go b/schema/field_test.go index 8fa46b876..300e375b4 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "database/sql" "reflect" "sync" @@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } diff --git a/schema/interfaces.go b/schema/interfaces.go index 98abffbd4..a75a33c0d 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -4,22 +4,33 @@ import ( "gorm.io/gorm/clause" ) +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string } +// FieldNewValuePool field new scan value pool +type FieldNewValuePool interface { + Get() interface{} + Put(interface{}) +} + +// CreateClausesInterface create clauses interface type CreateClausesInterface interface { CreateClauses(*Field) []clause.Interface } +// QueryClausesInterface query clauses interface type QueryClausesInterface interface { QueryClauses(*Field) []clause.Interface } +// UpdateClausesInterface update clauses interface type UpdateClausesInterface interface { UpdateClauses(*Field) []clause.Interface } +// DeleteClausesInterface delete clauses interface type DeleteClausesInterface interface { DeleteClauses(*Field) []clause.Interface } diff --git a/schema/pool.go b/schema/pool.go new file mode 100644 index 000000000..f5c73153d --- /dev/null +++ b/schema/pool.go @@ -0,0 +1,62 @@ +package schema + +import ( + "reflect" + "sync" + "time" +) + +// sync pools +var ( + normalPool sync.Map + stringPool = &sync.Pool{ + New: func() interface{} { + var v string + ptrV := &v + return &ptrV + }, + } + intPool = &sync.Pool{ + New: func() interface{} { + var v int64 + ptrV := &v + return &ptrV + }, + } + uintPool = &sync.Pool{ + New: func() interface{} { + var v uint64 + ptrV := &v + return &ptrV + }, + } + floatPool = &sync.Pool{ + New: func() interface{} { + var v float64 + ptrV := &v + return &ptrV + }, + } + boolPool = &sync.Pool{ + New: func() interface{} { + var v bool + ptrV := &v + return &ptrV + }, + } + timePool = &sync.Pool{ + New: func() interface{} { + var v time.Time + ptrV := &v + return &ptrV + }, + } + poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { + v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ + New: func() interface{} { + return reflect.New(reflectType).Interface() + }, + }) + return v.(FieldNewValuePool) + } +) diff --git a/schema/relationship.go b/schema/relationship.go index c5d3dcad9..eae8ab0b1 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -1,6 +1,7 @@ package schema import ( + "context" "fmt" "reflect" "strings" @@ -576,7 +577,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { return &constraint } -func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { +func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} @@ -616,7 +617,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] } } - _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 6d2bc6647..9abaecba3 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "fmt" "reflect" "strings" @@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - fv, _ := s.FieldsByDBName[k].ValueOf(value) + fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) tests.AssertEqual(t, v, fv) }) } diff --git a/schema/serializer.go b/schema/serializer.go new file mode 100644 index 000000000..68597538d --- /dev/null +++ b/schema/serializer.go @@ -0,0 +1,125 @@ +package schema + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var serializerMap = sync.Map{} + +// RegisterSerializer register serializer +func RegisterSerializer(name string, serializer SerializerInterface) { + serializerMap.Store(strings.ToLower(name), serializer) +} + +// GetSerializer get serializer +func GetSerializer(name string) (serializer SerializerInterface, ok bool) { + v, ok := serializerMap.Load(strings.ToLower(name)) + if ok { + serializer, ok = v.(SerializerInterface) + } + return serializer, ok +} + +func init() { + RegisterSerializer("json", JSONSerializer{}) + RegisterSerializer("unixtime", UnixSecondSerializer{}) +} + +// Serializer field value serializer +type serializer struct { + Field *Field + Serializer SerializerInterface + SerializeValuer SerializerValuerInterface + Destination reflect.Value + Context context.Context + value interface{} + fieldValue interface{} +} + +// Scan implements sql.Scanner interface +func (s *serializer) Scan(value interface{}) error { + s.value = value + return nil +} + +// Value implements driver.Valuer interface +func (s serializer) Value() (driver.Value, error) { + return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { + Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { + Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) +} + +// JSONSerializer json serializer +type JSONSerializer struct { +} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + } + + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + return string(result), err +} + +// UnixSecondSerializer json serializer +type UnixSecondSerializer struct { +} + +// Scan implements serializer interface +func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + t := sql.NullTime{} + if err = t.Scan(dbValue); err == nil { + err = field.Set(ctx, dst, t.Time) + } + + return +} + +// Value implements serializer interface +func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + switch v := fieldValue.(type) { + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.ValueOf(v).Int(), 0) + default: + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + } + return +} diff --git a/schema/utils.go b/schema/utils.go index e005cc740..2720c5304 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -1,6 +1,7 @@ package schema import ( + "context" "reflect" "regexp" "strings" @@ -59,13 +60,13 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct } // GetRelationsValues get relations's values from a reflect value -func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { +func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { - if _, isZero := rel.Field.ValueOf(value); !isZero { - result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) switch result.Kind() { case reflect.Struct: reflectResults = reflect.Append(reflectResults, result.Addr()) @@ -97,7 +98,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle } // GetIdentityFieldValuesMap get identity map from fields -func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( results = [][]interface{}{} dataResults = map[string][]reflect.Value{} @@ -110,7 +111,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { - results[0][idx], zero = field.ValueOf(reflectValue) + results[0][idx], zero = field.ValueOf(ctx, reflectValue) notZero = notZero || !zero } @@ -135,7 +136,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { - fieldValues[idx], zero = field.ValueOf(elem) + fieldValues[idx], zero = field.ValueOf(ctx, elem) notZero = notZero || !zero } @@ -155,12 +156,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map } // GetIdentityFieldValuesMapFromValues get identity map from fields -func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { resultsMap := map[string][]reflect.Value{} results := [][]interface{}{} for _, v := range values { - rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) for k, v := range rm { resultsMap[k] = append(resultsMap[k], v...) } diff --git a/soft_delete.go b/soft_delete.go index 4582161dd..ba6d2118d 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -135,7 +135,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -143,7 +143,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { diff --git a/statement.go b/statement.go index 232126426..cb4717766 100644 --- a/statement.go +++ b/statement.go @@ -389,7 +389,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -403,7 +403,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -562,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . switch destValue.Kind() { case reflect.Struct: - field.Set(destValue, value) + field.Set(stmt.Context, destValue, value) default: stmt.AddError(ErrInvalidData) } @@ -572,10 +572,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } else { - field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { @@ -583,7 +583,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . return } - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } else { stmt.AddError(ErrInvalidField) @@ -603,7 +603,7 @@ func (stmt *Statement) Changed(fields ...string) bool { selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, _ := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { @@ -617,7 +617,7 @@ func (stmt *Statement) Changed(fields ...string) bool { destValue = destValue.Elem() } - changedValue, zero := field.ValueOf(destValue) + changedValue, zero := field.ValueOf(stmt.Context, destValue) return !zero && !utils.AssertEqual(changedValue, fieldValue) } } diff --git a/tests/create_test.go b/tests/create_test.go index af2abdb08..2b23d4409 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) { {"name": "create_from_map_3", "Age": 20}, } - if err := DB.Model(&User{}).Create(datas).Error; err != nil { + if err := DB.Model(&User{}).Create(&datas).Error; err != nil { t.Fatalf("failed to create data from slice of map, got error: %v", err) } diff --git a/tests/go.mod b/tests/go.mod index 3453f77b0..35db92e65 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 diff --git a/tests/serializer_test.go b/tests/serializer_test.go new file mode 100644 index 000000000..3ed733d9f --- /dev/null +++ b/tests/serializer_test.go @@ -0,0 +1,71 @@ +package tests_test + +import ( + "bytes" + "context" + "fmt" + "reflect" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +type SerializerStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Contracts map[string]interface{} `gorm:"serializer:json"` + CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + EncryptedString EncryptedString +} + +type Roles []string +type EncryptedString string + +func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + *es = EncryptedString(bytes.TrimPrefix(value, []byte("hello"))) + case string: + *es = EncryptedString(strings.TrimPrefix(value, "hello")) + default: + return fmt.Errorf("unsupported data %v", dbValue) + } + return nil +} + +func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return "hello" + string(es), nil +} + +func TestSerializer(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("jinzhu"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) +} diff --git a/utils/utils.go b/utils/utils.go index f00f92ba3..28ca0daf3 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } -func CheckTruth(val interface{}) bool { - if v, ok := val.(bool); ok { - return v - } - - if v, ok := val.(string); ok { - v = strings.ToLower(v) - return v != "false" +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if !strings.EqualFold(val, "false") && val != "" { + return true + } } - - return !reflect.ValueOf(val).IsZero() + return false } func ToStringKey(values ...interface{}) string { From 0af95f509a3284bb94393946e0a83aeaf954f304 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 19 Feb 2022 16:59:22 +0800 Subject: [PATCH 03/92] Enhance migrator Columntype interface (#5088) * Update Migrator ColumnType interface * Update MigrateColumn Test * Upgrade test drivers * Fix typo --- migrator.go | 13 ++++- migrator/column_type.go | 107 ++++++++++++++++++++++++++++++++++++++++ migrator/migrator.go | 39 +++++++++++++-- tests/go.mod | 9 ++-- tests/migrate_test.go | 31 ++++++++++-- 5 files changed, 185 insertions(+), 14 deletions(-) create mode 100644 migrator/column_type.go diff --git a/migrator.go b/migrator.go index 2a8b42548..524438770 100644 --- a/migrator.go +++ b/migrator.go @@ -1,6 +1,8 @@ package gorm import ( + "reflect" + "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -33,14 +35,23 @@ type ViewOption struct { Query *DB } +// ColumnType column type interface type ColumnType interface { Name() string - DatabaseTypeName() string + DatabaseTypeName() string // varchar + ColumnType() (columnType string, ok bool) // varchar(64) + PrimaryKey() (isPrimaryKey bool, ok bool) + AutoIncrement() (isAutoIncrement bool, ok bool) Length() (length int64, ok bool) DecimalSize() (precision int64, scale int64, ok bool) Nullable() (nullable bool, ok bool) + Unique() (unique bool, ok bool) + ScanType() reflect.Type + Comment() (value string, ok bool) + DefaultValue() (value string, ok bool) } +// Migrator migrator interface type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error diff --git a/migrator/column_type.go b/migrator/column_type.go new file mode 100644 index 000000000..eb8d1b7f8 --- /dev/null +++ b/migrator/column_type.go @@ -0,0 +1,107 @@ +package migrator + +import ( + "database/sql" + "reflect" +) + +// ColumnType column type implements ColumnType interface +type ColumnType struct { + SQLColumnType *sql.ColumnType + NameValue sql.NullString + DataTypeValue sql.NullString + ColumnTypeValue sql.NullString + PrimayKeyValue sql.NullBool + UniqueValue sql.NullBool + AutoIncrementValue sql.NullBool + LengthValue sql.NullInt64 + DecimalSizeValue sql.NullInt64 + ScaleValue sql.NullInt64 + NullableValue sql.NullBool + ScanTypeValue reflect.Type + CommentValue sql.NullString + DefaultValueValue sql.NullString +} + +// Name returns the name or alias of the column. +func (ct ColumnType) Name() string { + if ct.NameValue.Valid { + return ct.NameValue.String + } + return ct.SQLColumnType.Name() +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ct ColumnType) DatabaseTypeName() string { + if ct.DataTypeValue.Valid { + return ct.DataTypeValue.String + } + return ct.SQLColumnType.DatabaseTypeName() +} + +// ColumnType returns the database type of the column. lke `varchar(16)` +func (ct ColumnType) ColumnType() (columnType string, ok bool) { + return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid +} + +// PrimaryKey returns the column is primary key or not. +func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { + return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid +} + +// AutoIncrement returns the column is auto increment or not. +func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { + return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid +} + +// Length returns the column type length for variable length column types +func (ct ColumnType) Length() (length int64, ok bool) { + if ct.LengthValue.Valid { + return ct.LengthValue.Int64, true + } + return ct.SQLColumnType.Length() +} + +// DecimalSize returns the scale and precision of a decimal type. +func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { + if ct.DecimalSizeValue.Valid { + return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true + } + return ct.SQLColumnType.DecimalSize() +} + +// Nullable reports whether the column may be null. +func (ct ColumnType) Nullable() (nullable bool, ok bool) { + if ct.NullableValue.Valid { + return ct.NullableValue.Bool, true + } + return ct.SQLColumnType.Nullable() +} + +// Unique reports whether the column may be unique. +func (ct ColumnType) Unique() (unique bool, ok bool) { + return ct.UniqueValue.Bool, ct.UniqueValue.Valid +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +func (ct ColumnType) ScanType() reflect.Type { + if ct.ScanTypeValue != nil { + return ct.ScanTypeValue + } + return ct.SQLColumnType.ScanType() +} + +// Comment returns the comment of current column. +func (ct ColumnType) Comment() (value string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} + +// DefaultValue returns the default value of current column. +func (ct ColumnType) DefaultValue() (value string, ok bool) { + return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 80c4e2b3c..9695f3129 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -30,10 +30,12 @@ type Config struct { gorm.Dialector } +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string } +// RunWithValue run migration with statement value func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { @@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error return fc(stmt) } +// DataTypeOf return field's db data type func (m Migrator) DataTypeOf(field *schema.Field) string { fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { @@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return m.Dialector.DataTypeOf(field) } +// FullDataTypeOf returns field's db full data type func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) @@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } -// AutoMigrate +// AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) @@ -156,12 +160,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return nil } +// GetTables returns tables func (m Migrator) GetTables() (tableList []string, err error) { err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). Scan(&tableList).Error return } +// CreateTable create table in database for values func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) @@ -252,6 +258,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { return nil } +// DropTable drop table for values func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { @@ -265,6 +272,7 @@ func (m Migrator) DropTable(values ...interface{}) error { return nil } +// HasTable returns table exists or not for value, value could be a struct or string func (m Migrator) HasTable(value interface{}) bool { var count int64 @@ -276,6 +284,7 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +// RenameTable rename table from oldName to newName func (m Migrator) RenameTable(oldName, newName interface{}) error { var oldTable, newTable interface{} if v, ok := oldName.(string); ok { @@ -303,12 +312,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error } -func (m Migrator) AddColumn(value interface{}, field string) error { +// AddColumn create `name` column for value +func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { // avoid using the same name field - f := stmt.Schema.LookUpField(field) + f := stmt.Schema.LookUpField(name) if f == nil { - return fmt.Errorf("failed to look up field with name: %s", field) + return fmt.Errorf("failed to look up field with name: %s", name) } if !f.IgnoreMigration { @@ -322,6 +332,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { }) } +// DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { @@ -334,6 +345,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { }) } +// AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { @@ -348,6 +360,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { }) } +// HasColumn check has column `field` for value or not func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -366,6 +379,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +// RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(oldName); field != nil { @@ -383,6 +397,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +// MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) @@ -448,7 +463,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { } for _, c := range rawColumnTypes { - columnTypes = append(columnTypes, c) + columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) } return @@ -457,10 +472,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } +// CreateView create view func (m Migrator) CreateView(name string, option gorm.ViewOption) error { return gorm.ErrNotImplemented } +// DropView drop view func (m Migrator) DropView(name string) error { return gorm.ErrNotImplemented } @@ -487,6 +504,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } +// GuessConstraintAndTable guess statement's constraint and it's table based on name func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { if stmt.Schema == nil { return nil, nil, stmt.Table @@ -531,6 +549,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ return nil, nil, stmt.Schema.Table } +// CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) @@ -554,6 +573,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { }) } +// DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) @@ -566,6 +586,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { }) } +// HasConstraint check has constraint or not func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -586,6 +607,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { return count > 0 } +// BuildIndexOptions build index options func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) @@ -607,10 +629,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem return } +// BuildIndexOptionsInterface build index options interface type BuildIndexOptionsInterface interface { BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} } +// CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -642,6 +666,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { }) } +// DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -652,6 +677,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { }) } +// HasIndex check has index `name` or not func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -669,6 +695,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { return count > 0 } +// RenameIndex rename index from oldName to newName func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( @@ -678,6 +705,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error }) } +// CurrentDatabase returns current database name func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return @@ -781,6 +809,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i return } +// CurrentTable returns current statement's table expression func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { if stmt.TableExpr != nil { return *stmt.TableExpr diff --git a/tests/go.mod b/tests/go.mod index 35db92e65..0cd036371 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,16 @@ module gorm.io/gorm/tests go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.2.3 - gorm.io/driver/postgres v1.2.3 - gorm.io/driver/sqlite v1.2.6 - gorm.io/driver/sqlserver v1.2.1 + gorm.io/driver/mysql v1.3.0 + gorm.io/driver/postgres v1.3.0 + gorm.io/driver/sqlite v1.3.0 + gorm.io/driver/sqlserver v1.3.0 gorm.io/gorm v1.22.5 ) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index aa0a84ab5..5e9c01fa8 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) { } func TestSmartMigrateColumn(t *testing.T) { - fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()] + fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] type UserMigrateColumn struct { ID uint @@ -313,9 +313,15 @@ func TestMigrateIndexes(t *testing.T) { } func TestMigrateColumns(t *testing.T) { + fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()] + sqlite := DB.Dialector.Name() == "sqlite" + sqlserver := DB.Dialector.Name() == "sqlserver" + type ColumnStruct struct { gorm.Model Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) @@ -340,10 +346,29 @@ func TestMigrateColumns(t *testing.T) { stmt.Parse(&ColumnStruct2{}) for _, columnType := range columnTypes { - if columnType.Name() == "name" { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code": + if v, ok := columnType.Unique(); (fullSupported || ok) && !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } } } From e0b4e0ec8f938ac055e99c5b37e0cdb9bf6e2ad5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 19 Feb 2022 17:08:11 +0800 Subject: [PATCH 04/92] Update auto stale days --- .github/workflows/invalid_question.yml | 4 ++-- .github/workflows/missing_playground.yml | 4 ++-- .github/workflows/stale.yml | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index dfd2ddd90..868bcc348 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -13,10 +13,10 @@ jobs: uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:invalid question" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index cdb097de3..3efc90f74 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,9 +13,9 @@ jobs: uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:missing reproduction steps" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index d5419295f..e0be186fa 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,9 +13,9 @@ jobs: uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" - days-before-stale: 60 - days-before-close: 30 + stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" + days-before-stale: 360 + days-before-close: 180 stale-issue-label: "status:stale" exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' stale-pr-label: 'status:stale' From 48ced75d1d8d8aab844ab29787ae97337095b8e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 19 Feb 2022 23:42:20 +0800 Subject: [PATCH 05/92] Improve support for AutoMigrate --- migrator/column_type.go | 4 ++-- migrator/migrator.go | 24 +++++++++++++++++++++ tests/go.mod | 10 ++++----- tests/migrate_test.go | 47 ++++++++++++++++++++++++++++++----------- 4 files changed, 66 insertions(+), 19 deletions(-) diff --git a/migrator/column_type.go b/migrator/column_type.go index eb8d1b7f8..cc1331b92 100644 --- a/migrator/column_type.go +++ b/migrator/column_type.go @@ -11,7 +11,7 @@ type ColumnType struct { NameValue sql.NullString DataTypeValue sql.NullString ColumnTypeValue sql.NullString - PrimayKeyValue sql.NullBool + PrimaryKeyValue sql.NullBool UniqueValue sql.NullBool AutoIncrementValue sql.NullBool LengthValue sql.NullInt64 @@ -51,7 +51,7 @@ func (ct ColumnType) ColumnType() (columnType string, ok bool) { // PrimaryKey returns the column is primary key or not. func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { - return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid + return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid } // AutoIncrement returns the column is auto increment or not. diff --git a/migrator/migrator.go b/migrator/migrator.go index 9695f3129..a50bb3ff8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -436,6 +436,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } + // check unique + if unique, ok := columnType.Unique(); ok && unique != field.Unique { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check default value + if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check comment + if comment, ok := columnType.Comment(); ok && comment != field.Comment { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + if alterColumn && !field.IgnoreMigration { return m.DB.Migrator().AlterColumn(value, field.Name) } diff --git a/tests/go.mod b/tests/go.mod index 0cd036371..1c1fb2389 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,11 @@ require ( github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.3.0 - gorm.io/driver/postgres v1.3.0 - gorm.io/driver/sqlite v1.3.0 - gorm.io/driver/sqlserver v1.3.0 - gorm.io/gorm v1.22.5 + gorm.io/driver/mysql v1.3.1 + gorm.io/driver/postgres v1.3.1 + gorm.io/driver/sqlite v1.3.1 + gorm.io/driver/sqlserver v1.3.1 + gorm.io/gorm v1.23.0 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5e9c01fa8..94f562b47 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) { for _, m := range allModels { if !DB.Migrator().HasTable(m) { - t.Fatalf("Failed to create table for %#v---", m) + t.Fatalf("Failed to create table for %#v", m) } } @@ -313,15 +313,16 @@ func TestMigrateIndexes(t *testing.T) { } func TestMigrateColumns(t *testing.T) { - fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()] sqlite := DB.Dialector.Name() == "sqlite" sqlserver := DB.Dialector.Name() == "sqlserver" type ColumnStruct struct { gorm.Model - Name string - Age int `gorm:"default:18;comment:my age"` - Code string `gorm:"unique"` + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) @@ -332,13 +333,20 @@ func TestMigrateColumns(t *testing.T) { type ColumnStruct2 struct { gorm.Model - Name string `gorm:"size:100"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"unique"` + // Code3 string } - if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { t.Fatalf("no error should happened when alter column, but got %v", err) } + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { @@ -348,7 +356,7 @@ func TestMigrateColumns(t *testing.T) { for _, columnType := range columnTypes { switch columnType.Name() { case "id": - if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v { + if v, ok := columnType.PrimaryKey(); !ok || !v { t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "name": @@ -356,20 +364,35 @@ func TestMigrateColumns(t *testing.T) { if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } - if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 { + if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) } case "age": - if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" { + if v, ok := columnType.DefaultValue(); !ok || v != "18" { t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) } - if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" { + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code": - if v, ok := columnType.Unique(); (fullSupported || ok) && !v { + if v, ok := columnType.Unique(); !ok || !v { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } + if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { + t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code3": + // TODO + // if v, ok := columnType.Unique(); !ok || v { + // t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + // } } } } From 5edc78116fe46a7d001db52d80a78f97756ac1ad Mon Sep 17 00:00:00 2001 From: sammyrnycreal Date: Mon, 14 Feb 2022 14:13:26 -0500 Subject: [PATCH 06/92] Fixed the use of "or" to be " OR ", to account for words that contain "or" or "and" (e.g., 'score', 'band') in a sql statement as the name of a field. --- clause/where.go | 39 ++++++++++++++++++++++----------------- clause/where_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/clause/where.go b/clause/where.go index 20a011362..10b6df856 100644 --- a/clause/where.go +++ b/clause/where.go @@ -4,6 +4,11 @@ import ( "strings" ) +const ( + AndWithSpace = " AND " + OrWithSpace = " OR " +) + // Where where clause type Where struct { Exprs []Expression @@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) { } } - buildExprs(where.Exprs, builder, " AND ") + buildExprs(where.Exprs, builder, AndWithSpace) } func buildExprs(exprs []Expression, builder Builder, joinCond string) { @@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.WriteString(" OR ") + builder.WriteString(OrWithSpace) } else { builder.WriteString(joinCond) } @@ -46,23 +51,23 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { case OrConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case AndConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case Expr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) case NamedExpr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } @@ -110,10 +115,10 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) builder.WriteByte(')') } else { - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) } } @@ -131,10 +136,10 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) builder.WriteByte(')') } else { - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) } } @@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) { for idx, c := range not.Exprs { if idx > 0 { - builder.WriteString(" AND ") + builder.WriteString(AndWithSpace) } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { @@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) { builder.WriteString("NOT ") e, wrapInParentheses := c.(Expr) if wrapInParentheses { - sql := strings.ToLower(e.SQL) - if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { builder.WriteByte('(') } } diff --git a/clause/where_test.go b/clause/where_test.go index 272c7b76b..35e3dbeeb 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -66,6 +66,45 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{ + clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), + }, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + []interface{}{"1", 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", + []interface{}{"1", 100}, + }, } for idx, result := range results { From f3547e00cc786e0b07206c775f3b7fe19164f56f Mon Sep 17 00:00:00 2001 From: Gilad Weiss Date: Sun, 20 Feb 2022 02:33:12 +0200 Subject: [PATCH 07/92] Inherit clone flag (NewDB) on transaction creation (#5012) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Inherit clone flag (NewDB) on transaction creation I find it very reassuring to know that after a finisher API, I get a clean db object for my next queries. If you look at the example in https://gorm.io/docs i’d see many queries running one after the other.. but in reality they wouldn’t work as the they are portrayed and that’s because in default mode NewDB is false and will make all the clauses stay even after a finisher API. My solution is just to have the value of the clone flag in the “parent” db object, be injected to its children transactions. * Fix typo --- finisher_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index d2a8b981e..f994ec318 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -590,7 +590,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.getInstance().Session(&Session{Context: db.Statement.Context}) + tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions err error ) From 664c5fb7672863b38080bb2147403b5d67f2593c Mon Sep 17 00:00:00 2001 From: codingxh <94290868+codingxh@users.noreply.github.com> Date: Sun, 20 Feb 2022 19:55:04 +0800 Subject: [PATCH 08/92] strings.replace -> strings.replaceAll (#5095) Co-authored-by: huquan --- logger/sql.go | 8 ++++---- logger/sql_test.go | 2 +- schema/naming.go | 2 +- tests/sql_builder_test.go | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 04a2dbd49..c8b194c3c 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -75,10 +75,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper } else { vars[idx] = nullStr } @@ -94,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: - vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { @@ -111,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper } } } diff --git a/logger/sql_test.go b/logger/sql_test.go index 71aa841af..c5b181a9c 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper } func TestExplainSQL(t *testing.T) { diff --git a/schema/naming.go b/schema/naming.go index a4e3a75b6..125094bcf 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -174,7 +174,7 @@ func (ns NamingStrategy) toDBName(name string) string { } func (ns NamingStrategy) toSchemaName(name string) string { - result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1) + result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") for _, initialism := range commonInitialisms { result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 897f687f7..bc917c32d 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -460,16 +460,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { func replaceQuoteInSQL(sql string) string { // convert single quote into double quote - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) // convert dialect speical quote into double quote switch DB.Dialector.Name() { case "postgres": - sql = strings.Replace(sql, `"`, `"`, -1) + sql = strings.ReplaceAll(sql, `"`, `"`) case "mysql", "sqlite": - sql = strings.Replace(sql, "`", `"`, -1) + sql = strings.ReplaceAll(sql, "`", `"`) case "sqlserver": - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) } return sql From 7837fb6fa001ef78bc76e66b48445dee7b2db37b Mon Sep 17 00:00:00 2001 From: Qt Date: Sun, 20 Feb 2022 21:19:15 +0800 Subject: [PATCH 09/92] fix typo in TxCommitter interface comment & improve CheckTruth, chek val empty first (#5094) * fix typo in TxCommitter interface comment * improve CheckTruth, chek val empty first --- interfaces.go | 2 +- utils/utils.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/interfaces.go b/interfaces.go index ff0ca60ae..44a85cb51 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,7 +50,7 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } -// TxCommitter tx commiter +// TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error diff --git a/utils/utils.go b/utils/utils.go index 28ca0daf3..296917b93 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -39,7 +39,7 @@ func IsValidDBNameChar(c rune) bool { // CheckTruth check string true or not func CheckTruth(vals ...string) bool { for _, val := range vals { - if !strings.EqualFold(val, "false") && val != "" { + if val != "" && !strings.EqualFold(val, "false") { return true } } From b1201fce4efa60b464a1b260869a24d809607f53 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 23 Feb 2022 17:48:13 +0800 Subject: [PATCH 10/92] Fix update with customized time type, close #5101 --- callbacks/update.go | 12 ++++++------ schema/field.go | 8 ++++---- tests/go.mod | 4 ++-- tests/postgres_test.go | 18 +++++++++++++++--- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 4f07ca304..4a2e5c79b 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -232,10 +232,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) - } else if field.GORMDataType == schema.Time { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } } } @@ -264,10 +264,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() + } else { + value = stmt.DB.NowFunc() } isZero = false } diff --git a/schema/field.go b/schema/field.go index 319f36934..8c793f930 100644 --- a/schema/field.go +++ b/schema/field.go @@ -293,6 +293,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + if val, ok := field.TagSettings["TYPE"]; ok { switch DataType(strings.ToLower(val)) { case Bool, Int, Uint, Float, String, Time, Bytes: @@ -302,10 +306,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if field.GORMDataType == "" { - field.GORMDataType = field.DataType - } - if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: diff --git a/tests/go.mod b/tests/go.mod index 1c1fb2389..cefe6f96f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,11 @@ require ( github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.3.1 + gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 gorm.io/driver/sqlserver v1.3.1 - gorm.io/gorm v1.23.0 + gorm.io/gorm v1.23.1 ) replace gorm.io/gorm => ../ diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 85671864f..418b713e5 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "github.com/google/uuid" "github.com/lib/pq" @@ -15,9 +16,11 @@ func TestPostgres(t *testing.T) { type Harumph struct { gorm.Model - Name string `gorm:"check:name_checker,name <> ''"` - Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` - Things pq.StringArray `gorm:"type:text[]"` + Name string `gorm:"check:name_checker,name <> ''"` + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + Things pq.StringArray `gorm:"type:text[]"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { @@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) { if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } + + harumph.Name = "jinzhu1" + if err := DB.Save(&harumph).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } } type Post struct { From 45ef1da7e4853441e59af06800ed7c672f15bc7c Mon Sep 17 00:00:00 2001 From: Michael Nussbaum Date: Wed, 23 Feb 2022 21:10:20 -0500 Subject: [PATCH 11/92] Fix naming longer then 64 chars with dots in table (#5045) Ensures that foreign key relationships and indexes are given syntactically valid names when their name length exceeds 64 characters and they contained dot characters within the name. This is most often relevant when a Postgres table name is fully qualified by including its schema as part of its name --- schema/naming.go | 3 +-- schema/naming_test.go | 2 +- schema/relationship_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index 125094bcf..47a2b3636 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -3,7 +3,6 @@ package schema import ( "crypto/sha1" "encoding/hex" - "fmt" "regexp" "strings" "unicode/utf8" @@ -95,7 +94,7 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/schema/naming_test.go b/schema/naming_test.go index 1fdab9a06..3f598c33e 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -193,7 +193,7 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { ns := NamingStrategy{} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") - if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { t.Errorf("invalid formatted name generated, got %v", formattedName) } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index e2cf11a91..40ffc3249 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -576,3 +576,39 @@ func TestHasManySameForeignKey(t *testing.T) { References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, }) } + +type Author struct { + gorm.Model +} + +type Book struct { + gorm.Model + Author Author + AuthorID uint +} + +func (Book) TableName() string { + return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name" +} + +func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { + s, err := schema.Parse( + &Book{}, + &sync.Map{}, + schema.NamingStrategy{}, + ) + if err != nil { + t.Fatalf("Failed to parse schema") + } + + expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec" + constraint := s.Relationships.Relations["Author"].ParseConstraint() + + if constraint.Name != expectedConstraintName { + t.Fatalf( + "expected constraint name %s, got %s", + expectedConstraintName, + constraint.Name, + ) + } +} From 3741f258d053c0ac145392b5669c0cc62ddc0f15 Mon Sep 17 00:00:00 2001 From: jing1 Date: Thu, 24 Feb 2022 10:21:27 +0800 Subject: [PATCH 12/92] feat: support gob serialize (#5108) --- schema/serializer.go | 36 ++++++++++++++++++++++++++++++++++-- tests/serializer_test.go | 15 +++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 68597538d..09da6d9ef 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -1,11 +1,12 @@ package schema import ( + "bytes" "context" "database/sql" "database/sql/driver" + "encoding/gob" "encoding/json" - "errors" "fmt" "reflect" "strings" @@ -32,6 +33,7 @@ func GetSerializer(name string) (serializer SerializerInterface, ok bool) { func init() { RegisterSerializer("json", JSONSerializer{}) RegisterSerializer("unixtime", UnixSecondSerializer{}) + RegisterSerializer("gob", GobSerializer{}) } // Serializer field value serializer @@ -83,7 +85,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, case string: bytes = []byte(v) default: - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) } err = json.Unmarshal(bytes, fieldValue.Interface()) @@ -123,3 +125,33 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect } return } + +// GobSerializer gob serializer +type GobSerializer struct { +} + +// Scan implements serializer interface +func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytesValue []byte + switch v := dbValue.(type) { + case []byte: + bytesValue = v + default: + return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) + } + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(fieldValue) + return buf.Bytes(), err +} diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 3ed733d9f..a8a4e28f8 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -19,11 +19,20 @@ type SerializerStruct struct { Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type EncryptedString EncryptedString } type Roles []string + +type Job struct { + Title string + Number int + Location string + IsIntern bool +} + type EncryptedString string func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -56,6 +65,12 @@ func TestSerializer(t *testing.T) { Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, EncryptedString: EncryptedString("pass"), CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Kenmawr", + IsIntern: false, + }, } if err := DB.Create(&data).Error; err != nil { From 6a18a15c93e17d513687993294e045574117266a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Feb 2022 10:48:23 +0800 Subject: [PATCH 13/92] Refactor check missing where condition --- callbacks/delete.go | 19 +++++++------------ callbacks/helper.go | 16 ++++++++++++++++ callbacks/update.go | 20 +++++++------------- soft_delete.go | 11 ++--------- tests/update_test.go | 2 +- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 1fb5261cb..84f446a3f 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) diff --git a/callbacks/helper.go b/callbacks/helper.go index a59e1880f..a5eb047e5 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { } return false, 0 } + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} diff --git a/callbacks/update.go b/callbacks/update.go index 4a2e5c79b..da03261ec 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) @@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) { return } - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { diff --git a/soft_delete.go b/soft_delete.go index ba6d2118d..6d6462880 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } @@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - } else { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } - + SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } diff --git a/tests/update_test.go b/tests/update_test.go index b471ba9be..41ea5d27b 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,7 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } From 397b583b8ecc5a31c838db5822fe1003b53a91ef Mon Sep 17 00:00:00 2001 From: chenrui Date: Fri, 25 Feb 2022 22:38:48 +0800 Subject: [PATCH 14/92] fix: query scanner in single column --- scan.go | 12 +++++++++++- tests/query_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 0da12dafb..a1cb582e0 100644 --- a/scan.go +++ b/scan.go @@ -272,7 +272,17 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + if update { + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + } else { + elem := reflect.New(reflectValueType) + db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + if isPtr { + db.Statement.ReflectValue.Set(elem) + } else { + db.Statement.ReflectValue.Set(elem.Elem()) + } + } } default: db.AddError(rows.Scan(dest)) diff --git a/tests/query_test.go b/tests/query_test.go index d10df1807..6542774a4 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } + +type DoubleInt64 struct { + data int64 +} + +func (t *DoubleInt64) Scan(val interface{}) error { + switch v := val.(type) { + case int64: + t.data = v * 2 + return nil + default: + return fmt.Errorf("DoubleInt64 cant not scan with:%v", v) + } +} + +// https://github.com/go-gorm/gorm/issues/5091 +func TestQueryScannerWithSingleColumn(t *testing.T) { + user := User{Name: "scanner_raw_1", Age: 10} + DB.Create(&user) + + var result1 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck( + "age", &result1).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result1.data, 20) + + var result2 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select( + "age").Scan(&result2).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result2.data, 20) +} From f2edda50e11728e7aee6b1d4c961d575f7afbb2d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Feb 2022 10:48:23 +0800 Subject: [PATCH 15/92] Refactor check missing where condition --- callbacks/delete.go | 19 +++++++------------ callbacks/helper.go | 16 ++++++++++++++++ callbacks/update.go | 20 +++++++------------- soft_delete.go | 11 ++--------- tests/update_test.go | 2 +- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 1fb5261cb..84f446a3f 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) diff --git a/callbacks/helper.go b/callbacks/helper.go index a59e1880f..a5eb047e5 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { } return false, 0 } + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} diff --git a/callbacks/update.go b/callbacks/update.go index 4a2e5c79b..da03261ec 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) @@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) { return } - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { diff --git a/soft_delete.go b/soft_delete.go index ba6d2118d..6d6462880 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } @@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - } else { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } - + SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } diff --git a/tests/update_test.go b/tests/update_test.go index b471ba9be..41ea5d27b 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,7 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } From 68bb5379d91a7f7fae4dc65205db66004f515d0c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 09:09:29 +0800 Subject: [PATCH 16/92] Refactor scan into struct --- scan.go | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/scan.go b/scan.go index a1cb582e0..e83390ca5 100644 --- a/scan.go +++ b/scan.go @@ -68,7 +68,11 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re values[idx] = &sql.RawBytes{} } else if len(columns) == 1 { sch = nil - values[idx] = reflectValue.Interface() + if reflectValue.CanAddr() { + values[idx] = reflectValue.Addr().Interface() + } else { + values[idx] = reflectValue.Interface() + } } else { values[idx] = &sql.RawBytes{} } @@ -272,17 +276,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - if update { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) - } else { - elem := reflect.New(reflectValueType) - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) - if isPtr { - db.Statement.ReflectValue.Set(elem) - } else { - db.Statement.ReflectValue.Set(elem.Elem()) - } - } + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) } default: db.AddError(rows.Scan(dest)) From 530b0a12b4c63bb2dc7abef2934dc8406f1d0f13 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 22:10:17 +0800 Subject: [PATCH 17/92] Add fast path for ValueOf, ReflectValueOf --- schema/field.go | 70 ++++++++++++++++++++++++++++++------------------- tests/go.mod | 1 + 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/schema/field.go b/schema/field.go index 8c793f930..826680c5a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -465,24 +465,33 @@ func (field *Field) setupValuerAndSetter() { } // ValueOf returns field's value and if it is zero - field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { - v = reflect.Indirect(v) - for _, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) - - if !v.IsNil() { - v = v.Elem() + fieldIndex := field.StructField.Index[0] + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(fieldIndex) + return fieldValue.Interface(), fieldValue.IsZero() + } + default: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) } else { - return nil, true + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } } } - } - fv, zero := v.Interface(), v.IsZero() - return fv, zero + fv, zero := v.Interface(), v.IsZero() + return fv, zero + } } if field.Serializer != nil { @@ -509,24 +518,31 @@ func (field *Field) setupValuerAndSetter() { } // ReflectValueOf returns field's reflect value - field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { - v = reflect.Indirect(v) - for idx, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(fieldIndex) + } + default: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } - if idx < len(field.StructField.Index)-1 { - v = v.Elem() + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } } } + return v } - return v } fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { diff --git a/tests/go.mod b/tests/go.mod index cefe6f96f..9e3453b73 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,6 +3,7 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 From 43a72b369e670bd91e32784d063608931a59a66e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 22:54:43 +0800 Subject: [PATCH 18/92] Refactor Scan --- scan.go | 104 +++++++++++++++++++++++--------------------------------- 1 file changed, 43 insertions(+), 61 deletions(-) diff --git a/scan.go b/scan.go index e83390ca5..d7b58e03d 100644 --- a/scan.go +++ b/scan.go @@ -50,58 +50,37 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { - for idx, column := range columns { - if sch == nil { - values[idx] = reflectValue.Interface() - } else if field := sch.LookUpField(column); field != nil && field.Readable { +func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, field := range fields { + if field != nil { values[idx] = field.NewValuePool.Get() defer field.NewValuePool.Put(values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = field.NewValuePool.Get() - defer field.NewValuePool.Put(values[idx]) - continue - } + if len(joinFields) == 0 || joinFields[idx][0] == nil { + defer field.Set(db.Statement.Context, reflectValue, values[idx]) } - values[idx] = &sql.RawBytes{} - } else if len(columns) == 1 { - sch = nil + } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() } else { values[idx] = reflectValue.Interface() } - } else { - values[idx] = &sql.RawBytes{} } } db.RowsAffected++ db.AddError(rows.Scan(values...)) - if sch != nil { - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.Context, reflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(db.Statement.Context, relValue, values[idx]) - } + for idx, joinField := range joinFields { + if joinField[0] != nil { + relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + return } + + relValue.Set(reflect.New(relValue.Type().Elem())) } + joinField[1].Set(db.Statement.Context, relValue, values[idx]) } } } @@ -180,7 +159,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { reflectValue = db.Statement.ReflectValue ) - for reflectValue.Kind() == reflect.Interface { + if reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } @@ -199,35 +178,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - if len(columns) == 1 { - // isPluck + // Is Pluck if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch = nil } } + + // Not Pluck + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } } switch reflectValue.Kind() { @@ -260,7 +242,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflect.New(reflectValueType) } - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { if isPtr { @@ -276,7 +258,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: db.AddError(rows.Scan(dest)) From e2e802b837a234ede6dc122dbb26de965e35e55f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Feb 2022 09:28:19 +0800 Subject: [PATCH 19/92] Refactor Scan --- callbacks/create.go | 6 ++++-- scan.go | 29 ++++++++++++++++------------- tests/go.mod | 2 +- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index b0964e2b6..6e2883f79 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: rValLen := stmt.ReflectValue.Len() - stmt.SQL.Grow(rValLen * 18) - values.Values = make([][]interface{}, rValLen) if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } + stmt.SQL.Grow(rValLen * 18) + stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) + values.Values = make([][]interface{}, rValLen) + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) diff --git a/scan.go b/scan.go index d7b58e03d..a4243d12d 100644 --- a/scan.go +++ b/scan.go @@ -54,10 +54,6 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() - defer field.NewValuePool.Put(values[idx]) - if len(joinFields) == 0 || joinFields[idx][0] == nil { - defer field.Set(db.Statement.Context, reflectValue, values[idx]) - } } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() @@ -70,17 +66,24 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values db.RowsAffected++ db.AddError(rows.Scan(values...)) - for idx, joinField := range joinFields { - if joinField[0] != nil { - relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - return - } + for idx, field := range fields { + if field != nil { + if len(joinFields) == 0 || joinFields[idx][0] == nil { + field.Set(db.Statement.Context, reflectValue, values[idx]) + } else { + relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + return + } - relValue.Set(reflect.New(relValue.Type().Elem())) + relValue.Set(reflect.New(relValue.Type().Elem())) + } + joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]) } - joinField[1].Set(db.Statement.Context, relValue, values[idx]) + + // release data to pool + field.NewValuePool.Put(values[idx]) } } } diff --git a/tests/go.mod b/tests/go.mod index 9e3453b73..c65ea953c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - github.com/mattn/go-sqlite3 v1.14.11 // indirect + github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 From 996b96e81268335b22faf694dfb4674f84177f17 Mon Sep 17 00:00:00 2001 From: lianghuan Date: Mon, 28 Feb 2022 17:12:09 +0800 Subject: [PATCH 20/92] Add TxConnPoolBeginner and Tx interface --- .gitignore | 1 + finisher_api.go | 3 + interfaces.go | 13 +++ prepare_stmt.go | 7 +- tests/connpool_test.go | 181 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 tests/connpool_test.go diff --git a/.gitignore b/.gitignore index e1b9ecea1..45505cc93 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ documents coverage.txt _book .idea +vendor \ No newline at end of file diff --git a/finisher_api.go b/finisher_api.go index f994ec318..5d49ddf94 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } } + // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ @@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else { err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index 44a85cb51..ed7112f27 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,12 +50,25 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxConnPoolBeginner tx conn pool beginner +type TxConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) +} + // TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error } +// Tx sql.Tx interface +type Tx interface { + ConnPool + Commit() error + Rollback() error + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +} + // Valuer gorm valuer interface type Valuer interface { GormValue(context.Context, *DB) clause.Expr diff --git a/prepare_stmt.go b/prepare_stmt.go index 88bec4e95..94282fadb 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } @@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg } type PreparedStmtTX struct { - *sql.Tx + Tx PreparedStmtDB *PreparedStmtDB } @@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() diff --git a/tests/connpool_test.go b/tests/connpool_test.go new file mode 100644 index 000000000..3713ad7cf --- /dev/null +++ b/tests/connpool_test.go @@ -0,0 +1,181 @@ +package tests_test + +import ( + "context" + "database/sql" + "log" + "os" + "reflect" + "testing" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" +) + +type wrapperTx struct { + *sql.Tx + conn *wrapperConnPool +} + +func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.PrepareContext(ctx, query) +} + +func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.ExecContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryRowContext(ctx, query, args...) +} + +type wrapperConnPool struct { + db *sql.DB + got []string + expect []string +} + +func (c *wrapperConnPool) Ping() error { + return c.db.Ping() +} + +// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { + tx, err := c.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &wrapperTx{Tx: tx, conn: c}, nil +} + +func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.got = append(c.got, query) + return c.db.PrepareContext(ctx, query) +} + +func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.got = append(c.got, query) + return c.db.ExecContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.got = append(c.got, query) + return c.db.QueryContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.got = append(c.got, query) + return c.db.QueryRowContext(ctx, query, args...) +} + +func TestConnPoolWrapper(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect != "mysql" { + t.SkipNow() + } + + dbDSN := os.Getenv("GORM_DSN") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + nativeDB, err := sql.Open("mysql", dbDSN) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + conn := &wrapperConnPool{ + db: nativeDB, + expect: []string{ + "SELECT VERSION()", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + }, + } + + defer func() { + if !reflect.DeepEqual(conn.got, conn.expect) { + t.Errorf("expect %#v but got %#v", conn.expect, conn.got) + } + }() + + l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: logger.Info, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + tx := db.Begin() + user := *GetUser("transaction", Config{}) + + if err = tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err = tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := db.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err = tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +} From 4e523499d191d02e032b126774efd26daa8697a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Mar 2022 16:48:46 +0800 Subject: [PATCH 21/92] Refactor Tx interface --- finisher_api.go | 9 ++++----- interfaces.go | 8 +------- prepare_stmt.go | 3 --- tests/connpool_test.go | 14 ++------------ 4 files changed, 7 insertions(+), 27 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5d49ddf94..4b428a59f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } - if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else { + default: err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index ed7112f27..84dc94bb4 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,11 +50,6 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } -// TxConnPoolBeginner tx conn pool beginner -type TxConnPoolBeginner interface { - BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) -} - // TxCommitter tx committer type TxCommitter interface { Commit() error @@ -64,8 +59,7 @@ type TxCommitter interface { // Tx sql.Tx interface type Tx interface { ConnPool - Commit() error - Rollback() error + TxCommitter StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt } diff --git a/prepare_stmt.go b/prepare_stmt.go index 94282fadb..b062b0d6b 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err - } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { - tx, err := beginner.BeginTx(ctx, opt) - return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } diff --git a/tests/connpool_test.go b/tests/connpool_test.go index 3713ad7cf..fbae22941 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -3,15 +3,12 @@ package tests_test import ( "context" "database/sql" - "log" "os" "reflect" "testing" - "time" "gorm.io/driver/mysql" "gorm.io/gorm" - "gorm.io/gorm/logger" . "gorm.io/gorm/utils/tests" ) @@ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error { // return c.db.BeginTx(ctx, opts) // } // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. -func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) if err != nil { return nil, err @@ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: logger.Info, - IgnoreRecordNotFoundError: false, - Colorful: true, - }) - - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } From 29a8557384b060bf5d99b4b8824cb75c8a8b9917 Mon Sep 17 00:00:00 2001 From: Cao Manh Dat Date: Thu, 3 Mar 2022 09:17:29 +0700 Subject: [PATCH 22/92] ToSQL should enable SkipDefaultTransaction by default --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 7967b0945..aca7cb5ed 100644 --- a/gorm.go +++ b/gorm.go @@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error { // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { - tx := queryFn(db.Session(&Session{DryRun: true})) + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) From f961bf1c147113527e486595b0ce342f3c5ba3dd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 12 Mar 2022 22:28:18 +0800 Subject: [PATCH 23/92] chore(deps): bump actions/checkout from 2 to 3 (#5133) Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 3. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/labeler.yml | 2 +- .github/workflows/reviewdog.yml | 2 +- .github/workflows/tests.yml | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index bc1add531..0e8aaa602 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -11,7 +11,7 @@ jobs: name: Label issues and pull requests steps: - name: check out - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: labeler uses: jinzhu/super-labeler-action@develop diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index b252dd7ae..a6542d57e 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,7 +6,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 91a0abc9f..3e15427ca 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,7 +24,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache uses: actions/cache@v2 @@ -67,7 +67,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache @@ -111,7 +111,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache uses: actions/cache@v2 @@ -154,7 +154,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache uses: actions/cache@v2 From 61b4c31236a8f9792c94240ddb4e236f21bbb9ff Mon Sep 17 00:00:00 2001 From: labulakalia Date: Mon, 14 Mar 2022 21:47:59 +0800 Subject: [PATCH 24/92] fix when index name is "type", parseFieldIndexes will set index TYPE is "TYPE" (#5155) * fix index name is type, parseFieldIndexes will set index TYPE is "TYPE" * check TYPE empty --- schema/index.go | 11 ++++++----- schema/index_test.go | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/schema/index.go b/schema/index.go index 5f775f30f..16d096b76 100644 --- a/schema/index.go +++ b/schema/index.go @@ -89,11 +89,12 @@ func parseFieldIndexes(field *Field) (indexes []Index) { k := strings.TrimSpace(strings.ToUpper(v[0])) if k == "INDEX" || k == "UNIQUEINDEX" { var ( - name string - tag = strings.Join(v[1:], ":") - idx = strings.Index(tag, ",") - settings = ParseTagSetting(tag, ",") - length, _ = strconv.Atoi(settings["LENGTH"]) + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") + settings = ParseTagSetting(tagSetting, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) ) if idx == -1 { diff --git a/schema/index_test.go b/schema/index_test.go index bc6bb8b64..3c4582bb4 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -18,6 +18,7 @@ type UserIndex struct { Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` + Name7 string `gorm:"index:type"` } func TestParseIndex(t *testing.T) { @@ -78,6 +79,11 @@ func TestParseIndex(t *testing.T) { Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, }, + "type": { + Name: "type", + Type: "", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, + }, } indices := user.ParseIndexes() From 6befa0c947e0107f241663e4312a74bddd0a4ffe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Mar 2022 11:22:25 +0800 Subject: [PATCH 25/92] Refactor preload error check --- callbacks/query.go | 5 +++++ finisher_api.go | 4 ---- tests/count_test.go | 14 +++++++++++--- tests/go.mod | 2 +- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 03798859d..04f35c7e7 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -186,6 +186,11 @@ func BuildQuerySQL(db *gorm.DB) { func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { + if db.Statement.Schema == nil { + db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) + return + } + preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { preloadFields := strings.Split(name, ".") diff --git a/finisher_api.go b/finisher_api.go index 4b428a59f..b4d29b710 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -369,10 +369,6 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() - if len(tx.Statement.Preloads) > 0 { - tx.AddError(ErrPreloadNotAllowed) - return - } if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest defer func() { diff --git a/tests/count_test.go b/tests/count_test.go index b63a55fcc..b71e3de54 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -150,8 +150,16 @@ func TestCount(t *testing.T) { Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). Preload("Toys", func(db *gorm.DB) *gorm.DB { return db.Table("toys").Select("name") - }).Count(&count12).Error; err != gorm.ErrPreloadNotAllowed { - t.Errorf("should returns preload not allowed error, but got %v", err) + }).Count(&count12).Error; err == nil { + t.Errorf("error should raise when using preload without schema") + } + + var count13 int64 + if err := DB.Model(User{}). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count13).Error; err != nil { + t.Errorf("no error should raise when using count with preload, but got %v", err) } - } diff --git a/tests/go.mod b/tests/go.mod index c65ea953c..4ef7fbe28 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect + golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 From 63ac66b56988e1a22c8a3b41d4f1fbf9a8f5d0bc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Mar 2022 11:34:27 +0800 Subject: [PATCH 26/92] Support default tag for time.Time --- schema/field.go | 5 +++++ tests/default_value_test.go | 18 ++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/schema/field.go b/schema/field.go index 826680c5a..0d7085a90 100644 --- a/schema/field.go +++ b/schema/field.go @@ -259,6 +259,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } + if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { + if field.DefaultValueInterface, err = now.Parse(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse default value `%v` for field %v", field.DefaultValue, field.Name) + } + } case reflect.Array, reflect.Slice: if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 5e00b1546..918f0796d 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "gorm.io/gorm" ) @@ -9,12 +10,13 @@ import ( func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model - Email string `gorm:"not null;index:,unique"` - Name string `gorm:"notNull;default:foo"` - Name2 string `gorm:"size:233;not null;default:'foo'"` - Name3 string `gorm:"size:233;notNull;default:''"` - Age int `gorm:"default:18"` - Enabled bool `gorm:"default:true"` + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"notNull;default:foo"` + Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;notNull;default:''"` + Age int `gorm:"default:18"` + Created time.Time `gorm:"default:2000-01-02"` + Enabled bool `gorm:"default:true"` } DB.Migrator().DropTable(&Harumph{}) @@ -26,14 +28,14 @@ func TestDefaultValue(t *testing.T) { harumph := Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) - } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) - } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled { + } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to find created data with default data, got %+v", result) } } From f3e2da5ba359f0d672249fc52f54ae41c5a66d3a Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 17 Mar 2022 22:51:56 +0800 Subject: [PATCH 27/92] Added offset when scanning the result back to struct, close #5143 commit 9a2058164d44c98d7b586b87bed1757f89d6fad7 Author: Jinzhu Date: Thu Mar 17 22:34:19 2022 +0800 Refactor #5143 commit c259de21768936428c9d89f7b31afb95b8acb36a Author: Hasan Date: Mon Mar 14 20:04:01 2022 +0545 Update scan_test.go commit 09f127b49151a52fbb8b354a03e6610d4f70262f Author: Hasan Date: Mon Mar 14 19:23:47 2022 +0545 Added test for scanning embedded data into structs commit aeaca493cf412def7813d36fd6a68acc832bf79f Author: Hasan Date: Tue Mar 8 04:08:16 2022 +0600 Added offset when scanning the result back to struct --- scan.go | 22 +++++++++++++++++----- tests/go.mod | 2 +- tests/scan_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/scan.go b/scan.go index a4243d12d..89d923543 100644 --- a/scan.go +++ b/scan.go @@ -156,10 +156,11 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + selectedColumnsMap = make(map[string]int, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -194,7 +195,18 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { if sch != nil { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field + if curIndex, ok := selectedColumnsMap[column]; ok { + for fieldIndex, selectField := range sch.Fields[curIndex:] { + if selectField.DBName == column && selectField.Readable { + selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = selectField + break + } + } + } else { + fields[idx] = field + selectedColumnsMap[column] = idx + } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { diff --git a/tests/go.mod b/tests/go.mod index 4ef7fbe28..9dfa26ff0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect - github.com/jinzhu/now v1.1.4 + github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect diff --git a/tests/scan_test.go b/tests/scan_test.go index 1a188facf..ec1e652fc 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -10,6 +10,11 @@ import ( . "gorm.io/gorm/utils/tests" ) +type PersonAddressInfo struct { + Person *Person `gorm:"embedded"` + Address *Address `gorm:"embedded"` +} + func TestScan(t *testing.T) { user1 := User{Name: "ScanUser1", Age: 1} user2 := User{Name: "ScanUser2", Age: 10} @@ -156,3 +161,34 @@ func TestScanRows(t *testing.T) { t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) } } + +func TestScanToEmbedded(t *testing.T) { + person1 := Person{Name: "person 1"} + person2 := Person{Name: "person 2"} + DB.Save(&person1).Save(&person2) + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + DB.Save(&address1).Save(&address2) + + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)}) + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)}) + DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)}) + + var personAddressInfoList []*PersonAddressInfo + if err := DB.Select("people.*, addresses.*"). + Table("people"). + Joins("inner join person_addresses on people.id = person_addresses.person_id"). + Joins("inner join addresses on person_addresses.address_id = addresses.id"). + Find(&personAddressInfoList).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + + for _, info := range personAddressInfoList { + if info.Person != nil { + if info.Person.ID == person1.ID && info.Person.Name != person1.Name { + t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) + } + } + } +} From 2990790fbc4c1a3b38a3a7bde15620623264461d Mon Sep 17 00:00:00 2001 From: Mikhail Faraponov <11322032+moredure@users.noreply.github.com> Date: Thu, 17 Mar 2022 16:54:30 +0200 Subject: [PATCH 28/92] Use WriteByte for single byte operations (#5167) Co-authored-by: Mikhail Faraponov --- clause/limit.go | 2 +- clause/where.go | 4 ++-- statement.go | 4 ++-- utils/tests/dummy_dialecter.go | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/clause/limit.go b/clause/limit.go index 2082f4d98..184f6025d 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -21,7 +21,7 @@ func (limit Limit) Build(builder Builder) { } if limit.Offset > 0 { if limit.Limit > 0 { - builder.WriteString(" ") + builder.WriteByte(' ') } builder.WriteString("OFFSET ") builder.WriteString(strconv.Itoa(limit.Offset)) diff --git a/clause/where.go b/clause/where.go index 10b6df856..a29401cfe 100644 --- a/clause/where.go +++ b/clause/where.go @@ -72,9 +72,9 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { } if wrapInParentheses { - builder.WriteString(`(`) + builder.WriteByte('(') expr.Build(builder) - builder.WriteString(`)`) + builder.WriteByte(')') wrapInParentheses = false } else { expr.Build(builder) diff --git a/statement.go b/statement.go index cb4717766..abf646b89 100644 --- a/statement.go +++ b/statement.go @@ -130,7 +130,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteByte('(') for idx, d := range v { if idx > 0 { - writer.WriteString(",") + writer.WriteByte(',') } stmt.QuoteTo(writer, d) } @@ -143,7 +143,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteByte('(') for idx, d := range v { if idx > 0 { - writer.WriteString(",") + writer.WriteByte(',') } stmt.DB.Dialector.QuoteTo(writer, d) } diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 9543f750a..2990c20f5 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -49,7 +49,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) { shiftDelimiter = 0 underQuoted = false continuousBacktick = 0 - writer.WriteString("`") + writer.WriteByte('`') } writer.WriteByte(v) continue @@ -74,7 +74,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) { if continuousBacktick > 0 && !selfQuoted { writer.WriteString("``") } - writer.WriteString("`") + writer.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string { From 9b9ae325bb1fe6e209823d576e70e5e8e6ceccb2 Mon Sep 17 00:00:00 2001 From: chenrui <631807682@qq.com> Date: Thu, 17 Mar 2022 23:53:31 +0800 Subject: [PATCH 29/92] fix: circular reference save, close #5140 commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144 Author: Jinzhu Date: Thu Mar 17 23:49:21 2022 +0800 Refactor #5140 commit 6e3ca2d1aa09943dcfb5d9a4b93bea28212f71be Author: a631807682 <631807682@qq.com> Date: Sun Mar 13 12:52:08 2022 +0800 test: add test for LoadOrStoreVisitMap commit 9d5c68e41000fd15dea124797dd5f2656bf6b304 Author: chenrui Date: Thu Mar 10 20:33:47 2022 +0800 chore: add more comment commit bfffefb179c883389b72bef8f04469c0a8418043 Author: chenrui Date: Thu Mar 10 20:28:48 2022 +0800 fix: should check values has been saved instead of rel.Name commit e55cdfa4b3fbcf8b80baf009e8ddb2e40d471494 Author: chenrui Date: Tue Mar 8 17:48:01 2022 +0800 chore: go lint commit fe4715c5bd4ac28950c97dded9848710d8becb88 Author: chenrui Date: Tue Mar 8 17:27:24 2022 +0800 chore: add test comment commit 326862f3f8980482a09d7d1a7f4d1011bb8a7c59 Author: chenrui Date: Tue Mar 8 17:22:33 2022 +0800 fix: circular reference save --- callbacks/associations.go | 41 ++++++++++++++++++++++++++++++------- callbacks/helper.go | 30 +++++++++++++++++++++++++++ callbacks/visit_map_test.go | 36 ++++++++++++++++++++++++++++++++ tests/associations_test.go | 41 +++++++++++++++++++++++++++++++++++++ tests/tests_test.go | 2 +- utils/tests/models.go | 14 +++++++++++++ 6 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 callbacks/visit_map_test.go diff --git a/callbacks/associations.go b/callbacks/associations.go index d6fd21ded..3b204ab60 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } if elems.Len() > 0 { - if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { rv = rv.Addr() } - if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { @@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) } } } @@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } } @@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + saveAssociations(db, rel, elems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { @@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ return } -func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { +func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + // stop save association loop + if checkAssociationsSaved(db, rValues) { + return nil + } + var ( selects, omits []string onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) refName = rel.Name + "." + values = rValues.Interface() ) for name, ok := range selectColumns { @@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, return db.AddError(tx.Create(values).Error) } + +// check association values has been saved +// if values kind is Struct, check it has been saved +// if values kind is Slice/Array, check all items have been saved +var visitMapStoreKey = "gorm:saved_association_map" + +func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { + if visit, ok := db.Get(visitMapStoreKey); ok { + if v, ok := visit.(*visitMap); ok { + if loadOrStoreVisitMap(v, values) { + return true + } + } + } else { + vistMap := make(visitMap) + loadOrStoreVisitMap(&vistMap, values) + db.Set(visitMapStoreKey, &vistMap) + } + + return false +} diff --git a/callbacks/helper.go b/callbacks/helper.go index a5eb047e5..71b67de59 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -1,6 +1,7 @@ package callbacks import ( + "reflect" "sort" "gorm.io/gorm" @@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) { return } } + +type visitMap = map[reflect.Value]bool + +// Check if circular values, return true if loaded +func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice, reflect.Array: + loaded = true + for i := 0; i < v.Len(); i++ { + if !loadOrStoreVisitMap(vistMap, v.Index(i)) { + loaded = false + } + } + case reflect.Struct, reflect.Interface: + if v.CanAddr() { + p := v.Addr() + if _, ok := (*vistMap)[p]; ok { + return true + } + (*vistMap)[p] = true + } + } + + return +} diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go new file mode 100644 index 000000000..b1fb86dbe --- /dev/null +++ b/callbacks/visit_map_test.go @@ -0,0 +1,36 @@ +package callbacks + +import ( + "reflect" + "testing" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} diff --git a/tests/associations_test.go b/tests/associations_test.go index 5ce98c7dc..32f6525b8 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) { t.Errorf("Failed to preload AppliesToProduct") } } + +func TestSaveBelongsCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent} + DB.Create(&child) + + parent.FavChildID = child.ID + parent.FavChild = &child + DB.Save(&parent) + + var parent1 Parent + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") + + // Save and Updates is the same + DB.Updates(&parent) + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") +} + +func TestSaveHasManyCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} + child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} + + parent.Children = []*Child{&child, &child1} + DB.Save(&parent) + + var children []*Child + DB.Where("parent_id = ?", parent.ID).Find(&children) + if len(children) != len(parent.Children) || + children[0].ID != parent.Children[0].ID || + children[1].ID != parent.Children[1].ID { + t.Errorf("circular reference children save not equal children:%v parent.Children:%v", + children, parent.Children) + } +} diff --git a/tests/tests_test.go b/tests/tests_test.go index 11b6f0675..08f4f1932 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index c84f9cae9..22e8e659f 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -80,3 +80,17 @@ type Order struct { Coupon *Coupon CouponID string } + +type Parent struct { + gorm.Model + FavChildID uint + FavChild *Child + Children []*Child +} + +type Child struct { + gorm.Model + Name string + ParentID *uint + Parent *Parent +} From c2e36ebe62a0e79649aff1a539b39ace86bc6bab Mon Sep 17 00:00:00 2001 From: chenrui <631807682@qq.com> Date: Fri, 18 Mar 2022 01:07:49 +0800 Subject: [PATCH 30/92] fix: soft delete for join, close #5132 commit a83023bdfc0dc6eaccc6704b64ff6436c2fe7725 Author: Jinzhu Date: Fri Mar 18 01:05:25 2022 +0800 Refactor #5132 commit 8559f51102c01be6c19913c0bc3a5771721ff1f5 Author: chenrui Date: Mon Mar 7 20:33:12 2022 +0800 fix: should add deleted_at exprs for every joins commit 2b7a1bdcf3eff9d23253173d21e73c1f056f9be4 Author: chenrui Date: Mon Mar 7 14:46:48 2022 +0800 test: move debug flag commit ce13a2a7bc50d2c23678806acf65dbd589827c77 Author: chenrui Date: Mon Mar 7 14:39:56 2022 +0800 fix: soft delete for join.on --- callbacks/query.go | 38 ++++++++++++++++++++++++++------------ tests/helper_test.go | 5 +++++ tests/joins_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 04f35c7e7..c4c804062 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -145,19 +145,33 @@ func BuildQuerySQL(db *gorm.DB) { } } - if join.On != nil { - onStmt := gorm.Statement{Table: tableAliasName, DB: db} - join.On.Build(&onStmt) - onSQL := onStmt.SQL.String() - vars := onStmt.Vars - for idx, v := range onStmt.Vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) } - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } } joins = append(joins, clause.Join{ @@ -172,8 +186,8 @@ func BuildQuerySQL(db *gorm.DB) { } } - db.Statement.Joins = nil db.Statement.AddClause(clause.From{Joins: joins}) + db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) } diff --git a/tests/helper_test.go b/tests/helper_test.go index eee34e994..7ee2a5761 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -19,6 +19,7 @@ type Config struct { Team int Languages int Friends int + NamedPet bool } func GetUser(name string, config Config) *User { @@ -65,6 +66,10 @@ func GetUser(name string, config Config) *User { user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) } + if config.NamedPet { + user.NamedPet = &Pet{Name: name + "_namepet"} + } + return &user } diff --git a/tests/joins_test.go b/tests/joins_test.go index 4c9cffae9..0f02f3f91 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -200,3 +200,34 @@ func TestJoinCount(t *testing.T) { t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) } } + +func TestJoinWithSoftDeleted(t *testing.T) { + DB = DB.Debug() + + user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) + DB.Create(&user) + + var user1 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID) + if user1.NamedPet == nil || user1.Account.ID == 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user1) + } + + // Account should empty + DB.Delete(&user1.Account) + + var user2 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID) + if user2.NamedPet == nil || user2.Account.ID != 0 { + t.Fatalf("joins Account should not empty:%v", user2) + } + + // NamedPet should empty + DB.Delete(&user1.NamedPet) + + var user3 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID) + if user3.NamedPet != nil || user2.Account.ID != 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user2) + } +} From 5431da8caf09ad19256170df17e2e75eb541f4a5 Mon Sep 17 00:00:00 2001 From: chenrui <631807682@qq.com> Date: Fri, 18 Mar 2022 13:38:46 +0800 Subject: [PATCH 31/92] fix: preload panic when model and dest different close #5130 commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e Author: Jinzhu Date: Fri Mar 18 13:37:22 2022 +0800 Refactor #5130 commit 40cbba49f374c9bae54f80daee16697ae45e905b Author: chenrui Date: Sat Mar 5 17:36:56 2022 +0800 test: fix test fail commit 66d3f078291102a30532b6a9d97c757228a9b543 Author: chenrui Date: Sat Mar 5 17:29:09 2022 +0800 test: drop table and auto migrate commit 7cbf019a930019476a97ac7ac0f5fc186e8d5b42 Author: chenrui Date: Sat Mar 5 15:27:45 2022 +0800 fix: preload panic when model and dest different --- callbacks/preload.go | 56 ++++++++++++++++++------------------- callbacks/query.go | 15 ++++++++-- chainable_api.go | 5 +++- tests/preload_suits_test.go | 2 +- tests/preload_test.go | 18 ++++++++++++ 5 files changed, 63 insertions(+), 33 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 2363a8cab..888f832d5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -10,10 +10,9 @@ import ( "gorm.io/gorm/utils" ) -func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { +func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( - reflectValue = db.Statement.ReflectValue - tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + reflectValue = tx.Statement.ReflectValue relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field @@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload inlineConds []interface{} ) - db.Statement.Settings.Range(func(k, v interface{}) bool { - tx.Statement.Settings.Store(k, v) - return true - }) - if rel.JoinTable != nil { var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) @@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { - return + return nil } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { + return err + } // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) + joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { - return + return nil } } @@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + return err + } } fieldValues := make([]interface{}, len(relForeignFields)) @@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } } } @@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] if !ok { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", - elem.Interface())) - continue + return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) } for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) + reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(db.Statement.Context, data, elem.Interface()) + rel.Field.Set(tx.Statement.Context, data, elem.Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) + rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } } + + return tx.Error } diff --git a/callbacks/query.go b/callbacks/query.go index c4c804062..6ba3dd388 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -237,9 +237,20 @@ func Preload(db *gorm.DB) { } sort.Strings(preloadNames) + preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + preloadDB.Statement.Settings.Store(k, v) + return true + }) + + if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + return + } + preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + for _, name := range preloadNames { - if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { - preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) + if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/chainable_api.go b/chainable_api.go index 173479d30..38ad5cdee 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = tables[1] - } else { + } else if name != "" { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = name + } else { + tx.Statement.TableExpr = nil + tx.Statement.Table = "" } return } diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 0ef8890b0..b5b6a70f3 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) { } if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) } if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { diff --git a/tests/preload_test.go b/tests/preload_test.go index adb54ee19..cb4343ec0 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) { } wg.Wait() } + +func TestPreloadWithDiffModel(t *testing.T) { + user := *GetUser("preload_with_diff_model", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var result struct { + Something string + User + } + + DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select( + "users.*, 'yo' as something").First(&result, "name = ?", user.Name) + + CheckUser(t, user, result.User) +} From e6f7da0e0dbc193df883f799a4650d0a86507376 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Mar 2022 14:30:30 +0800 Subject: [PATCH 32/92] Support Variable Relation --- schema/relationship.go | 6 +++++- schema/relationship_test.go | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/schema/relationship.go b/schema/relationship.go index eae8ab0b1..b51008979 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -416,6 +416,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } else { var primaryFields []*Field + var primarySchemaName = primarySchema.Name + if primarySchemaName == "" { + primarySchemaName = relation.FieldSchema.Name + } if len(relation.primaryKeys) > 0 { for _, primaryKey := range relation.primaryKeys { @@ -428,7 +432,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } for _, primaryField := range primaryFields { - lookUpName := primarySchema.Name + primaryField.Name + lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 40ffc3249..6fffbfcbc 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -491,6 +491,26 @@ func TestEmbeddedRelation(t *testing.T) { } } +func TestVariableRelation(t *testing.T) { + var result struct { + User + } + + checkStructRelation(t, &result, Relation{ + Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account", + References: []Reference{ + {"ID", "", "UserID", "Account", "", true}, + }, + }) + + checkStructRelation(t, &result, Relation{ + Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company", + References: []Reference{ + {"ID", "Company", "CompanyID", "", "", false}, + }, + }) +} + func TestSameForeignKey(t *testing.T) { type UserAux struct { gorm.Model From 3c00980e01a6a16095b9fafddedd3217ad4b7357 Mon Sep 17 00:00:00 2001 From: ag9920 Date: Fri, 18 Mar 2022 17:12:17 +0800 Subject: [PATCH 33/92] fix: serializer use default valueOf in assignInterfacesToValue, close #5168 commit 58e1b2bffbc216f2862d040fb545a8a486e473b6 Author: Jinzhu Date: Fri Mar 18 17:06:43 2022 +0800 Refactor #5168 commit fb9233011d209174e8223e970f0f732412852908 Author: ag9920 Date: Thu Mar 17 21:23:28 2022 +0800 fix: serializer use default valueOf in assignInterfacesToValue --- schema/field.go | 80 ++++++++++++++++++++++------------------ tests/joins_test.go | 2 - tests/serializer_test.go | 51 ++++++++++++++++++++++++- 3 files changed, 95 insertions(+), 38 deletions(-) diff --git a/schema/field.go b/schema/field.go index 0d7085a90..45ec66e15 100644 --- a/schema/field.go +++ b/schema/field.go @@ -435,39 +435,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // Setup NewValuePool - var fieldValue = reflect.New(field.FieldType).Interface() - if field.Serializer != nil { - field.NewValuePool = &sync.Pool{ - New: func() interface{} { - return &serializer{ - Field: field, - Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), - } - }, - } - } else if _, ok := fieldValue.(sql.Scanner); !ok { - // set default NewValuePool - switch field.IndirectFieldType.Kind() { - case reflect.String: - field.NewValuePool = stringPool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.NewValuePool = intPool - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.NewValuePool = uintPool - case reflect.Float32, reflect.Float64: - field.NewValuePool = floatPool - case reflect.Bool: - field.NewValuePool = boolPool - default: - if field.IndirectFieldType == TimeReflectType { - field.NewValuePool = timePool - } - } - } - - if field.NewValuePool == nil { - field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) - } + field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] @@ -512,7 +480,7 @@ func (field *Field) setupValuerAndSetter() { s = field.Serializer } - return serializer{ + return &serializer{ Field: field, SerializeValuer: s, Destination: v, @@ -943,7 +911,9 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { - if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if s.fieldValue != nil { + err = oldFieldSetter(ctx, value, s.fieldValue) + } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) @@ -959,3 +929,43 @@ func (field *Field) setupValuerAndSetter() { } } } + +func (field *Field) setupNewValuePool() { + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + } + }, + } + } else if _, ok := fieldValue.(sql.Scanner); !ok { + field.setupDefaultNewValuePool() + } + + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } +} + +func (field *Field) setupDefaultNewValuePool() { + // set default NewValuePool + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool + } + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go index 0f02f3f91..bb5352ef6 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -202,8 +202,6 @@ func TestJoinCount(t *testing.T) { } func TestJoinWithSoftDeleted(t *testing.T) { - DB = DB.Debug() - user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) DB.Create(&user) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index a8a4e28f8..ce60280ec 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -42,7 +42,7 @@ func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst re case string: *es = EncryptedString(strings.TrimPrefix(value, "hello")) default: - return fmt.Errorf("unsupported data %v", dbValue) + return fmt.Errorf("unsupported data %#v", dbValue) } return nil } @@ -83,4 +83,53 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) + +} + +func TestSerializerAssignFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("ag9920"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jing1", "age": 11}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Shadyside", + IsIntern: false, + }, + } + + // first time insert record + out := SerializerStruct{} + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + AssertEqual(t, result, out) + + //update record + data.Roles = append(data.Roles, "r3") + data.JobInfo.Location = "Gates Hillman Complex" + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result.Roles, data.Roles) + AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location) } From d402765f694ade8fd3a0da1b7a2f9d2fa4453957 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 18 Mar 2022 20:11:23 +0800 Subject: [PATCH 34/92] test: fix utils.AssertEqual (#5172) --- tests/query_test.go | 4 +++- utils/tests/utils.go | 29 +++++++++++++++++------------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/query_test.go b/tests/query_test.go index 6542774a4..af2b8d4b4 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -583,7 +583,9 @@ func TestPluck(t *testing.T) { if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { t.Errorf("got error when pluck name: %v", err) } - AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2))) + + sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] }) + AssertEqual(t, names, names2) var ids []int if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 817e4b0bd..661d727fd 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -83,20 +83,22 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } if reflect.ValueOf(got).Kind() == reflect.Struct { - if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { - exported := false - for i := 0; i < reflect.ValueOf(got).NumField(); i++ { - if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { - exported = true - field := reflect.ValueOf(got).Field(i) - t.Run(fieldStruct.Name, func(t *testing.T) { - AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) - }) + if reflect.ValueOf(expect).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + exported := false + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + exported = true + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } } - } - if exported { - return + if exported { + return + } } } } @@ -107,6 +109,9 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() isEqual() + } else { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return } } } From 540b47571a2c74134c2a8eb02d5a8ef70b0bf8d6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Mar 2022 20:57:33 +0800 Subject: [PATCH 35/92] Fix update select clause with before/after expressions, close #5164 --- chainable_api.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 38ad5cdee..68b4d1aa5 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,7 +93,11 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } - delete(tx.Statement.Clauses, "SELECT") + + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } case string: if strings.Count(v, "?") >= len(args) && len(args) > 0 { tx.Statement.AddClause(clause.Select{ @@ -123,7 +127,10 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } } - delete(tx.Statement.Clauses, "SELECT") + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } } default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) From 0097b39a77b9573d63f89c22f3cea0aae103a77f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 20 Mar 2022 08:55:08 +0800 Subject: [PATCH 36/92] Should ignore error when parsing default value for time, close #5176 --- schema/field.go | 4 ++-- tests/go.mod | 2 +- tests/postgres_test.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index 45ec66e15..962918164 100644 --- a/schema/field.go +++ b/schema/field.go @@ -260,8 +260,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { - if field.DefaultValueInterface, err = now.Parse(field.DefaultValue); err != nil { - schema.err = fmt.Errorf("failed to parse default value `%v` for field %v", field.DefaultValue, field.Name) + if t, err := now.Parse(field.DefaultValue); err == nil { + field.DefaultValueInterface = t } } case reflect.Array, reflect.Slice: diff --git a/tests/go.mod b/tests/go.mod index 9dfa26ff0..17e5d3506 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -14,7 +14,7 @@ require ( gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 gorm.io/driver/sqlserver v1.3.1 - gorm.io/gorm v1.23.1 + gorm.io/gorm v1.23.3 ) replace gorm.io/gorm => ../ diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 418b713e5..66b988c3a 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -19,7 +19,7 @@ func TestPostgres(t *testing.T) { Name string `gorm:"check:name_checker,name <> ''"` Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` - UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` Things pq.StringArray `gorm:"type:text[]"` } From 2d5cb997ed4d0e8f53fa1662111ad2cb053caf9c Mon Sep 17 00:00:00 2001 From: Jin Date: Sun, 20 Mar 2022 09:02:45 +0800 Subject: [PATCH 37/92] style: fix linter check for NamingStrategy and onConflictOption (#5174) --- callbacks/associations.go | 4 ++-- schema/naming.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3b204ab60..644ef1855 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -323,7 +323,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) for _, dbName := range s.PrimaryFieldDBNames { @@ -349,7 +349,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Val var ( selects, omits []string - onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns) refName = rel.Name + "." values = rValues.Interface() ) diff --git a/schema/naming.go b/schema/naming.go index 47a2b3636..a258beed3 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -85,9 +85,9 @@ func (ns NamingStrategy) IndexName(table, column string) string { } func (ns NamingStrategy) formatName(prefix, table, name string) string { - formattedName := strings.Replace(strings.Join([]string{ + formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, - }, "_"), ".", "_", -1) + }, "_"), ".", "_") if utf8.RuneCountInString(formattedName) > 64 { h := sha1.New() From d66f37ad322cbda02bb873b5b2f1093296672b49 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 21 Mar 2022 10:50:14 +0800 Subject: [PATCH 38/92] Add Go 1.18 --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3e15427ca..ad4c99179 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -39,7 +39,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -83,7 +83,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -125,7 +125,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} From a7b3b5956fad0ae536147a19e89300af0462d74d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 22 Mar 2022 22:42:36 +0800 Subject: [PATCH 39/92] Fix hooks order, close https://github.com/go-gorm/gorm.io/pull/519 --- callbacks/create.go | 15 +++++++++------ callbacks/update.go | 16 ++++++++++------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 6e2883f79..0a43cacb8 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/utils" ) +// BeforeCreate before create hooks func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -31,6 +32,7 @@ func BeforeCreate(db *gorm.DB) { } } +// Create create hook func Create(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.CreateClauses, "RETURNING") @@ -146,20 +148,21 @@ func Create(config *Config) func(db *gorm.DB) { } } +// AfterCreate after create hooks func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { - if db.Statement.Schema.AfterSave { - if i, ok := value.(AfterSaveInterface); ok { + if db.Statement.Schema.AfterCreate { + if i, ok := value.(AfterCreateInterface); ok { called = true - db.AddError(i.AfterSave(tx)) + db.AddError(i.AfterCreate(tx)) } } - if db.Statement.Schema.AfterCreate { - if i, ok := value.(AfterCreateInterface); ok { + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { called = true - db.AddError(i.AfterCreate(tx)) + db.AddError(i.AfterSave(tx)) } } return called diff --git a/callbacks/update.go b/callbacks/update.go index da03261ec..1964973b1 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -29,6 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { } } +// BeforeUpdate before update hooks func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -51,6 +52,7 @@ func BeforeUpdate(db *gorm.DB) { } } +// Update update hook func Update(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") @@ -99,22 +101,24 @@ func Update(config *Config) func(db *gorm.DB) { } } +// AfterUpdate after update hooks func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { - if db.Statement.Schema.AfterSave { - if i, ok := value.(AfterSaveInterface); ok { + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(AfterUpdateInterface); ok { called = true - db.AddError(i.AfterSave(tx)) + db.AddError(i.AfterUpdate(tx)) } } - if db.Statement.Schema.AfterUpdate { - if i, ok := value.(AfterUpdateInterface); ok { + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { called = true - db.AddError(i.AfterUpdate(tx)) + db.AddError(i.AfterSave(tx)) } } + return called }) } From f92e6747cb12d5a5bc2bf7e0d76cb8e5f69cd637 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 23 Mar 2022 17:24:25 +0800 Subject: [PATCH 40/92] Handle field set value error --- callbacks/associations.go | 14 +++++++------- callbacks/create.go | 18 +++++++++--------- callbacks/preload.go | 14 +++++++------- callbacks/update.go | 2 +- scan.go | 4 ++-- schema/field.go | 5 +++-- statement.go | 8 ++++---- tests/go.mod | 2 +- 8 files changed, 34 insertions(+), 33 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 644ef1855..fd3141cfe 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -159,9 +159,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) - ref.ForeignKey.Set(db.Statement.Context, f, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -193,9 +193,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) - ref.ForeignKey.Set(db.Statement.Context, elem, pv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) } } @@ -261,12 +261,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) - ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) } else { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) - ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } } joins = reflect.Append(joins, joinValue) diff --git a/callbacks/create.go b/callbacks/create.go index 0a43cacb8..e94b7eca6 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -121,7 +121,7 @@ func Create(config *Config) func(db *gorm.DB) { _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -133,7 +133,7 @@ func Create(config *Config) func(db *gorm.DB) { } if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -141,7 +141,7 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Struct: _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } @@ -227,13 +227,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(stmt.Context, rv, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.Context, rv, curTime) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.Context, rv, curTime) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } @@ -267,13 +267,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.Context, stmt.ReflectValue, curTime) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.Context, stmt.ReflectValue, curTime) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 888f832d5..ea2570ba3 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -123,17 +123,17 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) } } } @@ -158,12 +158,12 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(tx.Statement.Context, data, elem.Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) } else { - rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 1964973b1..01f40509e 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]) + db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) } } } diff --git a/scan.go b/scan.go index 89d923543..42642ec6d 100644 --- a/scan.go +++ b/scan.go @@ -69,7 +69,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values for idx, field := range fields { if field != nil { if len(joinFields) == 0 || joinFields[idx][0] == nil { - field.Set(db.Statement.Context, reflectValue, values[idx]) + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) } else { relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { @@ -79,7 +79,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values relValue.Set(reflect.New(relValue.Type().Elem())) } - joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]) + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } // release data to pool diff --git a/schema/field.go b/schema/field.go index 962918164..3b5cc5c55 100644 --- a/schema/field.go +++ b/schema/field.go @@ -12,6 +12,7 @@ import ( "time" "github.com/jinzhu/now" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) @@ -567,8 +568,8 @@ func (field *Field) setupValuerAndSetter() { if v, err = valuer.Value(); err == nil { err = setter(ctx, value, v) } - } else { - return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) + } else if _, ok := v.(clause.Expr); !ok { + return fmt.Errorf("failed to set value %#v to field %s", v, field.Name) } } diff --git a/statement.go b/statement.go index abf646b89..9fcee09c0 100644 --- a/statement.go +++ b/statement.go @@ -562,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . switch destValue.Kind() { case reflect.Struct: - field.Set(stmt.Context, destValue, value) + stmt.AddError(field.Set(stmt.Context, destValue, value)) default: stmt.AddError(ErrInvalidData) } @@ -572,10 +572,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) } } else { - field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { @@ -583,7 +583,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . return } - field.Set(stmt.Context, stmt.ReflectValue, value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) } } else { stmt.AddError(ErrInvalidField) diff --git a/tests/go.mod b/tests/go.mod index 17e5d3506..b85ebdadf 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect + golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 From 9a4d10be64738f0c1f7a86841d56e2fe3165e3f0 Mon Sep 17 00:00:00 2001 From: Jin Date: Thu, 24 Mar 2022 09:31:58 +0800 Subject: [PATCH 41/92] style: fix coding typo (#5184) --- migrator/column_type.go | 2 +- tests/main_test.go | 6 ++---- tests/migrate_test.go | 2 +- tests/sql_builder_test.go | 10 +++++----- tests/upsert_test.go | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/migrator/column_type.go b/migrator/column_type.go index cc1331b92..c6fdd6b2d 100644 --- a/migrator/column_type.go +++ b/migrator/column_type.go @@ -44,7 +44,7 @@ func (ct ColumnType) DatabaseTypeName() string { return ct.SQLColumnType.DatabaseTypeName() } -// ColumnType returns the database type of the column. lke `varchar(16)` +// ColumnType returns the database type of the column. like `varchar(16)` func (ct ColumnType) ColumnType() (columnType string, ok bool) { return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid } diff --git a/tests/main_test.go b/tests/main_test.go index 5b8c7dbb2..997714b9b 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -43,10 +43,8 @@ func TestExceptionsWithInvalidSql(t *testing.T) { func TestSetAndGet(t *testing.T) { if value, ok := DB.Set("hello", "world").Get("hello"); !ok { t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } + } else if value.(string) != "world" { + t.Errorf("Set value should not be changed") } if _, ok := DB.Get("non_existing"); ok { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 94f562b47..f72c4c085 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -258,7 +258,7 @@ func TestMigrateTable(t *testing.T) { DB.Migrator().DropTable("new_table_structs") if DB.Migrator().HasTable(&NewTableStruct{}) { - t.Fatal("should not found droped table") + t.Fatal("should not found dropped table") } } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index bc917c32d..a7630271e 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -360,7 +360,7 @@ func TestToSQL(t *testing.T) { }) assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) - // after model chagned + // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } @@ -426,13 +426,13 @@ func TestToSQL(t *testing.T) { }) assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) - // after model chagned + // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } } -// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals. +// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. func assertEqualSQL(t *testing.T, expected string, actually string) { t.Helper() @@ -440,7 +440,7 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { expected = replaceQuoteInSQL(expected) actually = replaceQuoteInSQL(actually) - // ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update. + // ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update. updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) @@ -462,7 +462,7 @@ func replaceQuoteInSQL(sql string) string { // convert single quote into double quote sql = strings.ReplaceAll(sql, `'`, `"`) - // convert dialect speical quote into double quote + // convert dialect special quote into double quote switch DB.Dialector.Name() { case "postgres": sql = strings.ReplaceAll(sql, `"`, `"`) diff --git a/tests/upsert_test.go b/tests/upsert_test.go index c5d196055..f90c4518f 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -319,7 +319,7 @@ func TestUpdateWithMissWhere(t *testing.T) { tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) if err := tx.Error; err != nil { - t.Fatalf("failed to update user,missing where condtion,err=%+v", err) + t.Fatalf("failed to update user,missing where condition,err=%+v", err) } if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { From 3d7019a7c236890aae9716335c7d5b6dae116d17 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 24 Mar 2022 09:34:06 +0800 Subject: [PATCH 42/92] fix: throw err if association model miss primary key (#5187) --- association.go | 21 +++++++++++++++------ tests/associations_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/association.go b/association.go index 09e79ca60..dc731ff80 100644 --- a/association.go +++ b/association.go @@ -187,8 +187,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) @@ -199,8 +202,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) @@ -229,8 +235,11 @@ func (association *Association) Delete(values ...interface{}) error { } _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) diff --git a/tests/associations_test.go b/tests/associations_test.go index 32f6525b8..bc3dac551 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -261,3 +261,27 @@ func TestSaveHasManyCircularReference(t *testing.T) { children, parent.Children) } } + +func TestAssociationError(t *testing.T) { + DB = DB.Debug() + user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2}) + DB.Create(&user) + + var user1 User + DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1) + + var emptyUser User + var err error + // belongs to + err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has many + err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has one + err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // many to many + err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) +} From 6d40a8343249e208aa79b938a7b0939a631b6b74 Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Thu, 24 Mar 2022 16:30:14 +0800 Subject: [PATCH 43/92] Update README.md add gorm gen --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a3eabe39a..312a3a593 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) +* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) ## Contributing From 6c827ff2e3ffa0e8b7e4c598031f6af8124a7357 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Mar 2022 19:55:05 +0800 Subject: [PATCH 44/92] chore(deps): bump actions/cache from 2 to 3 (#5196) Bumps [actions/cache](https://github.com/actions/cache) from 2 to 3. - [Release notes](https://github.com/actions/cache/releases) - [Commits](https://github.com/actions/cache/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad4c99179..8194e6094 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -71,7 +71,7 @@ jobs: - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -114,7 +114,7 @@ jobs: uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -157,7 +157,7 @@ jobs: uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} From 9dd6ed9c65bcf95e4a4298bcdf1f26670778ba76 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Mar 2022 18:14:29 +0800 Subject: [PATCH 45/92] Scan with Rows interface --- interfaces.go | 10 ++++++++++ scan.go | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/interfaces.go b/interfaces.go index 84dc94bb4..32d49605d 100644 --- a/interfaces.go +++ b/interfaces.go @@ -72,3 +72,13 @@ type Valuer interface { type GetDBConnector interface { GetDBConn() (*sql.DB, error) } + +// Rows rows interface +type Rows interface { + Columns() ([]string, error) + ColumnTypes() ([]*sql.ColumnType, error) + Next() bool + Scan(dest ...interface{}) error + Err() error + Close() error +} diff --git a/scan.go b/scan.go index 42642ec6d..c8da13da9 100644 --- a/scan.go +++ b/scan.go @@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() @@ -99,7 +99,7 @@ const ( ) // Scan scan rows into db statement -func Scan(rows *sql.Rows, db *DB, mode ScanMode) { +func Scan(rows Rows, db *DB, mode ScanMode) { var ( columns, _ = rows.Columns() values = make([]interface{}, len(columns)) From ea8509b77704b152380f8097c59e5ae3b57428bb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Mar 2022 18:48:06 +0800 Subject: [PATCH 46/92] Use defer to close rows to avoid scan panic leak rows --- callbacks/create.go | 4 +++- callbacks/query.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index e94b7eca6..0fe1dc93a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -84,8 +84,10 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., ) if db.AddError(err) == nil { + defer func() { + db.AddError(rows.Close()) + }() gorm.Scan(rows, db, mode) - db.AddError(rows.Close()) } return diff --git a/callbacks/query.go b/callbacks/query.go index 6ba3dd388..6eda52ef4 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -20,8 +20,10 @@ func Query(db *gorm.DB) { db.AddError(err) return } + defer func() { + db.AddError(rows.Close()) + }() gorm.Scan(rows, db, 0) - db.AddError(rows.Close()) } } } From 8333844f7112192ebd203992a67adf01b51ee8a0 Mon Sep 17 00:00:00 2001 From: ZhangShenao <15201440436@163.com> Date: Thu, 31 Mar 2022 20:57:20 +0800 Subject: [PATCH 47/92] fix variable shadowing (#5212) Co-authored-by: Shenao Zhang --- gorm.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gorm.go b/gorm.go index aca7cb5ed..6a6bb0322 100644 --- a/gorm.go +++ b/gorm.go @@ -124,8 +124,8 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { for _, opt := range opts { if opt != nil { - if err := opt.Apply(config); err != nil { - return nil, err + if applyErr := opt.Apply(config); applyErr != nil { + return nil, applyErr } defer func(opt Option) { if errr := opt.AfterInitialize(db); errr != nil { From cd0315334b0fe555500d6f1870c566093d7daa33 Mon Sep 17 00:00:00 2001 From: Goxiaoy Date: Fri, 1 Apr 2022 08:33:39 +0800 Subject: [PATCH 48/92] fix: context missing in association (#5214) --- association.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/association.go b/association.go index dc731ff80..35e10ddd4 100644 --- a/association.go +++ b/association.go @@ -502,7 +502,7 @@ func (association *Association) buildCondition() *DB { if association.Relationship.JoinTable != nil { if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } From f7b52bb649ba803ec149a06fec9e9da7b311d36e Mon Sep 17 00:00:00 2001 From: ZhangShenao <15201440436@163.com> Date: Fri, 1 Apr 2022 08:35:16 +0800 Subject: [PATCH 49/92] unify db receiver name (#5215) Co-authored-by: Shenao Zhang --- finisher_api.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index b4d29b710..aa8e2b5ad 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -207,7 +207,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return tx } -func (tx *DB) assignInterfacesToValue(values ...interface{}) { +func (db *DB) assignInterfacesToValue(values ...interface{}) { for _, value := range values { switch v := value.(type) { case []clause.Expression: @@ -215,40 +215,40 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) + if field := db.Statement.Schema.LookUpField(column); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) + if field := db.Statement.Schema.LookUpField(column.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { - tx.assignInterfacesToValue(andCond.Exprs) + db.assignInterfacesToValue(andCond.Exprs) } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: - if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { - tx.assignInterfacesToValue(exprs) + if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) } default: - if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Struct: for _, f := range s.Fields { if f.Readable { - if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero { - if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { - tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v)) + if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero { + if field := db.Statement.Schema.LookUpField(f.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v)) } } } } } } else if len(values) > 0 { - if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { - tx.assignInterfacesToValue(exprs) + if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) } return } From 9144969c83829d2f14049a6e4882f785a90b6cf9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 2 Apr 2022 17:17:47 +0800 Subject: [PATCH 50/92] Allow to use tag to disable auto create/update time --- schema/field.go | 4 ++-- tests/associations_test.go | 1 - tests/go.mod | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index 3b5cc5c55..77521ad35 100644 --- a/schema/field.go +++ b/schema/field.go @@ -275,7 +275,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = DataType(dataTyper.GormDataType()) } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoCreateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { @@ -287,7 +287,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoUpdateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { diff --git a/tests/associations_test.go b/tests/associations_test.go index bc3dac551..e729e979b 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -263,7 +263,6 @@ func TestSaveHasManyCircularReference(t *testing.T) { } func TestAssociationError(t *testing.T) { - DB = DB.Debug() user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2}) DB.Create(&user) diff --git a/tests/go.mod b/tests/go.mod index b85ebdadf..fc6600b72 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 // indirect + golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 From 38a24606da3cd1e312644ef5f8d71e4d0d35554a Mon Sep 17 00:00:00 2001 From: huangcheng1 Date: Sat, 2 Apr 2022 17:27:53 +0800 Subject: [PATCH 51/92] fix: tables lost when joins exists in from clause, close #5218 commit 7f6a603afa26820e187489b5203f93adc513687c Author: Jinzhu Date: Sat Apr 2 17:26:48 2022 +0800 Refactor #5218 commit 95d00e6ff2668233f3eca98aa4917291e3d869bd Author: huangcheng1 Date: Fri Apr 1 16:30:27 2022 +0800 fix: tables lost when joins exists in from clause --- callbacks/query.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 6eda52ef4..fb2bb37ad 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -96,12 +96,12 @@ func BuildQuerySQL(db *gorm.DB) { } // inline joins - joins := []clause.Join{} - if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { - joins = fromClause.Joins + fromClause := clause.From{} + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause = v } - if len(db.Statement.Joins) != 0 || len(joins) != 0 { + if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { @@ -111,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) { for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { @@ -176,19 +176,19 @@ func BuildQuerySQL(db *gorm.DB) { } } - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, }) } else { - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } } - db.Statement.AddClause(clause.From{Joins: joins}) + db.Statement.AddClause(fromClause) db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) From 81c4024232c35c3d49907f3ae77c2857a1dd7f63 Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 7 Apr 2022 21:56:41 +0600 Subject: [PATCH 52/92] Offset issue resolved for scanning results back into struct (#5227) --- scan.go | 2 +- tests/scan_test.go | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index c8da13da9..2ce6bd285 100644 --- a/scan.go +++ b/scan.go @@ -196,7 +196,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { if curIndex, ok := selectedColumnsMap[column]; ok { - for fieldIndex, selectField := range sch.Fields[curIndex:] { + for fieldIndex, selectField := range sch.Fields[curIndex+1:] { if selectField.DBName == column && selectField.Readable { selectedColumnsMap[column] = curIndex + fieldIndex + 1 fields[idx] = selectField diff --git a/tests/scan_test.go b/tests/scan_test.go index ec1e652fc..425c0a299 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -184,11 +184,34 @@ func TestScanToEmbedded(t *testing.T) { t.Errorf("Failed to run join query, got error: %v", err) } + personMatched := false + addressMatched := false + for _, info := range personAddressInfoList { - if info.Person != nil { - if info.Person.ID == person1.ID && info.Person.Name != person1.Name { + if info.Person == nil { + t.Fatalf("Failed, expected not nil, got person nil") + } + if info.Address == nil { + t.Fatalf("Failed, expected not nil, got address nil") + } + if info.Person.ID == person1.ID { + personMatched = true + if info.Person.Name != person1.Name { t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) } } + if info.Address.ID == address1.ID { + addressMatched = true + if info.Address.Name != address1.Name { + t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name) + } + } + } + + if !personMatched { + t.Errorf("Failed, no person matched") + } + if !addressMatched { + t.Errorf("Failed, no address matched") } } From 0729261b627d0f73ab0e9bccc5b548d5e55fae88 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 8 Apr 2022 14:23:25 +0800 Subject: [PATCH 53/92] Support double ptr for Save --- finisher_api.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index aa8e2b5ad..5e4c3c5a5 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -74,6 +74,10 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Dest = value reflectValue := reflect.Indirect(reflect.ValueOf(value)) + for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { + reflectValue = reflect.Indirect(reflectValue) + } + switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { From 5c9ef9a8435334236662009c21d95c4bcc15a532 Mon Sep 17 00:00:00 2001 From: Naveen <172697+naveensrinivasan@users.noreply.github.com> Date: Sat, 9 Apr 2022 20:38:43 -0500 Subject: [PATCH 54/92] Set permissions for GitHub actions (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restrict the GitHub token permissions only to the required ones; this way, even if the attackers will succeed in compromising your workflow, they won’t be able to do much. - Included permissions for the action. https://github.com/ossf/scorecard/blob/main/docs/checks.md#token-permissions https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#permissions https://docs.github.com/en/actions/using-jobs/assigning-permissions-to-jobs [Keeping your GitHub Actions and workflows secure Part 1: Preventing pwn requests](https://securitylab.github.com/research/github-actions-preventing-pwn-requests/) Signed-off-by: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 6 ++++++ .github/workflows/missing_playground.yml | 6 ++++++ .github/workflows/stale.yml | 6 ++++++ .github/workflows/tests.yml | 3 +++ 4 files changed, 21 insertions(+) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index 868bcc348..327a70f65 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -3,8 +3,14 @@ on: schedule: - cron: "*/10 * * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 3efc90f74..15d3850f4 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -3,8 +3,14 @@ on: schedule: - cron: "*/10 * * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index e0be186fa..c5e0d7ab2 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -3,8 +3,14 @@ on: schedule: - cron: "0 2 * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8194e6094..8bfb23329 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,6 +8,9 @@ on: branches-ignore: - 'gh-pages' +permissions: + contents: read + jobs: # Label of the container job sqlite: From 41bef26f137fb1633b937482011c2266b4123a41 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Apr 2022 21:37:02 +0800 Subject: [PATCH 55/92] Remove shared sync pool for Scanner compatibility --- schema/field.go | 23 ----------------------- schema/pool.go | 45 +-------------------------------------------- tests/go.mod | 11 +++++------ 3 files changed, 6 insertions(+), 73 deletions(-) diff --git a/schema/field.go b/schema/field.go index 77521ad35..fd8b2e6ad 100644 --- a/schema/field.go +++ b/schema/field.go @@ -932,7 +932,6 @@ func (field *Field) setupValuerAndSetter() { } func (field *Field) setupNewValuePool() { - var fieldValue = reflect.New(field.FieldType).Interface() if field.Serializer != nil { field.NewValuePool = &sync.Pool{ New: func() interface{} { @@ -942,31 +941,9 @@ func (field *Field) setupNewValuePool() { } }, } - } else if _, ok := fieldValue.(sql.Scanner); !ok { - field.setupDefaultNewValuePool() } if field.NewValuePool == nil { field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) } } - -func (field *Field) setupDefaultNewValuePool() { - // set default NewValuePool - switch field.IndirectFieldType.Kind() { - case reflect.String: - field.NewValuePool = stringPool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.NewValuePool = intPool - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.NewValuePool = uintPool - case reflect.Float32, reflect.Float64: - field.NewValuePool = floatPool - case reflect.Bool: - field.NewValuePool = boolPool - default: - if field.IndirectFieldType == TimeReflectType { - field.NewValuePool = timePool - } - } -} diff --git a/schema/pool.go b/schema/pool.go index f5c73153d..fa62fe223 100644 --- a/schema/pool.go +++ b/schema/pool.go @@ -3,54 +3,11 @@ package schema import ( "reflect" "sync" - "time" ) // sync pools var ( - normalPool sync.Map - stringPool = &sync.Pool{ - New: func() interface{} { - var v string - ptrV := &v - return &ptrV - }, - } - intPool = &sync.Pool{ - New: func() interface{} { - var v int64 - ptrV := &v - return &ptrV - }, - } - uintPool = &sync.Pool{ - New: func() interface{} { - var v uint64 - ptrV := &v - return &ptrV - }, - } - floatPool = &sync.Pool{ - New: func() interface{} { - var v float64 - ptrV := &v - return &ptrV - }, - } - boolPool = &sync.Pool{ - New: func() interface{} { - var v bool - ptrV := &v - return &ptrV - }, - } - timePool = &sync.Pool{ - New: func() interface{} { - var v time.Time - ptrV := &v - return &ptrV - }, - } + normalPool sync.Map poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ New: func() interface{} { diff --git a/tests/go.mod b/tests/go.mod index fc6600b72..3ac4633eb 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,15 +5,14 @@ go 1.14 require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 - github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.4 + github.com/lib/pq v1.10.5 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 // indirect - gorm.io/driver/mysql v1.3.2 - gorm.io/driver/postgres v1.3.1 + golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 // indirect + gorm.io/driver/mysql v1.3.3 + gorm.io/driver/postgres v1.3.4 gorm.io/driver/sqlite v1.3.1 - gorm.io/driver/sqlserver v1.3.1 + gorm.io/driver/sqlserver v1.3.2 gorm.io/gorm v1.23.3 ) From 74e07b049c446bd0f1102c9f7c164558648850bd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Apr 2022 22:07:40 +0800 Subject: [PATCH 56/92] Serializer unixtime support ptr of int --- schema/serializer.go | 8 ++++---- tests/serializer_test.go | 3 +++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 09da6d9ef..758a6421f 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -108,8 +108,8 @@ type UnixSecondSerializer struct { // Scan implements serializer interface func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { t := sql.NullTime{} - if err = t.Scan(dbValue); err == nil { - err = field.Set(ctx, dst, t.Time) + if err = t.Scan(dbValue); err == nil && t.Valid { + err = field.Set(ctx, dst, t.Time.Unix()) } return @@ -118,8 +118,8 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.ValueOf(v).Int(), 0) + case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0) default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } diff --git a/tests/serializer_test.go b/tests/serializer_test.go index ce60280ec..ee14841a9 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -21,6 +21,7 @@ type SerializerStruct struct { Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type EncryptedString EncryptedString } @@ -58,6 +59,7 @@ func TestSerializer(t *testing.T) { } createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt := createdAt.Unix() data := SerializerStruct{ Name: []byte("jinzhu"), @@ -65,6 +67,7 @@ func TestSerializer(t *testing.T) { Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, EncryptedString: EncryptedString("pass"), CreatedTime: createdAt.Unix(), + UpdatedTime: &updatedAt, JobInfo: Job{ Title: "programmer", Number: 9920, From 6aa6d37fc47a433510ac05e2f01eb33e57d7cb6c Mon Sep 17 00:00:00 2001 From: Filippo Del Moro Date: Wed, 13 Apr 2022 09:47:04 +0200 Subject: [PATCH 57/92] Fix scanIntoStruct (#5241) * Reproduces error case * Fix scanIntoStruct Co-authored-by: Filippo Del Moro --- scan.go | 2 +- tests/joins_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index 2ce6bd285..ad3734d89 100644 --- a/scan.go +++ b/scan.go @@ -74,7 +74,7 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - return + continue } relValue.Set(reflect.New(relValue.Type().Elem())) diff --git a/tests/joins_test.go b/tests/joins_test.go index bb5352ef6..4908e5ba4 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -10,12 +10,12 @@ import ( ) func TestJoins(t *testing.T) { - user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true}) + user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) DB.Create(&user) var user2 User - if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } From a65912c5887f850f6262dca68ca8d0dc10ca1bcc Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 13 Apr 2022 15:52:07 +0800 Subject: [PATCH 58/92] fix: FirstOrCreate RowsAffected (#5250) --- finisher_api.go | 3 +++ tests/create_test.go | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index 5e4c3c5a5..d35456a6b 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -326,6 +326,9 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } return tx.Model(dest).Updates(assigns) + } else { + // can not use Find RowsAffected + tx.RowsAffected = 0 } } return tx diff --git a/tests/create_test.go b/tests/create_test.go index 2b23d4409..3730172fd 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -526,3 +526,17 @@ func TestCreateNilPointer(t *testing.T) { t.Fatalf("it is not ErrInvalidValue") } } + +func TestFirstOrCreateRowsAffected(t *testing.T) { + user := User{Name: "TestFirstOrCreateRowsAffected"} + + res := DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 1 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } + + res = DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 0 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } +} From 771cbed755b0b61c9b5c00eea54c92b7774a17fc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Apr 2022 15:52:40 +0800 Subject: [PATCH 59/92] chore(deps): bump actions/stale from 4 to 5 (#5244) Bumps [actions/stale](https://github.com/actions/stale) from 4 to 5. - [Release notes](https://github.com/actions/stale/releases) - [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/stale/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/stale dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 2 +- .github/workflows/missing_playground.yml | 2 +- .github/workflows/stale.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index 327a70f65..aa1812d4f 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 15d3850f4..c3c92beb3 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c5e0d7ab2..af8d36368 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" From ce53ea53ee064d57c8a23eb4c7b5f2deed0eb410 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Apr 2022 15:53:12 +0800 Subject: [PATCH 60/92] chore(deps): bump actions/setup-go from 2 to 3 (#5243) Bumps [actions/setup-go](https://github.com/actions/setup-go) from 2 to 3. - [Release notes](https://github.com/actions/setup-go/releases) - [Commits](https://github.com/actions/setup-go/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/setup-go dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8bfb23329..b97da3f45 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} @@ -65,7 +65,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} @@ -109,7 +109,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} @@ -152,7 +152,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} From d421c67ef59259dc65737a639bee75b568ad5c17 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 Apr 2022 10:51:39 +0800 Subject: [PATCH 61/92] Remove ErrRecordNotFound error from log when using Save --- finisher_api.go | 2 +- tests/go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index d35456a6b..cbe927bfd 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -105,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) { + if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 { return tx.Create(value) } } diff --git a/tests/go.mod b/tests/go.mod index 3ac4633eb..0a3f85f9b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.5 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 // indirect + golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect gorm.io/driver/mysql v1.3.3 gorm.io/driver/postgres v1.3.4 gorm.io/driver/sqlite v1.3.1 From e0ed3ce400c8cb774ad03bd6c1a5028e6c425988 Mon Sep 17 00:00:00 2001 From: ZhangShenao <15201440436@163.com> Date: Thu, 14 Apr 2022 20:32:57 +0800 Subject: [PATCH 62/92] fix spelling mistake (#5256) Co-authored-by: Shenao Zhang --- callbacks/helper.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index 71b67de59..ae9fd8c56 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -125,7 +125,7 @@ func checkMissingWhereConditions(db *gorm.DB) { type visitMap = map[reflect.Value]bool // Check if circular values, return true if loaded -func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { +func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) { if v.Kind() == reflect.Ptr { v = v.Elem() } @@ -134,17 +134,17 @@ func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { case reflect.Slice, reflect.Array: loaded = true for i := 0; i < v.Len(); i++ { - if !loadOrStoreVisitMap(vistMap, v.Index(i)) { + if !loadOrStoreVisitMap(visitMap, v.Index(i)) { loaded = false } } case reflect.Struct, reflect.Interface: if v.CanAddr() { p := v.Addr() - if _, ok := (*vistMap)[p]; ok { + if _, ok := (*visitMap)[p]; ok { return true } - (*vistMap)[p] = true + (*visitMap)[p] = true } } From b49ae84780b212f2460938c74ee41a43a46b1834 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 17 Apr 2022 09:58:33 +0800 Subject: [PATCH 63/92] fix: FindInBatches with offset limit (#5255) * fix: FindInBatches with offset limit * fix: break first * fix: FindInBatches Limit zero --- finisher_api.go | 24 ++++++++++++++++++ tests/query_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index cbe927bfd..0bd8f7d99 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -181,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat batch int ) + // user specified offset or limit + var totalSize int + if c, ok := tx.Statement.Clauses["LIMIT"]; ok { + if limit, ok := c.Expression.(clause.Limit); ok { + totalSize = limit.Limit + + if totalSize > 0 && batchSize > totalSize { + batchSize = totalSize + } + + // reset to offset to 0 in next batch + tx = tx.Offset(-1).Session(&Session{}) + } + } + for { result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected @@ -196,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } + if totalSize > 0 { + if totalSize <= int(rowsAffected) { + break + } + if totalSize/batchSize == batch { + batchSize = totalSize % batchSize + } + } + // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) if result.Statement.Schema.PrioritizedPrimaryField == nil { diff --git a/tests/query_test.go b/tests/query_test.go index af2b8d4b4..f66cf83a4 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithOffsetLimit(t *testing.T) { + users := []User{ + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + } + + DB.Create(&users) + + var ( + sub, results []User + lastBatch int + ) + + // offset limit + if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { + results = append(results, sub...) + lastBatch = batch + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + if lastBatch != 3 { + t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch) + } + + targetUsers := users[3:8] + for i := 0; i < len(targetUsers); i++ { + AssertEqual(t, results[i], targetUsers[i]) + } + + var sub1 []User + // limit < batchSize + if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub2 []User + // only offset + if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 7 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub3 []User + if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 4 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } +} + func TestFindInBatchesWithError(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlserver" { t.Skip("skip sqlserver due to it will raise data race for invalid sql") From 88c26b62ee63863932e001be21e05a4ef43d03c2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 20 Apr 2022 17:21:38 +0800 Subject: [PATCH 64/92] Support Scopes in group conditions --- statement.go | 4 ++++ tests/sql_builder_test.go | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/statement.go b/statement.go index 9fcee09c0..d0c691d8e 100644 --- a/statement.go +++ b/statement.go @@ -312,6 +312,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: + for _, scope := range v.Statement.scopes { + v = scope(v) + } + if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a7630271e..a9b920dcc 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -243,6 +243,21 @@ func TestGroupConditions(t *testing.T) { if !strings.HasSuffix(result, expects) { t.Errorf("expects: %v, got %v", expects, result) } + + stmt2 := dryRunDB.Where( + DB.Scopes(NameIn1And2), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement + + result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...) + expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...) + + if !strings.HasSuffix(result2, expects2) { + t.Errorf("expects: %v, got %v", expects2, result2) + } } func TestCombineStringConditions(t *testing.T) { From 395606ac7ce6c1fcd9bd9c79c16b73cb1bc13bc8 Mon Sep 17 00:00:00 2001 From: glebarez <47985861+glebarez@users.noreply.github.com> Date: Fri, 22 Apr 2022 06:19:33 +0300 Subject: [PATCH 65/92] fix missing error-check in AutoMigrate (#5283) --- migrator/migrator.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index a50bb3ff8..93f4c5d0c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -99,7 +99,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + columnTypes, err := m.DB.Migrator().ColumnTypes(value) + if err != nil { + return err + } for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] From 9b80fe9e96e6d9132f935a944a150777a3ffdf03 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 24 Apr 2022 09:08:52 +0800 Subject: [PATCH 66/92] fix: stmt.Changed zero value filed behavior (#5281) * fix: stmt.Changed zero value filed behavior * chore: rename var --- statement.go | 9 ++++++--- tests/hooks_test.go | 10 ++++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index d0c691d8e..ed3e8716d 100644 --- a/statement.go +++ b/statement.go @@ -609,10 +609,10 @@ func (stmt *Statement) Changed(fields ...string) bool { changed := func(field *schema.Field) bool { fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, ok := stmt.Dest.(map[string]interface{}); ok { - if fv, ok := v[field.Name]; ok { + if mv, mok := stmt.Dest.(map[string]interface{}); mok { + if fv, ok := mv[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) - } else if fv, ok := v[field.DBName]; ok { + } else if fv, ok := mv[field.DBName]; ok { return !utils.AssertEqual(fv, fieldValue) } } else { @@ -622,6 +622,9 @@ func (stmt *Statement) Changed(fields ...string) bool { } changedValue, zero := field.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } return !zero && !utils.AssertEqual(changedValue, fieldValue) } } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 0e6ab2fe2..20e8dc18c 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -375,13 +375,19 @@ func TestSetColumn(t *testing.T) { t.Errorf("invalid data after update, got %+v", product) } + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""}) + if product.Name != "Product New4" || product.Price != 320 || product.Code != "" { + t.Errorf("invalid data after update, got %+v", product) + } + DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) - if product.Price != 270 || product.Code != "L1215" { + if product.Price != 320 || product.Code != "L1215" { t.Errorf("invalid data after update, got %+v", product) } DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) - if product.Price != 270 || product.Code != "L1216" { + if product.Price != 320 || product.Code != "L1216" { t.Errorf("invalid data after update, got %+v", product) } From 3643f856a3edeaa4db7ede87a4bc2928d2aadc09 Mon Sep 17 00:00:00 2001 From: aelmel <5629597+aelmel@users.noreply.github.com> Date: Sun, 24 Apr 2022 04:10:36 +0300 Subject: [PATCH 67/92] check for pointer to pointer value (#5278) * check for pointer to pointer value * revert to Ptr Co-authored-by: Alexei Melnic --- schema/field.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/schema/field.go b/schema/field.go index fd8b2e6ad..d6df65965 100644 --- a/schema/field.go +++ b/schema/field.go @@ -528,6 +528,9 @@ func (field *Field) setupValuerAndSetter() { reflectValType := reflectV.Type() if reflectValType.AssignableTo(field.FieldType) { + if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr { + reflectV = reflect.Indirect(reflectV) + } field.ReflectValueOf(ctx, value).Set(reflectV) return } else if reflectValType.ConvertibleTo(field.FieldType) { From a0cc631272f44a18597c87b7910b660df729303e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 24 Apr 2022 12:13:27 +0800 Subject: [PATCH 68/92] test: test for postgrs serial column (#5234) * test: test for postgrs sercial column * test: only for postgres * chore: spelling mistake * test: for drop sequence --- tests/migrate_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index f72c4c085..d6a6c4db2 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -574,3 +574,65 @@ func TestMigrateColumnOrder(t *testing.T) { } } } + +// https://github.com/go-gorm/gorm/issues/5047 +func TestMigrateSerialColumn(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Event struct { + ID uint `gorm:"primarykey"` + UID uint32 + } + + type Event1 struct { + ID uint `gorm:"primarykey"` + UID uint32 `gorm:"not null;autoIncrement"` + } + + type Event2 struct { + ID uint `gorm:"primarykey"` + UID uint16 `gorm:"not null;autoIncrement"` + } + + var err error + err = DB.Migrator().DropTable(&Event{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + // create sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // delete sequence + err = DB.Table("events").AutoMigrate(&Event{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // update sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = DB.Table("events").AutoMigrate(&Event2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + + events := make([]*Event, 0) + DB.Table("events").Find(&events) + + AssertEqual(t, 3, len(events)) + for _, v := range events { + AssertEqual(t, v.ID, v.UID) + } +} From 0211ac91a2e2cbde5d6212e5f74a7344cb9795db Mon Sep 17 00:00:00 2001 From: Chiung-Ming Huang Date: Mon, 25 Apr 2022 11:39:23 +0800 Subject: [PATCH 69/92] index: add composite id (#5269) * index: add composite id * index: add test cases of composite id * index: improve the comments for the test cases of composite id --- schema/index.go | 26 ++++++++++++++++--- schema/index_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/schema/index.go b/schema/index.go index 16d096b76..5003c7428 100644 --- a/schema/index.go +++ b/schema/index.go @@ -1,6 +1,7 @@ package schema import ( + "fmt" "sort" "strconv" "strings" @@ -31,7 +32,12 @@ func (schema *Schema) ParseIndexes() map[string]Index { for _, field := range schema.Fields { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { - for _, index := range parseFieldIndexes(field) { + fieldIndexes, err := parseFieldIndexes(field) + if err != nil { + schema.err = err + break + } + for _, index := range fieldIndexes { idx := indexes[index.Name] idx.Name = index.Name if idx.Class == "" { @@ -82,7 +88,7 @@ func (schema *Schema) LookIndex(name string) *Index { return nil } -func parseFieldIndexes(field *Field) (indexes []Index) { +func parseFieldIndexes(field *Field) (indexes []Index, err error) { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { if value != "" { v := strings.Split(value, ":") @@ -106,7 +112,20 @@ func parseFieldIndexes(field *Field) (indexes []Index) { } if name == "" { - name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) + subName := field.Name + const key = "COMPOSITE" + if composite, found := settings[key]; found { + if len(composite) == 0 || composite == key { + err = fmt.Errorf( + "The composite tag of %s.%s cannot be empty", + field.Schema.Name, + field.Name) + return + } + subName = composite + } + name = field.Schema.namer.IndexName( + field.Schema.Table, subName) } if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { @@ -138,5 +157,6 @@ func parseFieldIndexes(field *Field) (indexes []Index) { } } + err = nil return } diff --git a/schema/index_test.go b/schema/index_test.go index 3c4582bb4..1fe31cc1b 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -19,6 +19,36 @@ type UserIndex struct { OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` Name7 string `gorm:"index:type"` + + // Composite Index: Flattened structure. + Data0A string `gorm:"index:,composite:comp_id0"` + Data0B string `gorm:"index:,composite:comp_id0"` + + // Composite Index: Nested structure. + Data1A string `gorm:"index:,composite:comp_id1"` + CompIdxLevel1C + + // Composite Index: Unique and priority. + Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"` + CompIdxLevel2C +} + +type CompIdxLevel1C struct { + CompIdxLevel1B + Data1C string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel1B struct { + Data1B string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel2C struct { + CompIdxLevel2B + Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"` +} + +type CompIdxLevel2B struct { + Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"` } func TestParseIndex(t *testing.T) { @@ -84,6 +114,36 @@ func TestParseIndex(t *testing.T) { Type: "", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, }, + "idx_user_indices_comp_id0": { + Name: "idx_user_indices_comp_id0", + Type: "", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data0A"}, + }, { + Field: &schema.Field{Name: "Data0B"}, + }}, + }, + "idx_user_indices_comp_id1": { + Name: "idx_user_indices_comp_id1", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data1A"}, + }, { + Field: &schema.Field{Name: "Data1B"}, + }, { + Field: &schema.Field{Name: "Data1C"}, + }}, + }, + "idx_user_indices_comp_id2": { + Name: "idx_user_indices_comp_id2", + Class: "UNIQUE", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data2C"}, + }, { + Field: &schema.Field{Name: "Data2A"}, + }, { + Field: &schema.Field{Name: "Data2B"}, + }}, + }, } indices := user.ParseIndexes() From 6a6dfdae72574e931ea4f0737637308ef2c34b8f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Apr 2022 17:16:48 +0800 Subject: [PATCH 70/92] Refactor FirstOrCreate, FirstOrInit --- finisher_api.go | 24 ++++++++++++------------ tests/go.mod | 7 +++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 0bd8f7d99..663d532bb 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -290,7 +290,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -312,25 +312,26 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - queryTx := db.Limit(1).Order(clause.OrderByColumn{ + tx = db.getInstance() + queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); tx.Error == nil { - if tx.RowsAffected == 0 { - if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if result := queryTx.Find(dest, conds...); result.Error == nil { + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignInterfacesToValue(where.Exprs) + result.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds - if len(tx.Statement.attrs) > 0 { - tx.assignInterfacesToValue(tx.Statement.attrs...) + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) } // initialize with attrs, conds - if len(tx.Statement.assigns) > 0 { - tx.assignInterfacesToValue(tx.Statement.assigns...) + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) } return tx.Create(dest) @@ -351,8 +352,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Model(dest).Updates(assigns) } else { - // can not use Find RowsAffected - tx.RowsAffected = 0 + tx.Error = result.Error } } return tx diff --git a/tests/go.mod b/tests/go.mod index 0a3f85f9b..6a2cf22fd 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,13 +7,12 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.5 - github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect gorm.io/driver/mysql v1.3.3 - gorm.io/driver/postgres v1.3.4 - gorm.io/driver/sqlite v1.3.1 + gorm.io/driver/postgres v1.3.5 + gorm.io/driver/sqlite v1.3.2 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.3 + gorm.io/gorm v1.23.4 ) replace gorm.io/gorm => ../ From bd7e42ec651f66539009371675bff38645b9b6b8 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 27 Apr 2022 21:13:48 +0800 Subject: [PATCH 71/92] fix: AutoMigrate with special table name (#5301) * fix: AutoMigrate with special table name * test: migrate with special table name --- migrator/migrator.go | 3 ++- tests/migrate_test.go | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 93f4c5d0c..d49894108 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -759,7 +759,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i Statement: &gorm.Statement{DB: m.DB, Dest: value}, } beDependedOn := map[*schema.Schema]bool{} - if err := dep.Parse(value); err != nil { + // support for special table name + if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } if _, ok := parsedSchemas[dep.Statement.Schema]; ok { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index d6a6c4db2..6576a2bd8 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -636,3 +636,14 @@ func TestMigrateSerialColumn(t *testing.T) { AssertEqual(t, v.ID, v.UID) } } + +// https://github.com/go-gorm/gorm/issues/5300 +func TestMigrateWithSpecialName(t *testing.T) { + DB.AutoMigrate(&Coupon{}) + DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + + AssertEqual(t, true, DB.Migrator().HasTable("coupons")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) +} From d3488ae6bcee8ccbb1e463a42a048e1958c4c90f Mon Sep 17 00:00:00 2001 From: Heliner <32272517+Heliner@users.noreply.github.com> Date: Sat, 30 Apr 2022 09:50:53 +0800 Subject: [PATCH 72/92] fix: add judge result of auto_migrate (#5306) Co-authored-by: fredhan --- tests/migrate_test.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 6576a2bd8..28ee28cb5 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -639,9 +639,19 @@ func TestMigrateSerialColumn(t *testing.T) { // https://github.com/go-gorm/gorm/issues/5300 func TestMigrateWithSpecialName(t *testing.T) { - DB.AutoMigrate(&Coupon{}) - DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) - DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + var err error + err = DB.AutoMigrate(&Coupon{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } AssertEqual(t, true, DB.Migrator().HasTable("coupons")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) From b0104943edf50bba6072d18ca91e949ff8d4e3a2 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 30 Apr 2022 09:57:16 +0800 Subject: [PATCH 73/92] fix: callbcak sort when using multiple plugin (#5304) --- callbacks.go | 8 +++++++- tests/callbacks_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index f344649ea..c060ea709 100644 --- a/callbacks.go +++ b/callbacks.go @@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { sortCallback func(*callback) error ) sort.Slice(cs, func(i, j int) bool { - return cs[j].before == "*" || cs[j].after == "*" + if cs[j].before == "*" && cs[i].before != "*" { + return true + } + if cs[j].after == "*" && cs[i].after != "*" { + return true + } + return false }) for _, c := range cs { diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 02765b8c4..2bf9496b9 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -38,6 +38,7 @@ func c2(*gorm.DB) {} func c3(*gorm.DB) {} func c4(*gorm.DB) {} func c5(*gorm.DB) {} +func c6(*gorm.DB) {} func TestCallbacks(t *testing.T) { type callback struct { @@ -168,3 +169,37 @@ func TestCallbacks(t *testing.T) { } } } + +func TestPluginCallbacks(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("plugin_1_fn1", c1) + createCallback.After("*").Register("plugin_1_fn2", c2) + + if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 2 + createCallback.Before("*").Register("plugin_2_fn1", c3) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_2_fn2", c4) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 3 + createCallback.Before("*").Register("plugin_3_fn1", c5) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_3_fn2", c6) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} From 19b8d37ae8155667d76021e4ca3314bb571756be Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 4 May 2022 18:57:53 +0800 Subject: [PATCH 74/92] fix: preload with skip hooks (#5310) --- callbacks/query.go | 2 +- tests/hooks_test.go | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index fb2bb37ad..26ee8c348 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -252,7 +252,7 @@ func Preload(db *gorm.DB) { for _, name := range preloadNames { if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { - db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 20e8dc18c..8e964fd8a 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -466,8 +466,9 @@ type Product4 struct { type ProductItem struct { gorm.Model - Code string - Product4ID uint + Code string + Product4ID uint + AfterFindCallTimes int } func (pi ProductItem) BeforeCreate(*gorm.DB) error { @@ -477,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error { return nil } +func (pi *ProductItem) AfterFind(*gorm.DB) error { + pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1 + return nil +} + func TestFailedToSaveAssociationShouldRollback(t *testing.T) { DB.Migrator().DropTable(&Product4{}, &ProductItem{}) DB.AutoMigrate(&Product4{}, &ProductItem{}) @@ -498,4 +504,13 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) { if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { t.Errorf("should find product, but got error %v", err) } + + var productWithItem Product4 + if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } + + if productWithItem.Item.AfterFindCallTimes != 0 { + t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) + } } From 373bcf7aca01ef76c8ba5c3bc1ff191b020afc7b Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 9 May 2022 10:07:18 +0800 Subject: [PATCH 75/92] fix: many2many auto migrate (#5322) * fix: many2many auto migrate * fix: uuid ossp --- schema/relationship.go | 6 ++++-- schema/utils.go | 9 +++++++++ tests/migrate_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index b51008979..0aa33e518 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -235,7 +235,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -258,7 +259,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } diff --git a/schema/utils.go b/schema/utils.go index 2720c5304..acf1a739b 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -2,6 +2,7 @@ package schema import ( "context" + "fmt" "reflect" "regexp" "strings" @@ -59,6 +60,14 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct return tag } +func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { + t := tag.Get("gorm") + if strings.Contains(t, value) { + return tag + } + return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) +} + // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 28ee28cb5..f862eda0f 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -657,3 +657,39 @@ func TestMigrateWithSpecialName(t *testing.T) { AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) } + +// https://github.com/go-gorm/gorm/issues/5320 +func TestPrimarykeyID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type MissPKLanguage struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + Name string + } + + type MissPKUser struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` + } + + var err error + err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("DropTable err:%v", err) + } + + DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) + + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // patch + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } +} From f5e77aab2fd3886f8743d6c9da87d5171f31a521 Mon Sep 17 00:00:00 2001 From: black-06 Date: Tue, 17 May 2022 10:59:53 +0800 Subject: [PATCH 76/92] fix: quote index when creating table (#5331) --- migrator/migrator.go | 2 +- tests/migrate_test.go | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index d49894108..757ab9494 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -223,7 +223,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { } createTableSQL += "," - values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index f862eda0f..12eb8ed05 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -262,6 +262,25 @@ func TestMigrateTable(t *testing.T) { } } +func TestMigrateWithQuotedIndex(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type QuotedIndexStruct struct { + gorm.Model + Name string `gorm:"size:255;index:AS"` // AS is one of MySQL reserved words + } + + if err := DB.Migrator().DropTable(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + func TestMigrateIndexes(t *testing.T) { type IndexStruct struct { gorm.Model From 7496c3a56eb4a26679a0a47db092e51379a98ff5 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 17 May 2022 14:13:41 +0800 Subject: [PATCH 77/92] fix: trx in hooks clone stmt (#5338) * fix: trx in hooks * chore: format by gofumpt --- finisher_api.go | 3 +-- tests/transaction_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 663d532bb..da4ef8f76 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -589,8 +589,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() } - - err = fc(db.Session(&Session{})) + err = fc(db.Session(&Session{NewDB: db.clone == 1})) } else { tx := db.Begin(opts...) if tx.Error != nil { diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 4e4b61494..0ac04a047 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -367,3 +367,33 @@ func TestTransactionOnClosedConn(t *testing.T) { t.Errorf("should returns error when commit with closed conn, got error %v", err) } } + +func TestTransactionWithHooks(t *testing.T) { + user := GetUser("tTestTransactionWithHooks", Config{Account: true}) + DB.Create(&user) + + var err error + err = DB.Transaction(func(tx *gorm.DB) error { + return tx.Model(&User{}).Limit(1).Transaction(func(tx2 *gorm.DB) error { + return tx2.Scan(&User{}).Error + }) + }) + + if err != nil { + t.Error(err) + } + + // method with hooks + err = DB.Transaction(func(tx1 *gorm.DB) error { + // callMethod do + tx2 := tx1.Find(&User{}).Session(&gorm.Session{NewDB: true}) + // trx in hooks + return tx2.Transaction(func(tx3 *gorm.DB) error { + return tx3.Where("user_id", user.ID).Delete(&Account{}).Error + }) + }) + + if err != nil { + t.Error(err) + } +} From 540fb49bcbe07ee56c7a8a449a5504f40f50abc1 Mon Sep 17 00:00:00 2001 From: Clark McCauley Date: Sun, 22 May 2022 01:16:01 -0600 Subject: [PATCH 78/92] Fixed #5355 - Named variables don't work when followed by Windows CRLF line endings (#5356) * Fixed #5355. * Fixed unit test to test both CRLF and CR line endings --- clause/expression.go | 2 +- clause/expression_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/clause/expression.go b/clause/expression.go index dde00b1d7..92ac7f223 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -127,7 +127,7 @@ func (expr NamedExpr) Build(builder Builder) { if v == '@' && !inName { inName = true name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' { + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) diff --git a/clause/expression_test.go b/clause/expression_test.go index 4826db381..aaede61c5 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -94,6 +94,16 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?;", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r\n AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r\n AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, From 7d1a92d60e7df38fdc2f3e42ff1cc7842aefdf18 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 22 May 2022 16:12:28 +0800 Subject: [PATCH 79/92] test: test for skip prepared when auto migrate (#5350) --- tests/migrate_test.go | 36 ++++++++++++++++++++++++++++++++++++ tests/tests_test.go | 11 ++++++++--- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 12eb8ed05..2b5d7ecd3 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" @@ -712,3 +713,38 @@ func TestPrimarykeyID(t *testing.T) { t.Fatalf("AutoMigrate err:%v", err) } } + +func TestInvalidCachedPlan(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) + if err != nil { + t.Errorf("Open err:%v", err) + } + + type Object1 struct{} + type Object2 struct { + Field1 string + } + type Object3 struct { + Field2 string + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } +} diff --git a/tests/tests_test.go b/tests/tests_test.go index 08f4f1932..dcba3cbf5 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -17,6 +17,11 @@ import ( ) var DB *gorm.DB +var ( + mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" +) func init() { var err error @@ -49,13 +54,13 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "mysql": log.Println("testing mysql...") if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + dbDSN = mysqlDSN } db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) case "postgres": log.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + dbDSN = postgresDSN } db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, @@ -72,7 +77,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { // GO log.Println("testing sqlserver...") if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + dbDSN = sqlserverDSN } db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) default: From 7e13b03bd4e57a554d3daa2774d3f58102ac30d9 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 28 May 2022 22:18:07 +0800 Subject: [PATCH 80/92] fix: duplicate column scan (#5369) * fix: duplicate column scan * fix: dup filed in inconsistent schema and database * chore[ci skip]: gofumpt style * chore[ci skip]: fix typo --- scan.go | 17 ++++++++++++----- tests/scan_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/scan.go b/scan.go index ad3734d89..a611a9ce8 100644 --- a/scan.go +++ b/scan.go @@ -193,14 +193,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) { // Not Pluck if sch != nil { + schFieldsCount := len(sch.Fields) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { if curIndex, ok := selectedColumnsMap[column]; ok { - for fieldIndex, selectField := range sch.Fields[curIndex+1:] { - if selectField.DBName == column && selectField.Readable { - selectedColumnsMap[column] = curIndex + fieldIndex + 1 - fields[idx] = selectField - break + fields[idx] = field // handle duplicate fields + offset := curIndex + 1 + // handle sch inconsistent with database + // like Raw(`...`).Scan + if schFieldsCount > offset { + for fieldIndex, selectField := range sch.Fields[offset:] { + if selectField.DBName == column && selectField.Readable { + selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = selectField + break + } } } } else { diff --git a/tests/scan_test.go b/tests/scan_test.go index 425c0a299..6f2e9f54d 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -214,4 +214,29 @@ func TestScanToEmbedded(t *testing.T) { if !addressMatched { t.Errorf("Failed, no address matched") } + + personDupField := Person{ID: person1.ID} + if err := DB.Select("people.id, people.*"). + First(&personDupField).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + AssertEqual(t, person1, personDupField) + + user := User{ + Name: "TestScanToEmbedded_1", + Manager: &User{ + Name: "TestScanToEmbedded_1_m1", + Manager: &User{Name: "TestScanToEmbedded_1_m1_m1"}, + }, + } + DB.Create(&user) + + type UserScan struct { + ID uint + Name string + ManagerID *uint + } + var user2 UserScan + err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error + AssertEqual(t, err, nil) } From dc1ae394f329340cb4475b037fe9f98bdbf7176d Mon Sep 17 00:00:00 2001 From: "t-inagaki@hum_op" Date: Sat, 28 May 2022 23:18:43 +0900 Subject: [PATCH 81/92] fixed FirstOrCreate not handled error when table is not exists (#5367) * fixed FirstOrCreate not handled error when table is not exists * delete useless part --- finisher_api.go | 4 ++-- tests/create_test.go | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index da4ef8f76..7a3f27bae 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -351,9 +351,9 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } return tx.Model(dest).Updates(assigns) - } else { - tx.Error = result.Error } + } else { + tx.Error = result.Error } return tx } diff --git a/tests/create_test.go b/tests/create_test.go index 3730172fd..274a7f486 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -476,6 +476,13 @@ func TestOmitWithCreate(t *testing.T) { CheckUser(t, result2, user2) } +func TestFirstOrCreateNotExistsTable(t *testing.T) { + company := Company{Name: "first_or_create_if_not_exists_table"} + if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil { + t.Errorf("not exists table, but err is nil") + } +} + func TestFirstOrCreateWithPrimaryKey(t *testing.T) { company := Company{ID: 100, Name: "company100_with_primarykey"} DB.FirstOrCreate(&company) From 93986de8e43bc9af6864621c9a4855f0f860cde2 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 28 May 2022 23:09:13 +0800 Subject: [PATCH 82/92] fix: migrate column default value (#5359) Co-authored-by: Jinzhu --- migrator/migrator.go | 16 ++++- tests/migrate_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 757ab9494..4acc9df60 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -448,10 +448,20 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } // check default value - if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { - // not primary key - if !field.PrimaryKey { + if !field.PrimaryKey { + dv, dvNotNull := columnType.DefaultValue() + if dvNotNull && field.DefaultValueInterface == nil { + // defalut value -> null + alterColumn = true + } else if !dvNotNull && field.DefaultValueInterface != nil { + // null -> default value alterColumn = true + } else if dv != field.DefaultValue { + // default value not equal + // not both null + if !(field.DefaultValueInterface == nil && !dvNotNull) { + alterColumn = true + } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2b5d7ecd3..9e7caec9f 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "fmt" "math/rand" "reflect" "strings" @@ -714,6 +715,141 @@ func TestPrimarykeyID(t *testing.T) { } } +func TestUniqueColumn(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + + type UniqueTest struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique"` + } + + type UniqueTest2 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:NULL"` + } + + type UniqueTest3 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:''"` + } + + type UniqueTest4 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:'123'"` + } + + var err error + err = DB.Migrator().DropTable(&UniqueTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // null -> null + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok := ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // not trigger alert column + AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> empty string + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, true, ok) + + // empty string -> 123 + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest4{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "123", value) + AssertEqual(t, true, ok) + + // 123 -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + +} + +func findColumnType(dest interface{}, columnName string) ( + foundColumn gorm.ColumnType, err error) { + columnTypes, err := DB.Migrator().ColumnTypes(dest) + if err != nil { + err = fmt.Errorf("ColumnTypes err:%v", err) + return + } + + for _, c := range columnTypes { + if c.Name() == columnName { + foundColumn = c + break + } + } + return +} + func TestInvalidCachedPlan(t *testing.T) { if DB.Dialector.Name() != "postgres" { return From f4e9904b02dab5c2f675d9c661ae1c1a8654a768 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 1 Jun 2022 10:26:09 +0800 Subject: [PATCH 83/92] chore(deps): bump gorm.io/driver/mysql from 1.3.3 to 1.3.4 in /tests (#5385) Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.3.3 to 1.3.4. - [Release notes](https://github.com/go-gorm/mysql/releases) - [Commits](https://github.com/go-gorm/mysql/compare/v1.3.3...v1.3.4) --- updated-dependencies: - dependency-name: gorm.io/driver/mysql dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index 6a2cf22fd..bd668420b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.5 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect - gorm.io/driver/mysql v1.3.3 + gorm.io/driver/mysql v1.3.4 gorm.io/driver/postgres v1.3.5 gorm.io/driver/sqlite v1.3.2 gorm.io/driver/sqlserver v1.3.2 From d01de7232b46987e239ef19a89d9ab192f453894 Mon Sep 17 00:00:00 2001 From: Bexanderthebex Date: Wed, 1 Jun 2022 11:50:57 +0800 Subject: [PATCH 84/92] enhancement: Avoid calling reflect.New() when passing in slice of values to `Scan()` (#5388) * fix: reduce allocations when slice of values * chore[test]: Add benchmark for scan * chore[test]: add bench for scan slice * chore[test]: add bench for slice pointer and improve tests * chore[test]: make sure database is empty when doing slice tests * fix[test]: correct sql delete statement * enhancement: skip new if rows affected = 0 --- scan.go | 7 ++++++- tests/benchmark_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index a611a9ce8..1bb51560a 100644 --- a/scan.go +++ b/scan.go @@ -237,6 +237,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var elem reflect.Value + recyclableStruct := reflect.New(reflectValueType) if !update || reflectValue.Len() == 0 { update = false @@ -261,7 +262,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } } } else { - elem = reflect.New(reflectValueType) + if isPtr && db.RowsAffected > 0 { + elem = reflect.New(reflectValueType) + } else { + elem = recyclableStruct + } } db.scanIntoStruct(rows, elem, values, fields, joinFields) diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go index d897a6341..22d15898e 100644 --- a/tests/benchmark_test.go +++ b/tests/benchmark_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "fmt" "testing" . "gorm.io/gorm/utils/tests" @@ -24,6 +25,45 @@ func BenchmarkFind(b *testing.B) { } } +func BenchmarkScan(b *testing.B) { + user := *GetUser("scan", Config{}) + DB.Create(&user) + + var u User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users where id = ?", user.ID).Scan(&u) + } +} + +func BenchmarkScanSlice(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + +func BenchmarkScanSlicePointer(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []*User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + func BenchmarkUpdate(b *testing.B) { user := *GetUser("find", Config{}) DB.Create(&user) From 8d457146283e0a4197c26a559bedb1938767b78e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 14 Jun 2022 13:48:50 +0800 Subject: [PATCH 85/92] fix: reset null value in slice (#5417) * fix: reset null value in slice * fix: can not set field in-place in join --- scan.go | 17 ++++++---- schema/field.go | 10 ++++++ tests/query_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 6 deletions(-) diff --git a/scan.go b/scan.go index 1bb51560a..6250fb576 100644 --- a/scan.go +++ b/scan.go @@ -66,18 +66,23 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) + joinedSchemaMap := make(map[*schema.Field]interface{}, 0) for idx, field := range fields { if field != nil { if len(joinFields) == 0 || joinFields[idx][0] == nil { db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) } else { - relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } + joinSchema := joinFields[idx][0] + relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr { + if _, ok := joinedSchemaMap[joinSchema]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } - relValue.Set(reflect.New(relValue.Type().Elem())) + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedSchemaMap[joinSchema] = nil + } } db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } diff --git a/schema/field.go b/schema/field.go index d6df65965..981f56f2a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -587,6 +587,8 @@ func (field *Field) setupValuerAndSetter() { case **bool: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetBool(**data) + } else { + field.ReflectValueOf(ctx, value).SetBool(false) } case bool: field.ReflectValueOf(ctx, value).SetBool(data) @@ -606,6 +608,8 @@ func (field *Field) setupValuerAndSetter() { case **int64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) + } else { + field.ReflectValueOf(ctx, value).SetInt(0) } case int64: field.ReflectValueOf(ctx, value).SetInt(data) @@ -670,6 +674,8 @@ func (field *Field) setupValuerAndSetter() { case **uint64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) + } else { + field.ReflectValueOf(ctx, value).SetUint(0) } case uint64: field.ReflectValueOf(ctx, value).SetUint(data) @@ -722,6 +728,8 @@ func (field *Field) setupValuerAndSetter() { case **float64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) + } else { + field.ReflectValueOf(ctx, value).SetFloat(0) } case float64: field.ReflectValueOf(ctx, value).SetFloat(data) @@ -766,6 +774,8 @@ func (field *Field) setupValuerAndSetter() { case **string: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetString(**data) + } else { + field.ReflectValueOf(ctx, value).SetString("") } case string: field.ReflectValueOf(ctx, value).SetString(data) diff --git a/tests/query_test.go b/tests/query_test.go index f66cf83a4..253d84092 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1258,3 +1258,80 @@ func TestQueryScannerWithSingleColumn(t *testing.T) { AssertEqual(t, result2.data, 20) } + +func TestQueryResetNullValue(t *testing.T) { + type QueryResetItem struct { + ID string `gorm:"type:varchar(5)"` + Name string + } + + type QueryResetNullValue struct { + ID int + Name string `gorm:"default:NULL"` + Flag bool `gorm:"default:NULL"` + Number1 int64 `gorm:"default:NULL"` + Number2 uint64 `gorm:"default:NULL"` + Number3 float64 `gorm:"default:NULL"` + Now *time.Time `gorm:"defalut:NULL"` + Item1Id string + Item1 *QueryResetItem `gorm:"references:ID"` + Item2Id string + Item2 *QueryResetItem `gorm:"references:ID"` + } + + DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{}) + DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{}) + + now := time.Now() + q1 := QueryResetNullValue{ + Name: "name", + Flag: true, + Number1: 100, + Number2: 200, + Number3: 300.1, + Now: &now, + Item1: &QueryResetItem{ + ID: "u_1_1", + Name: "item_1_1", + }, + Item2: &QueryResetItem{ + ID: "u_1_2", + Name: "item_1_2", + }, + } + + q2 := QueryResetNullValue{ + Item1: &QueryResetItem{ + ID: "u_2_1", + Name: "item_2_1", + }, + Item2: &QueryResetItem{ + ID: "u_2_2", + Name: "item_2_2", + }, + } + + var err error + err = DB.Create(&q1).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + err = DB.Create(&q2).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + var qs []QueryResetNullValue + err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error + if err != nil { + t.Errorf("failed to find:%v", err) + } + + if len(qs) != 2 { + t.Fatalf("find count not equal:%d", len(qs)) + } + + AssertEqual(t, q1, qs[0]) + AssertEqual(t, q2, qs[1]) +} From 1305f637f834baa13c514df915157a51d86b4f28 Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Fri, 17 Jun 2022 11:00:57 +0800 Subject: [PATCH 86/92] feat: add method GetIndexes (#5436) * feat: add method GetIndexes * feat: add default impl for Index interface * feat: fmt --- migrator.go | 10 ++++++++++ migrator/index.go | 43 +++++++++++++++++++++++++++++++++++++++++++ migrator/migrator.go | 6 ++++++ 3 files changed, 59 insertions(+) create mode 100644 migrator/index.go diff --git a/migrator.go b/migrator.go index 524438770..34e888f2b 100644 --- a/migrator.go +++ b/migrator.go @@ -51,6 +51,15 @@ type ColumnType interface { DefaultValue() (value string, ok bool) } +type Index interface { + Table() string + Name() string + Columns() []string + PrimaryKey() (isPrimaryKey bool, ok bool) + Unique() (unique bool, ok bool) + Option() string +} + // Migrator migrator interface type Migrator interface { // AutoMigrate @@ -90,4 +99,5 @@ type Migrator interface { DropIndex(dst interface{}, name string) error HasIndex(dst interface{}, name string) bool RenameIndex(dst interface{}, oldName, newName string) error + GetIndexes(dst interface{}) ([]Index, error) } diff --git a/migrator/index.go b/migrator/index.go new file mode 100644 index 000000000..fe686e5af --- /dev/null +++ b/migrator/index.go @@ -0,0 +1,43 @@ +package migrator + +import "database/sql" + +// Index implements gorm.Index interface +type Index struct { + TableName string + NameValue string + ColumnList []string + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + OptionValue string +} + +// Table return the table name of the index. +func (idx Index) Table() string { + return idx.TableName +} + +// Name return the name of the index. +func (idx Index) Name() string { + return idx.NameValue +} + +// Columns return the columns fo the index +func (idx Index) Columns() []string { + return idx.ColumnList +} + +// PrimaryKey returns the index is primary key or not. +func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) { + return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid +} + +// Unique returns whether the index is unique or not. +func (idx Index) Unique() (unique bool, ok bool) { + return idx.UniqueValue.Bool, idx.UniqueValue.Valid +} + +// Option return the optional attribute fo the index +func (idx Index) Option() string { + return idx.OptionValue +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 4acc9df60..f20bf5134 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -3,6 +3,7 @@ package migrator import ( "context" "database/sql" + "errors" "fmt" "reflect" "regexp" @@ -854,3 +855,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { } return clause.Table{Name: stmt.Table} } + +// GetIndexes return Indexes []gorm.Index and execErr error +func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { + return nil, errors.New("not support") +} From a70af2a4c0d7bd66d76999f142a9babb438e53d7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 20 Jun 2022 15:35:29 +0800 Subject: [PATCH 87/92] Fix Select with digits in column name --- statement.go | 2 +- statement_test.go | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index ed3e8716d..850af6cbc 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_0-9]+?)[\W]?\.[\W]?([a-z_0-9]+?)[\W]?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/statement_test.go b/statement_test.go index 3f099d611..a89cc7d29 100644 --- a/statement_test.go +++ b/statement_test.go @@ -37,10 +37,14 @@ func TestWhereCloneCorruption(t *testing.T) { func TestNameMatcher(t *testing.T) { for k, v := range map[string]string{ - "table.name": "name", - "`table`.`name`": "name", - "'table'.'name'": "name", - "'table'.name": "name", + "table.name": "name", + "`table`.`name`": "name", + "'table'.'name'": "name", + "'table'.name": "name", + "table1.name_23": "name_23", + "`table_1`.`name23`": "name23", + "'table23'.'name_1'": "name_1", + "'table23'.name1": "name1", } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) From 93f28bc116526ba4decdd969a7b2b0b245ad70f1 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 24 Jun 2022 10:33:39 +0800 Subject: [PATCH 88/92] use callback to handle transaction - make transaction have before and after hooks, so plugin can have hack before or after transaction --- callbacks.go | 37 +++++++++++++++++++++++++++++++------ finisher_api.go | 16 +--------------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/callbacks.go b/callbacks.go index c060ea709..1b4e58ea1 100644 --- a/callbacks.go +++ b/callbacks.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "database/sql" "errors" "fmt" "reflect" @@ -15,12 +16,13 @@ import ( func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": {db: db}, - "query": {db: db}, - "update": {db: db}, - "delete": {db: db}, - "row": {db: db}, - "raw": {db: db}, + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + "transaction": {db: db}, }, } } @@ -72,6 +74,29 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } +func (cs *callbacks) Transaction() *processor { + return cs.processors["transaction"] +} + +func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { + var err error + + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx +} + func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { diff --git a/finisher_api.go b/finisher_api.go index 7a3f27bae..3e406c1cc 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -619,27 +619,13 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // clone statement tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions - err error ) if len(opts) > 0 { opt = opts[0] } - switch beginner := tx.Statement.ConnPool.(type) { - case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - default: - err = ErrInvalidTransaction - } - - if err != nil { - tx.AddError(err) - } - - return tx + return tx.callbacks.Transaction().Begin(tx, opt) } // Commit commit a transaction From 3e6ab990431c48a816676c9efbe1d0952ffb4a28 Mon Sep 17 00:00:00 2001 From: wws <32982278+wuweishuo@users.noreply.github.com> Date: Sat, 25 Jun 2022 16:32:47 +0800 Subject: [PATCH 89/92] fix:serializer contain field panic (#5461) --- schema/field.go | 2 +- tests/serializer_test.go | 43 +++++++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/schema/field.go b/schema/field.go index 981f56f2a..d4dfbd6f7 100644 --- a/schema/field.go +++ b/schema/field.go @@ -950,7 +950,7 @@ func (field *Field) setupNewValuePool() { New: func() interface{} { return &serializer{ Field: field, - Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + Serializer: field.Serializer, } }, } diff --git a/tests/serializer_test.go b/tests/serializer_test.go index ee14841a9..80e015ffa 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -16,13 +16,14 @@ import ( type SerializerStruct struct { gorm.Model - Name []byte `gorm:"json"` - Roles Roles `gorm:"serializer:json"` - Contracts map[string]interface{} `gorm:"serializer:json"` - JobInfo Job `gorm:"type:bytes;serializer:gob"` - CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type - UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type - EncryptedString EncryptedString + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + CustomSerializerString string `gorm:"serializer:custom"` + EncryptedString EncryptedString } type Roles []string @@ -52,7 +53,32 @@ func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst re return "hello" + string(es), nil } +type CustomSerializer struct { + prefix []byte +} + +func NewCustomSerializer(prefix string) *CustomSerializer { + return &CustomSerializer{prefix: []byte(prefix)} +} + +func (c *CustomSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + err = field.Set(ctx, dst, bytes.TrimPrefix(value, c.prefix)) + case string: + err = field.Set(ctx, dst, strings.TrimPrefix(value, string(c.prefix))) + default: + err = fmt.Errorf("unsupported data %#v", dbValue) + } + return err +} + +func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return fmt.Sprintf("%s%s", c.prefix, fieldValue), nil +} + func TestSerializer(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(&SerializerStruct{}) if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) @@ -74,6 +100,7 @@ func TestSerializer(t *testing.T) { Location: "Kenmawr", IsIntern: false, }, + CustomSerializerString: "world", } if err := DB.Create(&data).Error; err != nil { @@ -90,6 +117,7 @@ func TestSerializer(t *testing.T) { } func TestSerializerAssignFirstOrCreate(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(&SerializerStruct{}) if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) @@ -109,6 +137,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) { Location: "Shadyside", IsIntern: false, }, + CustomSerializerString: "world", } // first time insert record From 235c093bb97d37cdfa34103b59eabacfde9b2a42 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 29 Jun 2022 10:07:42 +0800 Subject: [PATCH 90/92] fix(MigrateColumn):declared different type without length (#5465) --- migrator/migrator.go | 11 +++++++---- tests/migrate_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index f20bf5134..87ac77451 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -15,7 +15,6 @@ import ( ) var ( - regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) ) @@ -404,11 +403,16 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate - fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) alterColumn := false + // check type + if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { + alterColumn = true + } + // check size if length, ok := columnType.Length(); length != int64(field.Size) { if length > 0 && field.Size > 0 { @@ -416,9 +420,8 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } else { // has size in data type and not equal // Since the following code is frequently called in the for loop, reg optimization is needed here - matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && + if !field.PrimaryKey && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { alterColumn = true } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 9e7caec9f..0bbef382a 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -884,3 +884,42 @@ func TestInvalidCachedPlan(t *testing.T) { t.Errorf("AutoMigrate err:%v", err) } } + +func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { + type DiffType struct { + ID uint + Name string `gorm:"type:varchar(20)"` + } + + type DiffType1 struct { + ID uint + Name string `gorm:"type:text"` + } + + var err error + DB.Migrator().DropTable(&DiffType{}) + + err = DB.AutoMigrate(&DiffType{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "varchar", strings.ToLower(ct.DatabaseTypeName())) + + err = DB.Table("diff_types").AutoMigrate(&DiffType1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "text", strings.ToLower(ct.DatabaseTypeName())) +} From 2cb4088456eaa845d6e89eeb69fb57d565a72cc2 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 1 Jul 2022 14:37:38 +0800 Subject: [PATCH 91/92] ignore AddError return error --- callbacks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index 1b4e58ea1..f835e5049 100644 --- a/callbacks.go +++ b/callbacks.go @@ -91,7 +91,7 @@ func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { } if err != nil { - tx.AddError(err) + _ = tx.AddError(err) } return tx From c74bc57add435a4fa0de1cd0eb65f11f62fe1dfd Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 1 Jul 2022 15:12:15 +0800 Subject: [PATCH 92/92] fix: association many2many duplicate elem (#5473) * fix: association many2many duplicate elem * chore: gofumpt style --- callbacks/associations.go | 27 ++++++++++++++++++++------- tests/associations_many2many_test.go | 27 +++++++++++++++++++++++++++ tests/migrate_test.go | 4 ++-- tests/serializer_test.go | 3 +-- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index fd3141cfe..4a50e6c24 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -253,6 +253,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} @@ -272,19 +273,31 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joins = reflect.Append(joins, joinValue) } + identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) - for i := 0; i < f.Len(); i++ { elem := f.Index(i) - + if !isPtr { + elem = elem.Addr() + } objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + elems = reflect.Append(elems, elem) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } } + + cacheKey := utils.ToStringKey(relPrimaryValues) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + identityMap[cacheKey] = true + distinctElems = reflect.Append(distinctElems, elem) + } + } } } @@ -304,7 +317,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems, selectColumns, restricted, nil) + saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 28b441bd8..7b45befb6 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -324,3 +325,29 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Team").Clear() AssertAssociationCount(t, users, "Team", 0, "After Clear") } + +func TestDuplicateMany2ManyAssociation(t *testing.T) { + user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-2"}, + }} + + user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-3"}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0bbef382a..3d6a78589 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -830,11 +830,11 @@ func TestUniqueColumn(t *testing.T) { value, ok = ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, false, ok) - } func findColumnType(dest interface{}, columnName string) ( - foundColumn gorm.ColumnType, err error) { + foundColumn gorm.ColumnType, err error, +) { columnTypes, err := DB.Migrator().ColumnTypes(dest) if err != nil { err = fmt.Errorf("ColumnTypes err:%v", err) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 80e015ffa..7232f9df4 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -113,7 +113,6 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) - } func TestSerializerAssignFirstOrCreate(t *testing.T) { @@ -152,7 +151,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) { } AssertEqual(t, result, out) - //update record + // update record data.Roles = append(data.Roles, "r3") data.JobInfo.Location = "Gates Hillman Complex" if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {