您的位置 首页 golang

04GORM源码解读

简介

GORM 源码解读, 基于 v1.9.11 版本.

查询

上一节中, 我们已经探究过了模型是如何定义的, 以及数据表是如何创建的.
这次, 看一下查询是如何实现的.

查询涉及到很大的一块内容, 因为要支持各种类型的方法.
先看一下官方文档中提供的最简单的几个查询方法.

// 根据主键查询第一条记录db.First(&user)//// SELECT * FROM users ORDER BY id LIMIT 1;// 随机获取一条记录db.Take(&user)//// SELECT * FROM users LIMIT 1;// 根据主键查询最后一条记录db.Last(&user)//// SELECT * FROM users ORDER BY id DESC LIMIT 1;// 查询所有的记录db.Find(&users)//// SELECT * FROM users;// 查询指定的某条记录(仅当主键为整型时可用)db.First(&user, 10)//// SELECT * FROM users WHERE id = 10;

First 方法为例, 看一下它的实现:

// First find first record that match given conditions, order by primary keyfunc (s *DB) First(out interface{}, where ...interface{}) *DB {    newScope := s.NewScope(out)    newScope.Search.Limit(1)    return newScope.Set("gorm:order_by_primary_key", "ASC").        inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db}

First 方法从数据库中获取第一条数据, 以 primary key 升序排序.

前面介绍过, 具体的数据库操作实现是依靠 callbacks 的. 这里用到了 callbacks.queries.

在默认的 callbacks 中, 注册了三个不同的 query 回调函数.

// Define callbacks for queryingfunc init() {    DefaultCallback.Query().Register("gorm:query", queryCallback)    DefaultCallback.Query().Register("gorm:preload", preloadCallback)    DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)}

查询流程

先来看一下最主要的 queryCallback 函数.

// queryCallback used to query data from databasefunc queryCallback(scope *Scope) {    if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {        return    }    //we are only preloading relations, dont touch base model    if _, skip := scope.InstanceGet("gorm:only_preload"); skip {        return    }    defer scope.trace(scope.db.nowFunc())    var (        isSlice, isPtr bool        resultType     reflect.Type        results        = scope.IndirectValue()    )    if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {        if primaryField := scope.PrimaryField(); primaryField != nil {            scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))        }    }    if value, ok := scope.Get("gorm:query_destination"); ok {        results = indirect(reflect.ValueOf(value))    }    if kind := results.Kind(); kind == reflect.Slice {        isSlice = true        resultType = results.Type().Elem()        results.Set(reflect.MakeSlice(results.Type(), 0, 0))        if resultType.Kind() == reflect.Ptr {            isPtr = true            resultType = resultType.Elem()        }    } else if kind != reflect.Struct {        scope.Err(errors.New("unsupported destination, should be slice or struct"))        return    }    scope.prepareQuerySQL()    if !scope.HasError() {        scope.db.RowsAffected = 0        if str, ok := scope.Get("gorm:query_option"); ok {            scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))        }        if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {            defer rows.Close()            columns, _ := rows.Columns()            for rows.Next() {                scope.db.RowsAffected++                elem := results                if isSlice {                    elem = reflect.New(resultType).Elem()                }                scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())                if isSlice {                    if isPtr {                        results.Set(reflect.Append(results, elem.Addr()))                    } else {                        results.Set(reflect.Append(results, elem))                    }                }            }            if err := rows.Err(); err != nil {                scope.Err(err)            } else if scope.db.RowsAffected == 0 && !isSlice {                scope.Err(ErrRecordNotFound)            }        }    }}

核心的步骤在于 scope.prepareQuerySQL() 构建 SQL 语句.
然后通过 rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...), 执行了数据库查询.

那么查询到的结果是如何传递的, 传递给谁呢?

函数的开头定义了 results = scope.IndirectValue(), 这就是最终查询结果的归属地.

results 只能是结构体或者是结构体的切片.

