Gorm 改造指针对象

gorm object support pointer

foreversmart write on 2021-02-02
最近项目中正好用到 Gorm 在使用 Gorm 的过程中发现在进行对象查询的时候不支持指针。具体的文档如下:
如果我们强行使用指针进行查询代码如下:
var user *User // doesn't work db.First(&user)
Go
会得到 "invalid value" 的错误
我们继续探究下其实 gorm 并不是在查询的时候不支持指针对象,只是目前不支持没有初始化的空指针对象的绑定赋值。比如下面这段代码就可以正常的工作:
user := &User{} // works db.First(&user)
Go
在使用 Gorm 进行数据库查询操作的过程中,会有大量的使用指针但是指针没有初始化的场景,如果不使用指针则对于大对象会有大量的拷贝工作,而如果使用指针我们需要进行大量的初始化操作,写起来很繁琐而且很容易漏洞,并且只能在运行时发现容易导致线上问题。所以这个需求还是非常有意义的,然后就想 Gorm 提了 issue。但作者认为设计就是这样的,其实我也不是为了报 Bug,所以想这个需求也不复杂就干脆先改源码实现,然后 直接向 Gorm 提 PR。
第一步就是 Gorm 的源码阅读
接口层代码都在 finisher_api.go 中
可以看到这里面关于读操作的 api 大概有下面个:
Find() FindInBatches() First() FirstOrCreate() FirstOrInit() Last() Pluck() Scan()
Go
上面的函数 API 大概分为三类:
调用 tx.callbacks.Query().Execute(tx) 进行回调查询操作然后将返回的结果扔到 Scan 扫描方法将行结果反序列化到目标 Struct中;比如 Find(), First(), Last(), Pluck()。
调用基础方法实现查询操作的方法如:FindInBatches(), FirstOrCreate(),FirstOrInit()。它们基本都是调用,上面的方法实现功能上的封装。
Scan 方法调用 Rows() 获取行让后通过 ScanRows 进行处理,ScanRows 最终调用了 Scan 函数这里的 Scan 方法和 Scan 函数是不同的,而 Rows 也会调用 Execute 方法执行回调。
所以总的来说我们只需要关心的函数只有两个一个是 Excuete() 方法,另一个是 Scan 函数
首先我们来看 func (p *processor) Execute(db *DB) {} 这个方法 :
首先我们来看下 Execute 做了什么事情
// gorm/callbacks.go func (p *processor) Execute(db *DB) { curTime := time.Now() stmt := db.Statement if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model } if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) } } } // 前面这块代码基本是就是确定 mysql 查询语句的模型 // 如果没有传入模型,有语法糖通过 dest 目标渲染的模型来确定查询中的模型 // 反之同理 // 如果目标渲染的对象为非空指针 if stmt.Dest != nil { // 获取反射获取目标对象的值 stmt.ReflectValue = reflect.ValueOf(stmt.Dest) // 如果目标对象的值是一个指针,循环解开指针 for stmt.ReflectValue.Kind() == reflect.Ptr { stmt.ReflectValue = stmt.ReflectValue.Elem() } // 当目标对象的值是零值时,报错 if !stmt.ReflectValue.IsValid() { db.AddError(fmt.Errorf("invalid value")) } } // 执行回调函数 for _, f := range p.fns { f(db) } db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) if !stmt.DB.DryRun { stmt.SQL.Reset() stmt.Vars = nil } }
Go
对于查询执行器 tx.callbacks.Query() 我们来看下它是怎么创建和初始化的
获取的代码:
// gorm/callbacks.go func (cs *callbacks) Query() *processor { return cs.processors["query"] }
Go
初始化代码:
// gorm/callbacks.go func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ "create": {db: db}, "query": {db: db}, "update": {db: db}, "delete": {db: db}, "row": {db: db}, "raw": {db: db}, }, }
Go
注册 Callback 函数
// gorm/callbacks/callbacks.go func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { enableTransaction := func(db *gorm.DB) bool { return !db.SkipDefaultTransaction } queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) db.Callback().Row().Register("gorm:row", RowQuery) db.Callback().Raw().Register("gorm:raw", RawExec) }
Go
最终执行查询的函数:
func Query(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) if !db.DryRun && db.Error == nil { // 执行查询语句获返回的行结果 rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) return } defer rows.Close() gorm.Scan(rows, db, false) } } }
Go
我们可以发现上述 GORM 执行一个查询的过程中关于目标对象处理的逻辑是下面这段代码:
// 如果目标渲染的对象为非空指针 if stmt.Dest != nil { // 获取反射获取目标对象的值 stmt.ReflectValue = reflect.ValueOf(stmt.Dest) // 如果目标对象的值是一个指针,循环解开指针 for stmt.ReflectValue.Kind() == reflect.Ptr { stmt.ReflectValue = stmt.ReflectValue.Elem() } // 当目标对象的值是零值时,报错 if !stmt.ReflectValue.IsValid() { db.AddError(fmt.Errorf("invalid value")) } }
Go
我们只需要对于零指针的情况做初始化处理就可以解决我们的需求
for stmt.ReflectValue.Kind() == reflect.Ptr { // 在循环解开指针的过程如果发现值的类型是指针,且他的值是 NIL 空指针 if stmt.ReflectValue.IsNil() { // 则我们将他的值初始化为零值 stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) break } stmt.ReflectValue = stmt.ReflectValue.Elem() }
Go
其中 stmt.ReflectValue.Type().Elem() 当前的类型,因为当前的类型是指针所以通过 Elem() 获取其指向的类型。
接下来我们在看下 func Scan(rows *sql.Rows, db *DB, initialized bool) {} 这个函数是怎么将行结果反序列化到具体的目标 Struct 变量中的:
//gorm/scan.go func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: if initialized || rows.Next() { columnTypes, _ := rows.ColumnTypes() prepareValues(values, db, columnTypes, columns) db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue, ok := dest.(map[string]interface{}) if !ok { if v, ok := dest.(*map[string]interface{}); ok { mapValue = *v } } scanIntoMap(mapValue, values, columns) } case *[]map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, *float32, *float64, *bool, *string, *time.Time, *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, *sql.NullBool, *sql.NullString, *sql.NullTime: for initialized || rows.Next() { initialized = false db.RowsAffected++ db.AddError(rows.Scan(dest)) } default: Schema := db.Statement.Schema switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( reflectValueType = db.Statement.ReflectValue.Type().Elem() isPtr = reflectValueType.Kind() == reflect.Ptr fields = make([]*schema.Field, len(columns)) joinFields [][2]*schema.Field ) if isPtr { reflectValueType = reflectValueType.Elem() } db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) if Schema != nil { if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field if len(joinFields) == 0 { joinFields = make([][2]*schema.Field, len(columns)) } joinFields[idx] = [2]*schema.Field{rel.Field, field} continue } } values[idx] = &sql.RawBytes{} } else { values[idx] = &sql.RawBytes{} } } } // pluck values into slice of data isPluck := false if len(fields) == 1 { if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time isPluck = true } } for initialized || rows.Next() { initialized = false db.RowsAffected++ elem := reflect.New(reflectValueType) if isPluck { db.AddError(rows.Scan(elem.Interface())) } else { for idx, field := range fields { if field != nil { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } } db.AddError(rows.Scan(values...)) for idx, field := range fields { if len(joinFields) != 0 && joinFields[idx][0] != nil { value := reflect.ValueOf(values[idx]).Elem() relValue := joinFields[idx][0].ReflectValueOf(elem) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if value.IsNil() { continue } relValue.Set(reflect.New(relValue.Type().Elem())) } field.Set(relValue, values[idx]) } else if field != nil { field.Set(elem, values[idx]) } } } if isPtr { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) } else { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) } } case reflect.Struct, reflect.Ptr: if db.Statement.ReflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } if initialized || rows.Next() { for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue } } values[idx] = &sql.RawBytes{} } else { values[idx] = &sql.RawBytes{} } } db.RowsAffected++ db.AddError(rows.Scan(values...)) for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) value := reflect.ValueOf(values[idx]).Elem() if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if value.IsNil() { continue } relValue.Set(reflect.New(relValue.Type().Elem())) } field.Set(relValue, values[idx]) } } } } } } } if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { db.AddError(ErrRecordNotFound) } }
Go
源码稍微有点长,不过大致的逻辑就是根据不同类型的 Dest 怎么样通过的放射的方式将行结果的值赋值给 Dest struct 的变量。这些类型里面唯独不支持指针类型,而我们在 Execute 方法中改造的逻辑会将 Dest 的 Value 设置为一个目标类型零值的指针。所以在 Scan 函数中我们需要增加一个 case
case reflect.Struct, reflect.Ptr: 而 Ptr struct 和 struct 的反射处理字段逻辑是一样的。现在我们要实现的需求就已经完成了, 在单测中增加相关的单测测试通过就可以提 PR 了。
PS:在阅读和调试第三方代码的时候可以将第三方代码的源码拉到本地,然后在 go mod 文件中修改依赖为我们本地的代码。这样就可以方便的调试和阅读第三方代码了。
replace gorm.io/gorm => /Users/lib/gorm
Go
相关的 issue 和 PR

「真诚赞赏,手留余香」

Foreversmart

真诚赞赏,手留余香

使用微信扫描二维码完成支付