// 数据库工具包 package db import ( "database/sql" "errors" "fmt" "reflect" "strconv" "sync" "time" "background/logs" ) // 数据容器抽象对象定义 type Database struct { Type string // 用来给SqlBuilder进行一些特殊的判断 (空值或mysql 皆表示这是一个MySQL实例) DB *sql.DB } // SQL异步执行队列定义 type queueList struct { list []*QueueItem //队列列表 sleeping chan bool loop chan bool lock sync.RWMutex quit chan bool quited bool } // SQL异步执行队列子元素定义 type QueueItem struct { DB *Database //数据库对象 Query string //SQL语句字符串 Params []interface{} //参数列表 } // 缓存数据对象定义 type cache struct { data map[string]map[string]interface{} } func (this *cache) Init() { this.data["default"] = make(map[string]interface{}) } // 设置缓存 func (this *cache) Set(key string, value interface{}, args ...string) { var group string if len(args) > 0 { group = args[0] if _, exist := this.data[group]; !exist { this.data[group] = make(map[string]interface{}) } } else { group = "default" } this.data[group][key] = value } // 获取缓存数据 func (this *cache) Get(key string, args ...string) interface{} { var group string if len(args) > 0 { group = args[0] } else { group = "default" } if g, exist := this.data[group]; exist { if v, ok := g[key]; ok { return v } } return nil } // 删除缓存数据 func (this *cache) Del(key string, args ...string) { var group string if len(args) > 0 { group = args[0] } else { group = "default" } if g, exist := this.data[group]; exist { if _, ok := g[key]; ok { delete(this.data[group], key) } } } var ( lastError error Cache *cache queue *queueList Obj *Database ) func init() { Cache = &cache{data: make(map[string]map[string]interface{})} Cache.Init() queue = &queueList{} go queue.Start() } // 关闭数据库连接 func (this *Database) Close() { this.DB.Close() } // 获取最后发生的错误字符串 func LastErr() string { if lastError != nil { return lastError.Error() } return "" } // 执行语句 func (this *Database) Exec(query string, args ...interface{}) (sql.Result, error) { return this.DB.Exec(query, args...) } // 查询单条记录 func (this *Database) Query(query string, args ...interface{}) (*sql.Rows, error) { return this.DB.Query(query, args...) } // 查询单条记录 func (this *Database) QueryRow(query string, args ...interface{}) *sql.Row { return this.DB.QueryRow(query, args...) } // Query2 查询实体集合 // obj 为接收数据的实体指针 func (this *Database) Query2(sql string, obj interface{}, args ...interface{}) error { var tagMap map[string]int var tp, tps reflect.Type var n, i int var err error var ret reflect.Value // 检测val参数是否为我们所想要的参数 tp = reflect.TypeOf(obj) if reflect.Ptr != tp.Kind() { return errors.New("is not pointer") } if reflect.Slice != tp.Elem().Kind() { return errors.New("is not slice pointer") } tp = tp.Elem() tps = tp.Elem() if reflect.Struct != tps.Kind() { return errors.New("is not struct slice pointer") } tagMap = make(map[string]int) n = tps.NumField() for i = 0; i < n; i++ { tag := tps.Field(i).Tag.Get("sql") if len(tag) > 0 { tagMap[tag] = i + 1 } } // 执行查询 ret, err = this.queryAndReflect(sql, tagMap, tp, args...) if nil != err { return err } // 返回结果 reflect.ValueOf(obj).Elem().Set(ret) return nil } // queryAndReflect 查询并将结果反射成实体集合 func (this *Database) queryAndReflect(sql string, tagMap map[string]int, tpSlice reflect.Type, args ...interface{}) (reflect.Value, error) { var ret reflect.Value // 执行sql语句 rows, err := this.DB.Query(sql, args...) if nil != err { return reflect.Value{}, err } defer rows.Close() // 开始枚举结果 cols, err := rows.Columns() if nil != err { return reflect.Value{}, err } ret = reflect.MakeSlice(tpSlice, 0, 50) // 构建接收队列 scan := make([]interface{}, len(cols)) row := make([]interface{}, len(cols)) for r := range row { scan[r] = &row[r] } for rows.Next() { feild := reflect.New(tpSlice.Elem()).Elem() // 取得结果 err = rows.Scan(scan...) // 开始遍历结果 for i := 0; i < len(cols); i++ { n := tagMap[cols[i]] - 1 if n < 0 { continue } switch feild.Type().Field(n).Type.Kind() { case reflect.Bool: if nil != row[i] { feild.Field(n).SetBool("false" != string(row[i].([]byte))) } else { feild.Field(n).SetBool(false) } case reflect.String: if nil != row[i] { feild.Field(n).SetString(string(row[i].([]byte))) } else { feild.Field(n).SetString("") } case reflect.Float32: if nil != row[i] { //log.Println(row[i].(float32)) switch reflect.TypeOf(row[i]).Kind() { case reflect.Slice: v, e := strconv.ParseFloat(string(row[i].([]byte)), 0) if nil == e { feild.Field(n).SetFloat(float64(v)) //feild.Field(n).SetFloat(float64(row[i].(float32))) } break case reflect.Float64: feild.Field(n).SetFloat(float64(row[i].(float32))) } } else { feild.Field(n).SetFloat(0) } case reflect.Float64: if nil != row[i] { //log.Println(row[i].(float32)) //v, e := strconv.ParseFloat(string(row[i].([]byte)), 0) //if nil == e { feild.Field(n).SetFloat(row[i].(float64)) //} } else { feild.Field(n).SetFloat(0) } case reflect.Int8: fallthrough case reflect.Int16: fallthrough case reflect.Int32: fallthrough case reflect.Int64: fallthrough case reflect.Int: if nil != row[i] { byRow, ok := row[i].([]byte) if ok { v, e := strconv.ParseInt(string(byRow), 10, 64) if nil == e { feild.Field(n).SetInt(v) } } else { v, e := strconv.ParseInt(fmt.Sprint(row[i]), 10, 64) if nil == e { feild.Field(n).SetInt(v) } } } else { feild.Field(n).SetInt(0) } } } ret = reflect.Append(ret, feild) } return ret, nil } // 执行UPDATE语句并返回受影响的行数 // 返回0表示没有出错, 但没有被更新的行 // 返回-1表示出错 func (this *Database) Update(query string, args ...interface{}) (int64, error) { ret, err := this.Exec(query, args...) if err != nil { return -1, err } aff, err := ret.RowsAffected() if err != nil { return -1, err } return aff, nil } // 执行DELETE语句并返回受影响的行数 // 返回0表示没有出错, 但没有被删除的行 // 返回-1表示出错 func (this *Database) Delete(query string, args ...interface{}) (int64, error) { return this.Update(query, args...) } func GenSql(obj interface{}) (string, error) { ret := "" typ := reflect.TypeOf(obj).Kind() if typ != reflect.Struct { return (""), errors.New("not a struct") } value := obj.(reflect.Value) num := value.NumField() for i := 0; i < num; i++ { if i == 0 { ret += "(" } switch value.Field(i).Type().Kind() { case reflect.String: str := value.Field(i).Interface().(string) if str[0] != '"' { ret += "\"" str += "\"" ret += str } else { ret += value.Field(i).Interface().(string) } case reflect.Int: ret += fmt.Sprintf("%d", value.Field(i).Interface().(int)) case reflect.Int8: ret += fmt.Sprintf("%d", value.Field(i).Interface().(int8)) case reflect.Int32: ret += fmt.Sprintf("%d", value.Field(i).Interface().(int32)) case reflect.Int64: ret += fmt.Sprintf("%d", value.Field(i).Interface().(int64)) case reflect.Int16: ret += fmt.Sprintf("%d", value.Field(i).Interface().(int16)) case reflect.Bool: if value.Field(i).Interface().(bool) { ret += fmt.Sprintf("true") } else { ret += fmt.Sprintf("false") } case reflect.Float32: ret += fmt.Sprintf("%x", value.Field(i).Interface().(float32)) case reflect.Float64: ret += fmt.Sprintf("true", value.Field(i).Interface().(float64)) } if i == num-1 { ret += ")" } else { ret += "," } } return ret, nil } func (this *Database) InsertObejct(tb_name string, obj interface{}) (int64, error) { var tagMap map[int]string var tp, tps reflect.Type var n, i int // 检测val参数是否为我们所想要的参数 tp = reflect.TypeOf(obj) if reflect.Ptr != tp.Kind() { return 0, errors.New("is not pointer") } if reflect.Slice != tp.Elem().Kind() { return 0, errors.New("is not slice pointer") } tp = tp.Elem() tps = tp.Elem() value := reflect.ValueOf(obj).Elem() if reflect.Struct != tps.Kind() { return 0, errors.New("is not struct slice pointer") } for z := 0; z < value.Len(); z++ { tagMap = make(map[int]string) n = tps.NumField() var query_struct string for i = 0; i < n; i++ { tag := tps.Field(i).Tag.Get("sql") if len(tag) > 0 { tagMap[i] = tag } if i == 0 { query_struct += "(" } query_struct += tagMap[i] if i == n-1 { query_struct += ")" } else { query_struct += "," } } vs, e := GenSql(value.Index(z)) if nil != e { logs.Error(e.Error()) } query := "insert into " + tb_name + query_struct + "values " + vs _, e = this.Insert(query) if e != nil { logs.Error(e.Error()) } } return 0, nil } // 执行INSERT语句并返回最后生成的自增ID // 返回0表示没有出错, 但没生成自增ID // 返回-1表示出错 func (this *Database) Insert(query string, args ...interface{}) (int64, error) { ret, err := this.Exec(query, args...) if err != nil { return -1, err } last, err := ret.LastInsertId() if err != nil { return -1, err } return last, nil } type OneRow map[string]string type Results []OneRow // 判断字段是否存在 func (row OneRow) Exist(field string) bool { if _, ok := row[field]; ok { return true } return false } // 获取指定字段的值 func (row OneRow) Get(field string) string { if v, ok := row[field]; ok { return v } return "" } // 获取指定字段的整数值, 注意, 如果该字段不存在则会返回0 func (row OneRow) GetInt(field string) int { if v, ok := row[field]; ok { return Atoi(v) } return 0 } // 获取指定字段的整数值, 注意, 如果该字段不存在则会返回0 func (row OneRow) GetInt64(field string) int64 { if v, ok := row[field]; ok { return Atoi64(v) } return 0 } // 设置值 func (row OneRow) Set(key, val string) { row[key] = val } // 查询不定字段的结果集 func (this *Database) Select(query string, args ...interface{}) (Results, error) { rows, err := this.DB.Query(query, args...) if err != nil { return nil, err } defer rows.Close() cols, err := rows.Columns() if err != nil { return nil, err } colNum := len(cols) rawValues := make([][]byte, colNum) scans := make([]interface{}, len(cols)) //query.Scan的参数,因为每次查询出来的列是不定长的,所以传入长度固定当次查询的长度 // 将每行数据填充到[][]byte里 for i := range rawValues { scans[i] = &rawValues[i] } results := make(Results, 0) for rows.Next() { err = rows.Scan(scans...) if err != nil { return nil, err } row := make(map[string]string) for k, raw := range rawValues { key := cols[k] /*if raw == nil { row[key] = "\\N" } else {*/ row[key] = string(raw) //} } results = append(results, row) } return results, nil } // 查询一行不定字段的结果 func (this *Database) SelectOne(query string, args ...interface{}) (OneRow, error) { ret, err := this.Select(query, args...) if err != nil { return nil, err } if len(ret) > 0 { return ret[0], nil } return make(OneRow), nil } // 队列入栈 func (this *queueList) Push(item *QueueItem) { this.lock.Lock() this.list = append(this.list, item) this.lock.Unlock() } // 队列出栈 func (this *queueList) Pop() chan *QueueItem { item := make(chan *QueueItem) go func() { defer close(item) for { switch { case len(this.list) == 0: timeout := time.After(time.Second * 2) select { case <-this.quit: this.quited = true return case <-timeout: //log.Println("SQL Queue polling") } default: this.lock.Lock() i := this.list[0] this.list = this.list[1:] this.lock.Unlock() select { case item <- i: return case <-this.quit: this.quited = true return } } } }() return item } // 执行开始执行 func (this *queueList) Start() { for { if this.quited { return } c := this.Pop() item := <-c item.DB.Exec(item.Query, item.Params...) } } // 停止队列 func (this *queueList) Stop() { this.quit <- true } // 向Sql队列中插入一条执行语句 func (this *Database) Queue(query string, args ...interface{}) { item := &QueueItem{ DB: this, Query: query, Params: args, } queue.Push(item) }