if kind := results.Kind(); kind == reflect.Slice {  isSlice = true  resultType = results.Type().Elem()  results.Set(reflect.MakeSlice(results.Type(), 0, 0))  if resultType.Kind() == reflect.Ptr {    isPtr = true    resultType = resultType.Elem()  }} else if kind != reflect.Struct {  scope.Err(errors.New("unsupported destination, should be slice or struct"))  return}

具体如何处理查询到的结果是在下面这部分代码中:

columns, _ := rows.Columns()for rows.Next() {  scope.db.RowsAffected++  elem := results  if isSlice {    elem = reflect.New(resultType).Elem()  }  scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())  if isSlice {    if isPtr {      results.Set(reflect.Append(results, elem.Addr()))    } else {      results.Set(reflect.Append(results, elem))    }  }}

这部分代码的核心语句在于 scope.scan, 看一下这个方法的定义:

func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {    var (        ignored            interface{}        values             = make([]interface{}, len(columns))        selectFields       []*Field        selectedColumnsMap = map[string]int{}        resetFields        = map[int]*Field{}    )    for index, column := range columns {        values[index] = &ignored        selectFields = fields        offset := 0        if idx, ok := selectedColumnsMap[column]; ok {            offset = idx + 1            selectFields = selectFields[offset:]        }        for fieldIndex, field := range selectFields {            if field.DBName == column {                if field.Field.Kind() == reflect.Ptr {                    values[index] = field.Field.Addr().Interface()                } else {                    reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))                    reflectValue.Elem().Set(field.Field.Addr())                    values[index] = reflectValue.Interface()                    resetFields[index] = field                }                selectedColumnsMap[column] = offset + fieldIndex                if field.IsNormal {                    break                }            }        }    }    scope.Err(rows.Scan(values...))    for index, field := range resetFields {        if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {            field.Field.Set(v)        }    }}

就和它的名字暗示的那样, 实际上就是调用了 rows.Scan(values...), 将查询到的数据复制到对应的字段中.

由此, 我们就了解了查询时的主要流程了.

前面专注于流程, 略过了构建 SQL 语句的细节, 来仔细看看 prepareQuerySQL 方法.

构建查询 SQL 语句

func (scope *Scope) prepareQuerySQL() {    if scope.Search.raw {        scope.Raw(scope.CombinedConditionSql())    } else {        scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))    }    return}

内部分支中都使用到了 scope.Raw, 看一下它的实现:

// Raw set raw sqlfunc (scope *Scope) Raw(sql string) *Scope {    scope.SQL = strings.Replace(sql, "$$$", "?", -1)    return scope}

它的作用是将获取到的 sql 语句赋值到 scope.SQL 字段上, 其中替换了所有的 $$$?.

回到 prepareQuerySQL 上来, 重要的部分是其实是 Raw 的参数.
if 的后半部分更好理解点, 就是构建了 SELECT 表达式.

SELECT 表达式需要三个变量, 字段名, 表名, 条件.

将每个都看一下吧.

func (scope *Scope) selectSQL() string {    if len(scope.Search.selects) == 0 {        if len(scope.Search.joinConditions) > 0 {            return fmt.Sprintf("%v.*", scope.QuotedTableName())        }        return "*"    }    return scope.buildSelectQuery(scope.Search.selects)}func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {    switch value := clause["query"].(type) {    case string:        str = value    case []string:        str = strings.Join(value, ", ")    }    args := clause["args"].([]interface{})    replacements := []string{}    for _, arg := range args {        switch reflect.ValueOf(arg).Kind() {        case reflect.Slice:            values := reflect.ValueOf(arg)            var tempMarks []string            for i := 0; i < values.Len(); i++ {                tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))            }            replacements = append(replacements, strings.Join(tempMarks, ","))        default:            if valuer, ok := interface{}(arg).(driver.Valuer); ok {                arg, _ = valuer.Value()            }            replacements = append(replacements, scope.AddToVars(arg))        }    }    buff := bytes.NewBuffer([]byte{})    i := 0    for pos, char := range str {        if str[pos] == '?' {            buff.WriteString(replacements[i])            i++        } else {            buff.WriteRune(char)        }    }    str = buff.String()    return}

