最近项目中正好用到 Gorm 在使用 Gorm 的过程中发现在进行对象查询的时候不支持指针。具体的文档如下:
如果我们强行使用指针进行查询代码如下:
var user *User
// doesn't work
db.First(&user)
Go
会得到 "invalid value" 的错误
我们继续探究下其实 gorm 并不是在查询的时候不支持指针对象,只是目前不支持没有初始化的空指针对象的绑定赋值。比如下面这段代码就可以正常的工作:
user := &User{}
// works
db.First(&user)
Go
在使用 Gorm 进行数据库查询操作的过程中,会有大量的使用指针但是指针没有初始化的场景,如果不使用指针则对于大对象会有大量的拷贝工作,而如果使用指针我们需要进行大量的初始化操作,写起来很繁琐而且很容易漏洞,并且只能在运行时发现容易导致线上问题。所以这个需求还是非常有意义的,然后就想 Gorm 提了 issue。但作者认为设计就是这样的,其实我也不是为了报 Bug,所以想这个需求也不复杂就干脆先改源码实现,然后 直接向 Gorm 提 PR。
第一步就是 Gorm 的源码阅读
接口层代码都在 finisher_api.go 中
可以看到这里面关于读操作的 api 大概有下面个:
Find()
FindInBatches()
First()
FirstOrCreate()
FirstOrInit()
Last()
Pluck()
Scan()
Go
上面的函数 API 大概分为三类:
调用 tx.callbacks.Query().Execute(tx) 进行回调查询操作然后将返回的结果扔到 Scan 扫描方法将行结果反序列化到目标 Struct中;比如 Find(), First(), Last(), Pluck()。
调用基础方法实现查询操作的方法如:FindInBatches(), FirstOrCreate(),FirstOrInit()。它们基本都是调用,上面的方法实现功能上的封装。
Scan 方法调用 Rows() 获取行让后通过 ScanRows 进行处理,ScanRows 最终调用了 Scan 函数这里的 Scan 方法和 Scan 函数是不同的,而 Rows 也会调用 Execute 方法执行回调。
所以总的来说我们只需要关心的函数只有两个一个是 Excuete() 方法,另一个是 Scan 函数
首先我们来看 func (p *processor) Execute(db *DB) {} 这个方法 :
首先我们来看下 Execute 做了什么事情
// gorm/callbacks.go
func (p *processor) Execute(db *DB) {
curTime := time.Now()
stmt := db.Statement
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
}
}
// 前面这块代码基本是就是确定 mysql 查询语句的模型
// 如果没有传入模型,有语法糖通过 dest 目标渲染的模型来确定查询中的模型
// 反之同理
// 如果目标渲染的对象为非空指针
if stmt.Dest != nil {
// 获取反射获取目标对象的值
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
// 如果目标对象的值是一个指针,循环解开指针
for stmt.ReflectValue.Kind() == reflect.Ptr {
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
// 当目标对象的值是零值时,报错
if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value"))
}
}
// 执行回调函数
for _, f := range p.fns {
f(db)
}
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
}
}
Go
对于查询执行器 tx.callbacks.Query() 我们来看下它是怎么创建和初始化的
获取的代码:
// gorm/callbacks.go
func (cs *callbacks) Query() *processor {
return cs.processors["query"]
}
Go
初始化代码:
// gorm/callbacks.go
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
},
}
Go
注册 Callback 函数
// gorm/callbacks/callbacks.go
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
enableTransaction := func(db *gorm.DB) bool {
return !db.SkipDefaultTransaction
}
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
db.Callback().Row().Register("gorm:row", RowQuery)
db.Callback().Raw().Register("gorm:raw", RawExec)
}
Go
最终执行查询的函数:
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun && db.Error == nil {
// 执行查询语句获返回的行结果
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer rows.Close()
gorm.Scan(rows, db, false)
}
}
}
Go
我们可以发现上述 GORM 执行一个查询的过程中关于目标对象处理的逻辑是下面这段代码:
// 如果目标渲染的对象为非空指针
if stmt.Dest != nil {
// 获取反射获取目标对象的值
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
// 如果目标对象的值是一个指针,循环解开指针
for stmt.ReflectValue.Kind() == reflect.Ptr {
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
// 当目标对象的值是零值时,报错
if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value"))
}
}
Go
我们只需要对于零指针的情况做初始化处理就可以解决我们的需求
for stmt.ReflectValue.Kind() == reflect.Ptr {
// 在循环解开指针的过程如果发现值的类型是指针,且他的值是 NIL 空指针
if stmt.ReflectValue.IsNil() {
// 则我们将他的值初始化为零值
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
break
}
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
Go
其中 stmt.ReflectValue.Type().Elem() 当前的类型,因为当前的类型是指针所以通过 Elem() 获取其指向的类型。
接下来我们在看下 func Scan(rows *sql.Rows, db *DB, initialized bool) {} 这个函数是怎么将行结果反序列化到具体的目标 Struct 变量中的:
//gorm/scan.go
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue, ok := dest.(map[string]interface{})
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
mapValue = *v
}
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(dest))
}
default:
Schema := db.Statement.Schema
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
reflectValueType = db.Statement.ReflectValue.Type().Elem()
isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
)
if isPtr {
reflectValueType = reflectValueType.Elem()
}
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
if Schema != nil {
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
// pluck values into slice of data
isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
isPluck = true
}
}
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType)
if isPluck {
db.AddError(rows.Scan(elem.Interface()))
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
}
db.AddError(rows.Scan(values...))
for idx, field := range fields {
if len(joinFields) != 0 && joinFields[idx][0] != nil {
value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
} else if field != nil {
field.Set(elem, values[idx])
}
}
}
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
} else {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
}
}
case reflect.Struct, reflect.Ptr:
if db.Statement.ReflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
field.Set(db.Statement.ReflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
}
}
}
}
}
}
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
db.AddError(ErrRecordNotFound)
}
}
Go
源码稍微有点长,不过大致的逻辑就是根据不同类型的 Dest 怎么样通过的放射的方式将行结果的值赋值给 Dest struct 的变量。这些类型里面唯独不支持指针类型,而我们在 Execute 方法中改造的逻辑会将 Dest 的 Value 设置为一个目标类型零值的指针。所以在 Scan 函数中我们需要增加一个 case
case reflect.Struct, reflect.Ptr: 而 Ptr struct 和 struct 的反射处理字段逻辑是一样的。现在我们要实现的需求就已经完成了, 在单测中增加相关的单测测试通过就可以提 PR 了。
PS:在阅读和调试第三方代码的时候可以将第三方代码的源码拉到本地,然后在 go mod 文件中修改依赖为我们本地的代码。这样就可以方便的调试和阅读第三方代码了。
replace gorm.io/gorm => /Users/lib/gorm
Go
相关的 issue 和 PR