background/db/sql_builder.go

542 lines
11 KiB
Go
Raw Normal View History

2019-01-25 09:11:15 +00:00
package db
import (
"database/sql"
"errors"
"strconv"
"strings"
"log"
"fmt"
"reflect"
"math/big"
"git.jiaxianghudong.com/go/utils"
)
const (
_ = iota
TYPE_INSERT
TYPE_DELETE
TYPE_UPDATE
TYPE_SELECT
TYPE_INSERTUPDATE
)
var (
WrapSymbol = "`"
DBType = "mysql"
)
// SQL语句构造结构
type SB struct {
db *Database
t int
field, table, where, group, order, limit string
values SBValues
values2 SBValues
ignore bool
fullsql bool
debug bool
unsafe bool //是否进行安全检查, 专门针对无限定的UPDATE和DELETE进行二次验证
args []interface{}
}
// Exec返回结果
type SBResult struct {
Success bool //语句是否执行成功
Code int //错误代码
Msg string //错误提示信息
LastID int64 //最后产生的ID
Affected int64 //受影响的行数
Sql string //最后执行的SQL
}
// 值对象
type SBValues map[string]interface{}
// 增量值
type IncVal struct {
Val int64
BaseField string // 为空表示对当前字段累加
}
// 向值对象中加入值
func (v SBValues) Add(key string, val interface{}) {
v[key] = val
}
// 删除值对象中的某个值
func (v SBValues) Del(key string) {
delete(v, key)
}
// 判断指定键是否存在
func (v SBValues) IsExist(key string) bool {
if _, exist := v[key]; exist {
return true
}
return false
}
// 获取键的整形值
func (v SBValues) Get(key string) interface{} {
if val, exist := v[key]; exist {
return val
}
return nil
}
// 获取键的字符串值
func (v SBValues) GetString(key string) string {
if val, exist := v[key]; exist {
if trueVal, ok := val.(string); ok {
return trueVal
}
}
return ""
}
// 获取键的整形值
func (v SBValues) GetInt(key string) int {
if val, exist := v[key]; exist {
if trueVal, ok := val.(int); ok {
return trueVal
}
}
return 0
}
// 获取键的无符号整形值
func (v SBValues) GetUint(key string) uint {
if val, exist := v[key]; exist {
if trueVal, ok := val.(uint); ok {
return trueVal
}
}
return 0
}
// 获取键的64位整形值
func (v SBValues) GetInt64(key string) int64 {
if val, exist := v[key]; exist {
if trueVal, ok := val.(int64); ok {
return trueVal
}
}
return 0
}
// 返回绑定完参数的完整的SQL语句
func FullSql(str string, args ...interface{}) (string, error) {
if !strings.Contains(str, "?") {
return str, nil
}
sons := strings.Split(str, "?")
var ret string
var argIndex int
var maxArgIndex = len(args)
for _, son := range sons {
ret += son
if argIndex < maxArgIndex {
switch v := args[argIndex].(type) {
case int:
ret += strconv.Itoa(v)
case int8:
ret += strconv.Itoa(int(v))
case int16:
ret += strconv.Itoa(int(v))
case int32:
ret += utils.I64toA(int64(v))
case int64:
ret += utils.I64toA(v)
case uint:
ret += utils.UitoA(v)
case uint8:
ret += utils.UitoA(uint(v))
case uint16:
ret += utils.UitoA(uint(v))
case uint32:
ret += utils.Ui32toA(v)
case uint64:
ret += utils.Ui64toA(v)
case float32:
ret += utils.F32toA(v)
case float64:
ret += utils.F64toA(v)
case *big.Int:
ret += v.String()
case bool:
if v {
ret += "true"
} else {
ret += "false"
}
case string:
ret += "'" + strings.Replace(strings.Replace(v, "'", "", -1), `\`, `\\`, -1) + "'"
case nil:
ret += "NULL"
default:
return "", errors.New(fmt.Sprintf("invalid sql argument type: %v => %v (sql: %s)", reflect.TypeOf(v).String(), v, str))
}
argIndex++
}
}
return ret, nil
}
// 构建SQL语句
// param: returnFullSql 是否返回完整的sql语句(即:绑定参数之后的语句)
func (q *SB) ToSql(returnFullSql ...bool) (str string, err error) {
q.args = make([]interface{}, 0)
switch q.t {
case TYPE_INSERT:
if q.table == "" {
err = errors.New("table cannot be empty.")
return
}
if len(q.values) == 0 {
err = errors.New("values cannot be empty.")
return
}
if q.ignore {
str = "INSERT IGNORE INTO " + q.table
} else {
str = "INSERT INTO " + q.table
}
var fields, placeholder string
for k, v := range q.values {
fields += "," + WrapSymbol + k + WrapSymbol
placeholder += ",?"
q.args = append(q.args, v)
}
str += " (" + utils.Substr(fields, 1) + ") VALUES (" + utils.Substr(placeholder, 1) + ")"
case TYPE_DELETE:
if q.table != "" {
if q.where == "" && !q.unsafe {
err = errors.New("deleting all data is not safe.")
return
}
str = "DELETE " + q.table
if q.table != "" {
str += " FROM " + q.table
}
if q.where != "" {
str += " WHERE " + q.where
}
}
case TYPE_UPDATE:
if q.table != "" {
if q.where == "" && !q.unsafe {
err = errors.New("updating all data is not safe.")
return
}
str = "UPDATE " + q.table
str += " SET " + utils.Substr(q.buildUpdateParams(q.values), 1)
if q.where != "" {
str += " WHERE " + q.where
}
}
case TYPE_INSERTUPDATE:
if q.table != "" {
str = "INSERT INTO " + q.table
var fields, placeholder string
for k, v := range q.values {
fields += "," + WrapSymbol + k + WrapSymbol
placeholder += ",?"
q.args = append(q.args, v)
}
str += " (" + utils.Substr(fields, 1) + ") VALUES (" + utils.Substr(placeholder, 1) + ") ON DUPLICATE KEY UPDATE "
placeholder = q.buildUpdateParams(q.values2)
str += utils.Substr(placeholder, 1)
}
case TYPE_SELECT:
str = "SELECT " + q.field
if q.table != "" {
str += " FROM " + q.table
}
if q.where != "" {
str += " WHERE " + q.where
}
if q.group != "" {
str += " GROUP BY " + q.group
}
if q.order != "" {
str += " ORDER BY " + q.order
}
if q.limit != "" && (q.db.Type == "" || q.db.Type == "mysql") {
str += " LIMIT " + q.limit
}
}
if len(returnFullSql) == 1 && returnFullSql[0] {
str, err = FullSql(str, q.args...)
}
return
}
// 构造Update更新参数
func (q *SB) buildUpdateParams(vals SBValues) string {
var placeholder string
for k, v := range vals {
if iv, ok := v.(IncVal); ok {
placeholder += "," + WrapSymbol + k + WrapSymbol + "=" + utils.Ternary(iv.BaseField == "", k, iv.BaseField).(string)
if iv.Val >= 0 {
placeholder += "+" + utils.I64toA(iv.Val)
} else {
placeholder += utils.I64toA(iv.Val)
}
} else {
placeholder += "," + WrapSymbol + k + WrapSymbol + "=?"
q.args = append(q.args, v)
}
}
return placeholder
}
// 设置数据库对象
func (q *SB) DB(db *Database) *SB {
q.db = db
return q
}
// 设置FROM字句
func (q *SB) From(str string) *SB {
q.table = str
return q
}
// 设置表名
func (q *SB) Table(str string) *SB {
return q.From(str)
}
// 设置WHERE字句
func (q *SB) Where(str string) *SB {
q.where = str
return q
}
// 设置GROUP字句
func (q *SB) Group(str string) *SB {
q.group = str
return q
}
// 设置GROUP字句
func (q *SB) Order(str string) *SB {
q.order = str
return q
}
// 设置LIMIT字句
func (q *SB) Limit(count int, offset ...int) *SB {
if len(offset) > 0 {
q.limit = utils.Itoa(offset[0]) + "," + utils.Itoa(count)
} else {
q.limit = "0," + utils.Itoa(count)
}
return q
}
// 设置安全检查开关
func (q *SB) Unsafe(unsefe ...bool) *SB {
if len(unsefe) == 1 && !unsefe[0] {
q.unsafe = false
} else {
q.unsafe = true
}
return q
}
// 是否Debug
func (q *SB) Debug(debug ...bool) *SB {
if len(debug) == 1 && !debug[0] {
q.debug = false
} else {
q.debug = true
}
return q
}
// 设置值
func (q *SB) Value(m SBValues) *SB {
q.values = m
return q
}
// 设置值2
func (q *SB) Value2(m SBValues) *SB {
q.values2 = m
return q
}
// 添加值
func (q *SB) AddValue(key string, val interface{}) *SB {
q.values.Add(key, val)
return q
}
// 添加值2
func (q *SB) AddValue2(key string, val interface{}) *SB {
q.values2.Add(key, val)
return q
}
// 获取一个值对象
func NewValues() SBValues {
return SBValues{}
}
// 构建INSERT语句
func Insert(ignore ...bool) *SB {
var i bool
if len(ignore) == 1 && ignore[0] {
i = true
}
return &SB{t: TYPE_INSERT, db: Obj, ignore: i, values: SBValues{}, args: make([]interface{}, 0)}
}
// 构建DELETE语句
func Delete() *SB {
return &SB{t: TYPE_DELETE, db: Obj}
}
// 构建UPDATE语句
func Update() *SB {
return &SB{t: TYPE_UPDATE, db: Obj, values: SBValues{}, args: make([]interface{}, 0)}
}
// 构建InsertUpdate语句, 仅针对MySQL有效, 内部使用ON DUPLICATE KEY UPDATE方式实现
func InsertUpdate() *SB {
return &SB{t: TYPE_INSERTUPDATE, db: Obj, values: SBValues{}, values2: SBValues{}, args: make([]interface{}, 0)}
}
// 构建SELECT语句
func Select(str ...string) *SB {
fields := "*"
if len(str) == 1 {
fields = str[0]
}
return &SB{t: TYPE_SELECT, db: Obj, field: fields}
}
// 获取构造SQL后的参数
func (q *SB) GetArgs() []interface{} {
return q.args
}
//
func (q *SB) FullSql(yes ...bool) *SB {
if len(yes) == 1 {
q.fullsql = yes[0]
} else {
q.fullsql = true
}
return q
}
// 执行INSERT、DELETE、UPDATE语句
func (q *SB) Exec(args ...interface{}) *SBResult {
var err error
sbRet := &SBResult{}
sbRet.Sql, err = q.ToSql()
if err != nil {
sbRet.Msg = err.Error()
} else {
if q.debug {
log.Println("\n\tSQL prepare statement:\n\t", sbRet.Sql, "\n\tMap args:\n\t", q.args, "\n\tParams:\n\t", args)
}
var ret sql.Result
var err error
if q.fullsql {
var sqlStr string
sqlStr, err = FullSql(sbRet.Sql, append(q.args, args...)...)
if err == nil {
ret, err = q.db.Exec(sqlStr)
}
} else {
ret, err = q.db.Exec(sbRet.Sql, append(q.args, args...)...)
}
if err != nil {
sbRet.Msg = err.Error()
} else {
sbRet.Success = true
switch q.t {
case TYPE_INSERT:
if DBType == "mysql" {
last, err := ret.LastInsertId()
if (err == nil) {
sbRet.LastID = last;
}
}
case TYPE_DELETE:
fallthrough
case TYPE_UPDATE:
fallthrough
case TYPE_INSERTUPDATE:
aff, err := ret.RowsAffected()
if (err == nil) {
sbRet.Affected = aff
}
}
}
}
return sbRet
}
// 查询记录集
func (q *SB) Query(args ...interface{}) (Results, error) {
s, e := q.ToSql()
if e != nil {
return nil, e
}
if q.debug {
log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args)
}
return q.db.Select(s, args...)
}
// 查询单行数据
func (q *SB) QueryOne(args ...interface{}) (OneRow, error) {
q.Limit(1, 0)
s, e := q.ToSql()
if e != nil {
return nil, e
}
if q.debug {
log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args)
}
return q.db.SelectOne(s, args...)
}
// 查询记录集
func (q *SB) QueryAllRow(args ...interface{}) (*sql.Rows, error) {
s, e := q.ToSql()
if e != nil {
return nil, e
}
if q.debug {
log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args)
}
return q.db.Query(s, args...)
}
// 查询单行数据
func (q *SB) QueryRow(args ...interface{}) *sql.Row {
s, e := q.ToSql()
if e != nil {
return nil
}
if q.debug {
log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args)
}
return q.db.QueryRow(s, args...)
}