scope.Search.selects 为空的时候, 比较简单.
只要根据是否有连表查询, 返回 table.**.

buildSelectQuery 就是根据 scope.Search.selects 构建查询字段名.

前面半部分一看就明白.

switch value := clause["query"].(type) {case string:  str = valuecase []string:  str = strings.Join(value, ", ")}

重点是遇到参数时如何处理, 也就是后半段代码.

args := clause["args"].([]interface{})replacements := []string{}for _, arg := range args {  switch reflect.ValueOf(arg).Kind() {  case reflect.Slice:    values := reflect.ValueOf(arg)    var tempMarks []string    for i := 0; i < values.Len(); i++ {      tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))    }    replacements = append(replacements, strings.Join(tempMarks, ","))  default:    if valuer, ok := interface{}(arg).(driver.Valuer); ok {      arg, _ = valuer.Value()    }    replacements = append(replacements, scope.AddToVars(arg))  }}buff := bytes.NewBuffer([]byte{})i := 0for pos, char := range str {  if str[pos] == '?' {    buff.WriteString(replacements[i])    i++  } else {    buff.WriteRune(char)  }}

主要的过程是遍历 args := clause["args"].([]interface{}),
创建了一个 replacements 切片. 然后将 str 中所有的 ?,
替换为了对应的字段.

到此, 构建 SELECT 字段的过程就结束了.

获取表名的过程相对简单, 直接展示代码吧:

// QuotedTableName return quoted table namefunc (scope *Scope) QuotedTableName() (name string) {    if scope.search != nil && len(scope.Search.tableName) > 0 {        if strings.Contains(scope.Search.tableName, " ") {            return scope.Search.tableName        }        return scope.Quote(scope.Search.tableName)    }    return scope.Quote(scope.TableName())}

条件语句

更多的关注点在于如何构建筛选条件, 即 CombinedConditionSql 方法.

// CombinedConditionSql return combined condition sqlfunc (scope *Scope) CombinedConditionSql() string {    joinSQL := scope.joinsSQL()    whereSQL := scope.whereSQL()    if scope.Search.raw {        whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")")    }    return joinSQL + whereSQL + scope.groupSQL() +        scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()}

短小的代码中是精简的逻辑, 条件语句有很多模块, 这里总共有 6 个子句.
都看一遍吧, 看完之后应该对如何构建条件语句不会陌生了.

func (scope *Scope) joinsSQL() string {    var joinConditions []string    for _, clause := range scope.Search.joinConditions {        if sql := scope.buildCondition(clause, true); sql != "" {            joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))        }    }    return strings.Join(joinConditions, " ") + " "}

创建 joinSQL 的过程中主要用到了 buildCondition, 继续深入:

func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) {    var (        quotedTableName  = scope.QuotedTableName()        quotedPrimaryKey = scope.Quote(scope.PrimaryKey())        equalSQL         = "="        inSQL            = "IN"    )    // If building not conditions    if !include {        equalSQL = "<>"        inSQL = "NOT IN"    }    switch value := clause["query"].(type) {    case sql.NullInt64:        return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64)    case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:        return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value)    case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:        if !include && reflect.ValueOf(value).Len() == 0 {            return        }        str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL)        clause["args"] = []interface{}{value}    case string:        if isNumberRegexp.MatchString(value) {            return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value))        }        if value != "" {            if !include {                if comparisonRegexp.MatchString(value) {                    str = fmt.Sprintf("NOT (%v)", value)                } else {                    str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value))                }            } else {                str = fmt.Sprintf("(%v)", value)            }        }    case map[string]interface{}:        var sqls []string        for key, value := range value {            if value != nil {                sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value)))            } else {                if !include {                    sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key)))                } else {                    sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key)))                }            }        }        return strings.Join(sqls, " AND ")    case interface{}:        var sqls []string        newScope := scope.New(value)        if len(newScope.Fields()) == 0 {            scope.Err(fmt.Errorf("invalid query condition: %v", value))            return        }        scopeQuotedTableName := newScope.QuotedTableName()        for _, field := range newScope.Fields() {            if !field.IsIgnored && !field.IsBlank {                sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))            }        }        return strings.Join(sqls, " AND ")    default:        scope.Err(fmt.Errorf("invalid query condition: %v", value))        return    }    replacements := []string{}    args := clause["args"].([]interface{})    for _, arg := range args {        var err error        switch reflect.ValueOf(arg).Kind() {        case reflect.Slice: // For where("id in (?)", []int64{1,2})            if scanner, ok := interface{}(arg).(driver.Valuer); ok {                arg, err = scanner.Value()                replacements = append(replacements, scope.AddToVars(arg))            } else if b, ok := arg.([]byte); ok {                replacements = append(replacements, scope.AddToVars(b))            } else if as, ok := arg.([][]interface{}); ok {                var tempMarks []string                for _, a := range as {                    var arrayMarks []string                    for _, v := range a {                        arrayMarks = append(arrayMarks, scope.AddToVars(v))                    }                    if len(arrayMarks) > 0 {                        tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ",")))                    }                }                if len(tempMarks) > 0 {                    replacements = append(replacements, strings.Join(tempMarks, ","))                }            } else if values := reflect.ValueOf(arg); values.Len() > 0 {                var tempMarks []string                for i := 0; i < values.Len(); i++ {                    tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))                }                replacements = append(replacements, strings.Join(tempMarks, ","))            } else {                replacements = append(replacements, scope.AddToVars(Expr("NULL")))            }        default:            if valuer, ok := interface{}(arg).(driver.Valuer); ok {                arg, err = valuer.Value()            }            replacements = append(replacements, scope.AddToVars(arg))        }        if err != nil {            scope.Err(err)        }    }    buff := bytes.NewBuffer([]byte{})    i := 0    for _, s := range str {        if s == '?' && len(replacements) > i {            buff.WriteString(replacements[i])            i++        } else {            buff.WriteRune(s)        }    }    str = buff.String()    return}

开头是一个精妙的选择, 基于 include, 实现了 not 条件.

var (  quotedTableName  = scope.QuotedTableName()  quotedPrimaryKey = scope.Quote(scope.PrimaryKey())  equalSQL         = "="  inSQL            = "IN")// If building not conditionsif !include {  equalSQL = "<>"  inSQL = "NOT IN"}

中间是一个 switch value := clause["query"].(type) 选择.
在这个 switch 选择中, 大部分的条件都会直接返回.
剩余的部分, 则会构建 str 字符串变量.

而这会继续进入到结尾部分, 这部分的代码和我们上面看过的非常类似,
就是根据 clause["args"] 构建 replacements 切片,
用来替换 str 变量中的 ?.

接着看下一个 whereSQL 方法.

func (scope *Scope) whereSQL() (sql string) {    var (        quotedTableName                                = scope.QuotedTableName()        deletedAtField, hasDeletedAtField              = scope.FieldByName("DeletedAt")        primaryConditions, andConditions, orConditions []string    )    if !scope.Search.Unscoped && hasDeletedAtField {        sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName))        primaryConditions = append(primaryConditions, sql)    }    if !scope.PrimaryKeyZero() {        for _, field := range scope.PrimaryFields() {            sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))            primaryConditions = append(primaryConditions, sql)        }    }    for _, clause := range scope.Search.whereConditions {        if sql := scope.buildCondition(clause, true); sql != "" {            andConditions = append(andConditions, sql)        }    }    for _, clause := range scope.Search.orConditions {        if sql := scope.buildCondition(clause, true); sql != "" {            orConditions = append(orConditions, sql)        }    }    for _, clause := range scope.Search.notConditions {        if sql := scope.buildCondition(clause, false); sql != "" {            andConditions = append(andConditions, sql)        }    }    orSQL := strings.Join(orConditions, " OR ")    combinedSQL := strings.Join(andConditions, " AND ")    if len(combinedSQL) > 0 {        if len(orSQL) > 0 {            combinedSQL = combinedSQL + " OR " + orSQL        }    } else {        combinedSQL = orSQL    }    if len(primaryConditions) > 0 {        sql = "WHERE " + strings.Join(primaryConditions, " AND ")        if len(combinedSQL) > 0 {            sql = sql + " AND (" + combinedSQL + ")"        }    } else if len(combinedSQL) > 0 {        sql = "WHERE " + combinedSQL    }    return}

主要构建了三个部分, primaryConditions, andConditions, orConditions.

if !scope.Search.Unscoped && hasDeletedAtField {  sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName))  primaryConditions = append(primaryConditions, sql)}if !scope.PrimaryKeyZero() {  for _, field := range scope.PrimaryFields() {    sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))    primaryConditions = append(primaryConditions, sql)  }}

前面两个 if 构建了 primaryConditions 条件.

for _, clause := range scope.Search.whereConditions {  if sql := scope.buildCondition(clause, true); sql != "" {    andConditions = append(andConditions, sql)  }}for _, clause := range scope.Search.orConditions {  if sql := scope.buildCondition(clause, true); sql != "" {    orConditions = append(orConditions, sql)  }}for _, clause := range scope.Search.notConditions {  if sql := scope.buildCondition(clause, false); sql != "" {    andConditions = append(andConditions, sql)  }}

然后三个 for 循环都使用了 buildCondition 方法.
注意到 scope.Search.notConditions 是算在 andConditions 中的.

orSQL := strings.Join(orConditions, " OR ")combinedSQL := strings.Join(andConditions, " AND ")if len(combinedSQL) > 0 {  if len(orSQL) > 0 {    combinedSQL = combinedSQL + " OR " + orSQL  }} else {  combinedSQL = orSQL}

结合 orConditionsandConditions 生成了条件语句.

if len(primaryConditions) > 0 {  sql = "WHERE " + strings.Join(primaryConditions, " AND ")  if len(combinedSQL) > 0 {    sql = sql + " AND (" + combinedSQL + ")"  }} else if len(combinedSQL) > 0 {  sql = "WHERE " + combinedSQL}return

最后, 结合 primaryConditions 生成最终的 WHERE 子句.

接着看另一个:

func (scope *Scope) groupSQL() string {    if len(scope.Search.group) == 0 {        return ""    }    return " GROUP BY " + scope.Search.group}

GROUP BY 子句比较简单, 直接就能构建.

继续:

func (scope *Scope) havingSQL() string {    if len(scope.Search.havingConditions) == 0 {        return ""    }    var andConditions []string    for _, clause := range scope.Search.havingConditions {        if sql := scope.buildCondition(clause, true); sql != "" {            andConditions = append(andConditions, sql)        }    }    combinedSQL := strings.Join(andConditions, " AND ")    if len(combinedSQL) == 0 {        return ""    }    return " HAVING " + combinedSQL}

HAVING 子句也不算难, 构建完条件之后用 AND 连接, 然后在最前面加上 HAVING 就行了.

继续:

func (scope *Scope) orderSQL() string {    if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery {        return ""    }    var orders []string    for _, order := range scope.Search.orders {        if str, ok := order.(string); ok {            orders = append(orders, scope.quoteIfPossible(str))        } else if expr, ok := order.(*expr); ok {            exp := expr.expr            for _, arg := range expr.args {                exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)            }            orders = append(orders, exp)        }    }    return " ORDER BY " + strings.Join(orders, ",")}

结构也是类似, 遍历 scope.Search.orders 切片, order 有两种不同的类型, 字符串或者 expr 结构体.
后者用于处理带参数的情况.

最后还有一个 limitAndOffsetSQL 方法:

func (scope *Scope) limitAndOffsetSQL() string {    return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)}

这直接调用了具体数据库驱动中的 LimitAndOffsetSQL 方法.

看两个具体的实现, 一个是通用中的实现, 另一个是 mysql 中的实现.

func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {    if limit != nil {        if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {            sql += fmt.Sprintf(" LIMIT %d", parsedLimit)        }    }    if offset != nil {        if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {            sql += fmt.Sprintf(" OFFSET %d", parsedOffset)        }    }    return}

直接将 limit 和 offset 解析为 int 类型, 然后连接对应的关键字即可.

接着看一下 mysql 中的实现:

func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {    if limit != nil {        if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {            sql += fmt.Sprintf(" LIMIT %d", parsedLimit)            if offset != nil {                if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {                    sql += fmt.Sprintf(" OFFSET %d", parsedOffset)                }            }        }    }    return}

两者的区别在于 offset 的嵌套, mysql 中 offset 必须和 limit 一起使用.

就这样, CombinedConditionSql 中的所有子句都看完了.
说到底其实也没什么魔法, 不过是根据不同的条件, 构建不同的 SQL 语句.

小结

一路从 First 深入到查询的内部细节. 在了解了底层细节之后, 其他类似的方法也就不难理解了.

// Take return a record that match given conditions, the order will depend on the database implementationfunc (s *DB) Take(out interface{}, where ...interface{}) *DB {    newScope := s.NewScope(out)    newScope.Search.Limit(1)    return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db}// Last find last record that match given conditions, order by primary keyfunc (s *DB) Last(out interface{}, where ...interface{}) *DB {    newScope := s.NewScope(out)    newScope.Search.Limit(1)    return newScope.Set("gorm:order_by_primary_key", "DESC").        inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db}// Find find records that match given conditionsfunc (s *DB) Find(out interface{}, where ...interface{}) *DB {    return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db}

search 结构体

前面的过程中, 我们只看到了最简单的查询是如何产生的.
在这个过程中, 没有仔细研究查询条件是如何存储的.

看一下如何使用 Where 方法添加查询条件.

// Get first matched recorddb.Where("name = ?", "jinzhu").First(&user)//// SELECT * FROM users WHERE name = 'jinzhu' limit 1;// Get all matched recordsdb.Where("name = ?", "jinzhu").Find(&users)//// SELECT * FROM users WHERE name = 'jinzhu';

上面的例子来自于官方文档. GORM 使用链式调用的风格, 可以串联多个 Where 方法, 或是其他的查询条件.

// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#queryfunc (s *DB) Where(query interface{}, args ...interface{}) *DB {    return s.clone().search.Where(query, args...).db}

上面是 Where 方法的代码, 在它的源码附近有很多类似的的方法.

// Or filter records that match before conditions or this one, similar to `Where`func (s *DB) Or(query interface{}, args ...interface{}) *DB {    return s.clone().search.Or(query, args...).db}// Not filter records that don't match current conditions, similar to `Where`func (s *DB) Not(query interface{}, args ...interface{}) *DB {    return s.clone().search.Not(query, args...).db}

可以很容易的发现, 这一切的源头都是 search 对象.

结构体 DB 定义的时候, 有个字段就是 search:

search            *search

search 的定义

这就是用于存储查询条件的地方. 它的定义如下:

type search struct {    db               *DB    whereConditions  []map[string]interface{}    orConditions     []map[string]interface{}    notConditions    []map[string]interface{}    havingConditions []map[string]interface{}    joinConditions   []map[string]interface{}    initAttrs        []interface{}    assignAttrs      []interface{}    selects          map[string]interface{}    omits            []string    orders           []interface{}    preload          []searchPreload    offset           interface{}    limit            interface{}    group            string    tableName        string    raw              bool    Unscoped         bool    ignoreOrderQuery bool}type searchPreload struct {    schema     string    conditions []interface{}}

这里有很多类型为 []map[string]interface{} 的字段, 结合前面关于条件查询的代码, 就能回忆起这就是存储各种条件的地方.

另一些字段比如 offsetlimit 也很容易明白它的作用.

search 的方法

search 下有很多方法, 虽然方法数量比较多, 但基本都很短, 总共也就一百行出头.

func (s *search) clone() *search {    clone := *s    return &clone}

这个克隆方法有点独特, 似乎什么也没做, 也可能是我见识少.

func (s *search) Where(query interface{}, values ...interface{}) *search {    s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})    return s}func (s *search) Not(query interface{}, values ...interface{}) *search {    s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})    return s}func (s *search) Or(query interface{}, values ...interface{}) *search {    s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})    return s}

上面这些方法都是用参数构建成一个 map 然后推入对应的切片中, 考虑到链式调用, 返回了本身.

func (s *search) Attrs(attrs ...interface{}) *search {    s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))    return s}func (s *search) Assign(attrs ...interface{}) *search {    s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))    return s}func toSearchableMap(attrs ...interface{}) (result interface{}) {    if len(attrs) > 1 {        if str, ok := attrs[0].(string); ok {            result = map[string]interface{}{str: attrs[1]}        }    } else if len(attrs) == 1 {        if attr, ok := attrs[0].(map[string]interface{}); ok {            result = attr        }        if attr, ok := attrs[0].(interface{}); ok {            result = attr        }    }    return}

这两个方法也是类似, 并使用了 toSearchableMap 转换参数.

func (s *search) Order(value interface{}, reorder ...bool) *search {    if len(reorder) > 0 && reorder[0] {        s.orders = []interface{}{}    }    if value != nil && value != "" {        s.orders = append(s.orders, value)    }    return s}

看到这个可能有点疑惑, 可以从文档和注释中获取解释.

// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions//     db.Order("name DESC")//     db.Order("name DESC", true) // reorder//     db.Order(gorm.Expr("name = ? DESC", "first")) // sql expressionfunc (s *DB) Order(value interface{}, reorder ...bool) *DB {    return s.clone().search.Order(value, reorder...).db}

第二个参数用于判断是否覆盖前面的排序条件.

可能有点奇怪的是为什么 reorder 是可变参数, 不知为了兼容或者是历史遗留.

另一点是不能理解 []interface{}{}, 这其实可以分为两部分, []interface{} 是类型, {} 构造了一个空的该类型实例.

func (s *search) Select(query interface{}, args ...interface{}) *search {    s.selects = map[string]interface{}{"query": query, "args": args}    return s}func (s *search) Omit(columns ...string) *search {    s.omits = columns    return s}func (s *search) Limit(limit interface{}) *search {    s.limit = limit    return s}func (s *search) Offset(offset interface{}) *search {    s.offset = offset    return s}

这几个就是替换型的了, 每次调用都只会保存最新值.

func (s *search) Group(query string) *search {    s.group = s.getInterfaceAsSQL(query)    return s}func (s *search) getInterfaceAsSQL(value interface{}) (str string) {    switch value.(type) {    case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:        str = fmt.Sprintf("%v", value)    default:        s.db.AddError(ErrInvalidSQL)    }    if str == "-1" {        return ""    }    return}

getInterfaceAsSQL 的一个特性是使用 -1 会重置.

func (s *search) Having(query interface{}, values ...interface{}) *search {    if val, ok := query.(*expr); ok {        s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})    } else {        s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})    }    return s}func (s *search) Joins(query string, values ...interface{}) *search {    s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})    return s}

这其实也比较类似前面看过的, 就不多解释了.

func (s *search) Preload(schema string, values ...interface{}) *search {    var preloads []searchPreload    for _, preload := range s.preload {        if preload.schema != schema {            preloads = append(preloads, preload)        }    }    preloads = append(preloads, searchPreload{schema, values})    s.preload = preloads    return s}

Preload 需要防止重复, 所以开头会重新遍历一遍已经存在的 schema.

func (s *search) Raw(b bool) *search {    s.raw = b    return s}func (s *search) unscoped() *search {    s.Unscoped = true    return s}func (s *search) Table(name string) *search {    s.tableName = name    return s}

最后几个方法也没什么特殊的.

小结

search 结构体还是挺简单的, 定义加方法总共也就一百多行.
但用处却不小, 查询相关的条件都是存储在这里的.

总结

这部分主要查看了 SQL 查询是如何发生的, 并在这个过程中探索了各种查询子句是如何实现的. 同时, 也研究了一下 search 结构体和它的作用.


文章来源:智云一二三科技

文章标题:04GORM源码解读

文章地址:https://www.zhihuclub.com/7436.shtml

关于作者: 智云科技

热门文章

网站地图