diff --git a/go.mod b/go.mod index 0367805..a6a8335 100644 --- a/go.mod +++ b/go.mod @@ -49,7 +49,7 @@ require ( github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.617 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms v1.0.617 github.com/tencentyun/cos-go-sdk-v5 v0.7.39 - github.com/winc-link/edge-driver-proto v0.0.0-20230208100708-287ba270a685 + github.com/winc-link/edge-driver-proto v0.0.0-20231023113502-daf15ee41883 github.com/xuri/excelize/v2 v2.5.0 go.uber.org/atomic v1.9.0 go.uber.org/zap v1.21.0 diff --git a/internal/hummingbird/core/application/dmi/docker/manager.go b/internal/hummingbird/core/application/dmi/docker/manager.go index 400da50..71717ea 100644 --- a/internal/hummingbird/core/application/dmi/docker/manager.go +++ b/internal/hummingbird/core/application/dmi/docker/manager.go @@ -50,10 +50,17 @@ type DockerManager struct { } type CustomParams struct { - env []string - runtime string - net string - mnt []string + user string + cpuShares int64 + memory int64 + memorySwap int64 + dns []string + dnsSearch []string + restart string + env []string + runtime string + mnt []string + port []string } // 镜像信息 @@ -184,7 +191,7 @@ func (dm *DockerManager) ContainerStart(imageRepo string, containerName string, //exposedPorts, portMap := dm.makeExposedPorts(exposePorts) //resourceDevices := dm.makeMountDevices(mountDevices) binds := make([]string, 0) - binds = append(binds, "/etc/localtime:/etc/localtime:ro") // 挂载时区 + //binds = append(binds, "/etc/localtime:/etc/localtime:ro") // 挂载时区 var thisRunMode container.NetworkMode if instanceType == constants.CloudInstance { @@ -205,16 +212,26 @@ func (dm *DockerManager) ContainerStart(imageRepo string, containerName string, dm.lc.Infof("binds: %+v", binds) dm.lc.Infof("Image:%+v", dm.ImageMap[imageRepo]) dm.lc.Infof("thisRunMode:%+v", string(thisRunMode)) - _, cErr := dm.cli.ContainerCreate(dm.ctx, &container.Config{ - Image: imageRepo, - Env: dockerCustomParams.env, + + portMap := generateExposedPorts(dockerCustomParams.port) + restartPolicy := generateRestartPolicy(dockerCustomParams.restart) + resources := generateResources(dockerCustomParams.cpuShares, dockerCustomParams.memory, dockerCustomParams.memorySwap) + + var _, cErr = dm.cli.ContainerCreate(dm.ctx, &container.Config{ + OpenStdin: true, + Tty: true, + User: dockerCustomParams.user, + Image: imageRepo, + Env: dockerCustomParams.env, }, &container.HostConfig{ - Binds: binds, - NetworkMode: thisRunMode, - RestartPolicy: container.RestartPolicy{ - MaximumRetryCount: 10, - }, - Runtime: dockerCustomParams.runtime, + DNS: dockerCustomParams.dns, + DNSSearch: dockerCustomParams.dnsSearch, + Resources: resources, + Binds: binds, + PortBindings: portMap, + NetworkMode: thisRunMode, + RestartPolicy: restartPolicy, + Runtime: dockerCustomParams.runtime, }, &network.NetworkingConfig{}, nil, containerName) if cErr != nil { return "", cErr @@ -237,46 +254,75 @@ func (dm *DockerManager) ContainerStart(imageRepo string, containerName string, return "", errort.NewCommonEdgeX(errort.DefaultSystemError, "GetContainerRunStatus Fail", err) } if status != constants.ContainerRunStatusRunning { - err = errort.NewCommonEdgeX(errort.ContainerRunFail, fmt.Sprintf("%s container status %s", containerName, status), nil) + err = errort.NewCommonEdgeX(errort.ContainerRunFail, fmt.Sprintf("%s container status %s please check the log for specific details", containerName, status), nil) + return } if thisRunMode.IsHost() { ip = constants.HostAddress } else { ip, err = dm.GetContainerIp(containerName) } - return } -// 端口导出组装 -func (dm *DockerManager) makeExposedPorts(exposePorts []int) (nat.PortSet, nat.PortMap) { - portMap := make(nat.PortMap) - exposedPorts := make(nat.PortSet, 0) - var empty struct{} - for _, p := range exposePorts { - tmpPort, _ := nat.NewPort("tcp", strconv.Itoa(p)) - portMap[tmpPort] = []nat.PortBinding{ - { - HostIP: "", - HostPort: strconv.Itoa(p), - }, +func generateRestartPolicy(restart string) container.RestartPolicy { + if restart != "" { + ls := strings.Split(restart, ":") + if len(ls) == 2 { + maximumRetryCount, err := strconv.Atoi(ls[1]) + if err != nil { + maximumRetryCount = 0 + } + return container.RestartPolicy{ + Name: ls[0], + MaximumRetryCount: maximumRetryCount, + } + } else if len(ls) == 1 { + return container.RestartPolicy{ + Name: ls[0], + } } - exposedPorts[tmpPort] = empty } - return exposedPorts, portMap + return container.RestartPolicy{} } -// 挂载设备组装 -func (dm *DockerManager) makeMountDevices(devices []string) container.Resources { - resourceDevices := make([]container.DeviceMapping, 0) - for _, v := range devices { - resourceDevices = append(resourceDevices, container.DeviceMapping{ - PathOnHost: v, - PathInContainer: v, - CgroupPermissions: "rwm", - }) +func generateResources(cpuShares, memory, memorySwap int64) container.Resources { + return container.Resources{ + CPUShares: cpuShares, + Memory: memory, + MemorySwap: memorySwap, } - return container.Resources{Devices: resourceDevices} +} + +// makeExposedPorts pots => [8080:8080/tcp 8090:8090/udp] +func generateExposedPorts(ports []string) nat.PortMap { + portMap := make(nat.PortMap) + for _, port := range ports { + + proto := "tcp" + + sp := strings.Split(port, ":") + if len(sp) != 2 { + return portMap + } + + parsePortRange := strings.Split(sp[1], "/") + + if len(parsePortRange) == 2 { + proto = strings.ToLower(parsePortRange[1]) + if proto != "tcp" && proto != "udp" { + continue + } + } + tmpPort, _ := nat.NewPort(proto, parsePortRange[0]) + portMap[tmpPort] = []nat.PortBinding{ + { + HostIP: "0.0.0.0", + HostPort: sp[0], + }, + } + } + return portMap } func (dm *DockerManager) ContainerStop(containerIdOrName string) error { @@ -498,16 +544,17 @@ func (dm *DockerManager) GetAuthToken(username string, password string, serverAd // 自定义docker启动参数解析 func (dm *DockerManager) ParseCustomParams(cmd string) (CustomParams, error) { - runMode := dm.dcm.DockerManageConfig.DockerRunMode - if !utils.InStringSlice(runMode, []string{ - constants.NetworkModeHost, constants.NetworkModeBridge, - }) { - runMode = constants.NetworkModeHost - } params := CustomParams{ - runtime: "", - env: []string{}, - net: runMode, + user: "", + cpuShares: 0, + memory: 0, + dns: []string{}, + dnsSearch: []string{}, + restart: "", + env: []string{}, + runtime: "", + port: []string{}, + mnt: []string{}, } if cmd == "" { return params, nil @@ -518,32 +565,25 @@ func (dm *DockerManager) ParseCustomParams(cmd string) (CustomParams, error) { for _, v := range strArr { args = append(args, strings.Split(v, " ")...) } - f := flag.NewFlagSet("edge-flag", flag.ContinueOnError) - f.StringVarP(¶ms.runtime, "runtime", "", "", "") - f.StringVarP(¶ms.net, "net", "", runMode, "") + f := flag.NewFlagSet("", flag.ContinueOnError) + f.StringVarP(¶ms.user, "user", "u", "", "") + f.Int64VarP(¶ms.cpuShares, "cpu-shares", "c", 0, "") + f.Int64VarP(¶ms.memory, "memory", "m", 0, "") + f.Int64VarP(¶ms.memorySwap, "memory-swap", "", -1, "") + f.StringArrayVarP(¶ms.dns, "dns", "", []string{}, "") + f.StringArrayVarP(¶ms.dnsSearch, "dns-search", "", []string{}, "") + f.StringVarP(¶ms.restart, "restart", "", "on-failure:10", "") f.StringArrayVarP(¶ms.env, "env", "e", []string{}, "") - f.StringArrayVarP(¶ms.mnt, "mnt", "v", []string{}, "") + f.StringVarP(¶ms.runtime, "runtime", "", "", "") + f.StringArrayVarP(¶ms.port, "publish", "p", []string{}, "") + f.StringArrayVarP(¶ms.mnt, "volume", "v", []string{}, "") + err := f.Parse(args) if err != nil { return CustomParams{}, errort.NewCommonErr(errort.DockerParamsParseErr, fmt.Errorf("parse docker params err:%v", err)) } - - // 内部限制只支持host和bridge两种模式 - if !utils.InStringSlice(params.net, []string{ - constants.NetworkModeHost, constants.NetworkModeBridge, - }) { - params.net = runMode - } return params, nil -} -// 自定义docker启动参数解析 -func (dm *DockerManager) ParseCustomParamsIsRunBridge(cmd string) (bool, error) { - params, err := dm.ParseCustomParams(cmd) - if err != nil { - return false, err - } - return container.NetworkMode(params.net).IsBridge(), nil } func (dm *DockerManager) GetAllImagesIds() []string { diff --git a/internal/models/device.go b/internal/models/device.go index 95f1313..70508a4 100644 --- a/internal/models/device.go +++ b/internal/models/device.go @@ -72,6 +72,7 @@ func (table *Device) TransformToDriverDevice() *driverdevice.Device { driverDevice.Description = table.Description driverDevice.Status = table.Status.TransformToDriverDeviceStatus() driverDevice.ProductId = table.ProductId + driverDevice.Secret = table.Secret driverDevice.Platform = table.Platform.TransformToDriverDevicePlatform() return driverDevice } diff --git a/internal/pkg/jwtauth/jwtauth.go b/internal/pkg/jwtauth/jwtauth.go new file mode 100644 index 0000000..8900262 --- /dev/null +++ b/internal/pkg/jwtauth/jwtauth.go @@ -0,0 +1,750 @@ +package jwtauth + +import ( + "crypto/rsa" + "errors" + "io/ioutil" + "net/http" + "strings" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/gin-gonic/gin" +) + +const JwtPayloadKey = "JWT_PAYLOAD" + +type MapClaims map[string]interface{} + +// GinJWTMiddleware provides a Json-Web-Token authentication implementation. On failure, a 401 HTTP response +// is returned. On success, the wrapped middleware is called, and the userID is made available as +// c.Get("userID").(string). +// Users can get a token by posting a json request to LoginHandler. The token then needs to be passed in +// the Authentication header. Example: Authorization:Bearer XXX_TOKEN_XXX +type GinJWTMiddleware struct { + // Realm name to display to the user. Required. + Realm string + + // signing algorithm - possible values are HS256, HS384, HS512 + // Optional, default is HS256. + SigningAlgorithm string + + // Secret key used for signing. Required. + Key []byte + + // Duration that a jwt token is valid. Optional, defaults to one hour. + Timeout time.Duration + + // This field allows clients to refresh their token until MaxRefresh has passed. + // Note that clients can refresh their token in the last moment of MaxRefresh. + // This means that the maximum validity timespan for a token is TokenTime + MaxRefresh. + // Optional, defaults to 0 meaning not refreshable. + MaxRefresh time.Duration + + // Callback function that should perform the authentication of the user based on login info. + // Must return user data as user identifier, it will be stored in Claim Array. Required. + // Check error (e) to determine the appropriate error message. + Authenticator func(c *gin.Context) (interface{}, error) + + // Callback function that should perform the authorization of the authenticated user. Called + // only after an authentication success. Must return true on success, false on failure. + // Optional, default to success. + Authorizator func(data interface{}, c *gin.Context) bool + + // Callback function that will be called during login. + // Using this function it is possible to add additional payload data to the webtoken. + // The data is then made available during requests via c.Get("JWT_PAYLOAD"). + // Note that the payload is not encrypted. + // The attributes mentioned on jwt.io can't be used as keys for the map. + // Optional, by default no additional data will be set. + PayloadFunc func(data interface{}) MapClaims + + // User can define own Unauthorized func. + Unauthorized func(*gin.Context, int, string) + + // User can define own LoginResponse func. + LoginResponse func(*gin.Context, int, string, time.Time) + + // User can define own RefreshResponse func. + RefreshResponse func(*gin.Context, int, string, time.Time) + + // Set the identity handler function + IdentityHandler func(*gin.Context) interface{} + + // Set the identity key + IdentityKey string + + // username + NiceKey string + + DataScopeKey string + + // rolekey + RKey string + + // roleId + RoleIdKey string + + RoleKey string + + // roleName + RoleNameKey string + + // TokenLookup is a string in the form of ":" that is used + // to extract token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" + // - "query:" + // - "cookie:" + TokenLookup string + + // TokenHeadName is a string in the header. Default value is "Bearer" + TokenHeadName string + + // TimeFunc provides the current time. You can override it to use another time value. This is useful for testing or if your server uses a different time zone than your tokens. + TimeFunc func() time.Time + + // HTTP Status messages for when something in the JWT middleware fails. + // Check error (e) to determine the appropriate error message. + HTTPStatusMessageFunc func(e error, c *gin.Context) string + + // Private key file for asymmetric algorithms + PrivKeyFile string + + // Public key file for asymmetric algorithms + PubKeyFile string + + // Private key + privKey *rsa.PrivateKey + + // Public key + pubKey *rsa.PublicKey + + // Optionally return the token as a cookie + SendCookie bool + + // Allow insecure cookies for development over http + SecureCookie bool + + // Allow cookies to be accessed client side for development + CookieHTTPOnly bool + + // Allow cookie domain change for development + CookieDomain string + + // SendAuthorization allow return authorization header for every request + SendAuthorization bool + + // Disable abort() of context. + DisabledAbort bool + + // CookieName allow cookie name change for development + CookieName string +} + +var ( + // ErrMissingSecretKey indicates Secret key is required + ErrMissingSecretKey = errors.New("secret key is required") + + // ErrForbidden when HTTP status 403 is given + ErrForbidden = errors.New("you don't have permission to access this resource") + + // ErrMissingAuthenticatorFunc indicates Authenticator is required + ErrMissingAuthenticatorFunc = errors.New("ginJWTMiddleware.Authenticator func is undefined") + + // ErrMissingLoginValues indicates a user tried to authenticate without username or password + ErrMissingLoginValues = errors.New("missing Username or Password or Code") + + // ErrFailedAuthentication indicates authentication failed, could be faulty username or password + ErrFailedAuthentication = errors.New("incorrect Username or Password") + + // ErrFailedTokenCreation indicates JWT Token failed to create, reason unknown + ErrFailedTokenCreation = errors.New("failed to create JWT Token") + + // ErrExpiredToken indicates JWT token has expired. Can't refresh. + ErrExpiredToken = errors.New("token is expired") + + // ErrEmptyAuthHeader can be thrown if authing with a HTTP header, the Auth header needs to be set + ErrEmptyAuthHeader = errors.New("auth header is empty") + + // ErrMissingExpField missing exp field in token + ErrMissingExpField = errors.New("missing exp field") + + // ErrWrongFormatOfExp field must be float64 format + ErrWrongFormatOfExp = errors.New("exp must be float64 format") + + // ErrInvalidAuthHeader indicates auth header is invalid, could for example have the wrong Realm name + ErrInvalidAuthHeader = errors.New("auth header is invalid") + + // ErrEmptyQueryToken can be thrown if authing with URL Query, the query token variable is empty + ErrEmptyQueryToken = errors.New("query token is empty") + + // ErrEmptyCookieToken can be thrown if authing with a cookie, the token cokie is empty + ErrEmptyCookieToken = errors.New("cookie token is empty") + + // ErrEmptyParamToken can be thrown if authing with parameter in path, the parameter in path is empty + ErrEmptyParamToken = errors.New("parameter token is empty") + + // ErrInvalidSigningAlgorithm indicates signing algorithm is invalid, needs to be HS256, HS384, HS512, RS256, RS384 or RS512 + ErrInvalidSigningAlgorithm = errors.New("invalid signing algorithm") + + ErrInvalidVerificationode = errors.New("验证码错误") + + // ErrNoPrivKeyFile indicates that the given private key is unreadable + ErrNoPrivKeyFile = errors.New("private key file unreadable") + + // ErrNoPubKeyFile indicates that the given public key is unreadable + ErrNoPubKeyFile = errors.New("public key file unreadable") + + // ErrInvalidPrivKey indicates that the given private key is invalid + ErrInvalidPrivKey = errors.New("private key invalid") + + // ErrInvalidPubKey indicates the the given public key is invalid + ErrInvalidPubKey = errors.New("public key invalid") + + // IdentityKey default identity key + IdentityKey = "identity" + + // NiceKey 昵称 + NiceKey = "nice" + DataScopeKey = "datascope" + + RKey = "r" + + // RoleIdKey 角色id Old + RoleIdKey = "roleid" + + // RoleKey 角色名称 Old + RoleKey = "rolekey" + + // RoleNameKey 角色名称 Old + RoleNameKey = "rolename" + + // RoleIdKey 部门id + DeptId = "deptId" + + // RoleKey 部门名称 + DeptName = "deptName" +) + +// New for check error with GinJWTMiddleware +func New(m *GinJWTMiddleware) (*GinJWTMiddleware, error) { + if err := m.MiddlewareInit(); err != nil { + return nil, err + } + + return m, nil +} + +func (mw *GinJWTMiddleware) readKeys() error { + err := mw.privateKey() + if err != nil { + return err + } + err = mw.publicKey() + if err != nil { + return err + } + return nil +} + +func (mw *GinJWTMiddleware) privateKey() error { + keyData, err := ioutil.ReadFile(mw.PrivKeyFile) + if err != nil { + return ErrNoPrivKeyFile + } + key, err := jwt.ParseRSAPrivateKeyFromPEM(keyData) + if err != nil { + return ErrInvalidPrivKey + } + mw.privKey = key + return nil +} + +func (mw *GinJWTMiddleware) publicKey() error { + keyData, err := ioutil.ReadFile(mw.PubKeyFile) + if err != nil { + return ErrNoPubKeyFile + } + key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) + if err != nil { + return ErrInvalidPubKey + } + mw.pubKey = key + return nil +} + +func (mw *GinJWTMiddleware) usingPublicKeyAlgo() bool { + switch mw.SigningAlgorithm { + case "RS256", "RS512", "RS384": + return true + } + return false +} + +// MiddlewareInit initialize jwt configs. +func (mw *GinJWTMiddleware) MiddlewareInit() error { + + if mw.TokenLookup == "" { + mw.TokenLookup = "header:Authorization" + } + + if mw.SigningAlgorithm == "" { + mw.SigningAlgorithm = "HS256" + } + + if mw.TimeFunc == nil { + mw.TimeFunc = time.Now + } + + mw.TokenHeadName = strings.TrimSpace(mw.TokenHeadName) + if len(mw.TokenHeadName) == 0 { + mw.TokenHeadName = "Bearer" + } + + if mw.Authorizator == nil { + mw.Authorizator = func(data interface{}, c *gin.Context) bool { + return true + } + } + + if mw.Unauthorized == nil { + mw.Unauthorized = func(c *gin.Context, code int, message string) { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": message, + }) + } + } + + if mw.LoginResponse == nil { + mw.LoginResponse = func(c *gin.Context, code int, token string, expire time.Time) { + c.JSON(http.StatusOK, gin.H{ + "code": http.StatusOK, + "token": token, + "expire": expire.Format(time.RFC3339), + }) + } + } + + if mw.RefreshResponse == nil { + mw.RefreshResponse = func(c *gin.Context, code int, token string, expire time.Time) { + c.JSON(http.StatusOK, gin.H{ + "code": http.StatusOK, + "token": token, + "expire": expire.Format(time.RFC3339), + }) + } + } + + if mw.IdentityKey == "" { + mw.IdentityKey = IdentityKey + } + + if mw.IdentityHandler == nil { + mw.IdentityHandler = func(c *gin.Context) interface{} { + claims := ExtractClaims(c) + return claims + } + } + + if mw.HTTPStatusMessageFunc == nil { + mw.HTTPStatusMessageFunc = func(e error, c *gin.Context) string { + return e.Error() + } + } + + if mw.Realm == "" { + mw.Realm = "gin jwt" + } + + if mw.CookieName == "" { + mw.CookieName = "jwt" + } + + if mw.usingPublicKeyAlgo() { + return mw.readKeys() + } + + if mw.Key == nil { + return ErrMissingSecretKey + } + return nil +} + +// MiddlewareFunc makes GinJWTMiddleware implement the Middleware interface. +func (mw *GinJWTMiddleware) MiddlewareFunc() gin.HandlerFunc { + return func(c *gin.Context) { + mw.middlewareImpl(c) + } +} + +func (mw *GinJWTMiddleware) middlewareImpl(c *gin.Context) { + claims, err := mw.GetClaimsFromJWT(c) + if err != nil { + mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, c)) + return + } + + if claims["exp"] == nil { + mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, c)) + return + } + + if _, ok := claims["exp"].(float64); !ok { + mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, c)) + return + } + if int64(claims["exp"].(float64)) < mw.TimeFunc().Unix() { + mw.unauthorized(c, 6401, mw.HTTPStatusMessageFunc(ErrExpiredToken, c)) + return + } + + c.Set(JwtPayloadKey, claims) + identity := mw.IdentityHandler(c) + + if identity != nil { + c.Set(mw.IdentityKey, identity) + } + + if !mw.Authorizator(identity, c) { + mw.unauthorized(c, http.StatusForbidden, mw.HTTPStatusMessageFunc(ErrForbidden, c)) + return + } + + c.Next() +} + +// GetClaimsFromJWT get claims from JWT token +func (mw *GinJWTMiddleware) GetClaimsFromJWT(c *gin.Context) (MapClaims, error) { + token, err := mw.ParseToken(c) + + if err != nil { + return nil, err + } + + if mw.SendAuthorization { + if v, ok := c.Get("JWT_TOKEN"); ok { + c.Header("Authorization", mw.TokenHeadName+" "+v.(string)) + } + } + + claims := MapClaims{} + for key, value := range token.Claims.(jwt.MapClaims) { + claims[key] = value + } + + return claims, nil +} + +// LoginHandler can be used by clients to get a jwt token. +// Payload needs to be json in the form of {"username": "USERNAME", "password": "PASSWORD"}. +// Reply will be of the form {"token": "TOKEN"}. +func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) { + if mw.Authenticator == nil { + mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(ErrMissingAuthenticatorFunc, c)) + return + } + + data, err := mw.Authenticator(c) + + if err != nil { + mw.unauthorized(c, 400, mw.HTTPStatusMessageFunc(err, c)) + return + } + + // Create the token + token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm)) + claims := token.Claims.(jwt.MapClaims) + + if mw.PayloadFunc != nil { + for key, value := range mw.PayloadFunc(data) { + claims[key] = value + } + } + + expire := mw.TimeFunc().Add(mw.Timeout) + claims["exp"] = expire.Unix() + claims["orig_iat"] = mw.TimeFunc().Unix() + tokenString, err := mw.signedString(token) + + if err != nil { + mw.unauthorized(c, http.StatusOK, mw.HTTPStatusMessageFunc(ErrFailedTokenCreation, c)) + return + } + + // set cookie + if mw.SendCookie { + maxage := int(expire.Unix() - time.Now().Unix()) + c.SetCookie( + mw.CookieName, + tokenString, + maxage, + "/", + mw.CookieDomain, + mw.SecureCookie, + mw.CookieHTTPOnly, + ) + } + + mw.LoginResponse(c, http.StatusOK, tokenString, expire) +} + +func (mw *GinJWTMiddleware) signedString(token *jwt.Token) (string, error) { + var tokenString string + var err error + if mw.usingPublicKeyAlgo() { + tokenString, err = token.SignedString(mw.privKey) + } else { + tokenString, err = token.SignedString(mw.Key) + } + return tokenString, err +} + +// RefreshHandler can be used to refresh a token. The token still needs to be valid on refresh. +// Shall be put under an endpoint that is using the GinJWTMiddleware. +// Reply will be of the form {"token": "TOKEN"}. +func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) { + tokenString, expire, err := mw.RefreshToken(c) + if err != nil { + mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, c)) + return + } + + mw.RefreshResponse(c, http.StatusOK, tokenString, expire) +} + +// RefreshToken refresh token and check if token is expired +func (mw *GinJWTMiddleware) RefreshToken(c *gin.Context) (string, time.Time, error) { + claims, err := mw.CheckIfTokenExpire(c) + if err != nil { + return "", time.Now(), err + } + + // Create the token + newToken := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm)) + newClaims := newToken.Claims.(jwt.MapClaims) + + for key := range claims { + newClaims[key] = claims[key] + } + + expire := mw.TimeFunc().Add(mw.Timeout) + newClaims["exp"] = expire.Unix() + newClaims["orig_iat"] = mw.TimeFunc().Unix() + tokenString, err := mw.signedString(newToken) + + if err != nil { + return "", time.Now(), err + } + + // set cookie + if mw.SendCookie { + maxage := int(expire.Unix() - time.Now().Unix()) + c.SetCookie( + mw.CookieName, + tokenString, + maxage, + "/", + mw.CookieDomain, + mw.SecureCookie, + mw.CookieHTTPOnly, + ) + } + + return tokenString, expire, nil +} + +// CheckIfTokenExpire check if token expire +func (mw *GinJWTMiddleware) CheckIfTokenExpire(c *gin.Context) (jwt.MapClaims, error) { + token, err := mw.ParseToken(c) + + if err != nil { + // If we receive an error, and the error is anything other than a single + // ValidationErrorExpired, we want to return the error. + // If the error is just ValidationErrorExpired, we want to continue, as we can still + // refresh the token if it's within the MaxRefresh time. + // (see https://github.com/appleboy/gin-jwt/issues/176) + validationErr, ok := err.(*jwt.ValidationError) + if !ok || validationErr.Errors != jwt.ValidationErrorExpired { + return nil, err + } + } + + claims := token.Claims.(jwt.MapClaims) + + origIat := int64(claims["orig_iat"].(float64)) + + if origIat < mw.TimeFunc().Add(-mw.MaxRefresh).Unix() { + return nil, ErrExpiredToken + } + + return claims, nil +} + +// TokenGenerator method that clients can use to get a jwt token. +func (mw *GinJWTMiddleware) TokenGenerator(data interface{}) (string, time.Time, error) { + token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm)) + claims := token.Claims.(jwt.MapClaims) + + if mw.PayloadFunc != nil { + for key, value := range mw.PayloadFunc(data) { + claims[key] = value + } + } + + expire := mw.TimeFunc().UTC().Add(mw.Timeout) + claims["exp"] = expire.Unix() + claims["orig_iat"] = mw.TimeFunc().Unix() + tokenString, err := mw.signedString(token) + if err != nil { + return "", time.Time{}, err + } + + return tokenString, expire, nil +} + +func (mw *GinJWTMiddleware) jwtFromHeader(c *gin.Context, key string) (string, error) { + authHeader := c.Request.Header.Get(key) + + if authHeader == "" { + return "", ErrEmptyAuthHeader + } + + parts := strings.SplitN(authHeader, " ", 2) + if !(len(parts) == 2 && parts[0] == mw.TokenHeadName) { + return "", ErrInvalidAuthHeader + } + + return parts[1], nil +} + +func (mw *GinJWTMiddleware) jwtFromQuery(c *gin.Context, key string) (string, error) { + token := c.Query(key) + + if token == "" { + return "", ErrEmptyQueryToken + } + + return token, nil +} + +func (mw *GinJWTMiddleware) jwtFromCookie(c *gin.Context, key string) (string, error) { + cookie, _ := c.Cookie(key) + + if cookie == "" { + return "", ErrEmptyCookieToken + } + + return cookie, nil +} + +func (mw *GinJWTMiddleware) jwtFromParam(c *gin.Context, key string) (string, error) { + token := c.Param(key) + + if token == "" { + return "", ErrEmptyParamToken + } + + return token, nil +} + +// ParseToken parse jwt token from gin context +func (mw *GinJWTMiddleware) ParseToken(c *gin.Context) (*jwt.Token, error) { + var token string + var err error + + methods := strings.Split(mw.TokenLookup, ",") + for _, method := range methods { + if len(token) > 0 { + break + } + parts := strings.Split(strings.TrimSpace(method), ":") + k := strings.TrimSpace(parts[0]) + v := strings.TrimSpace(parts[1]) + switch k { + case "header": + token, err = mw.jwtFromHeader(c, v) + case "query": + token, err = mw.jwtFromQuery(c, v) + case "cookie": + token, err = mw.jwtFromCookie(c, v) + case "param": + token, err = mw.jwtFromParam(c, v) + } + } + + if err != nil { + return nil, err + } + + return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { + if jwt.GetSigningMethod(mw.SigningAlgorithm) != t.Method { + return nil, ErrInvalidSigningAlgorithm + } + if mw.usingPublicKeyAlgo() { + return mw.pubKey, nil + } + c.Set("JWT_TOKEN", token) + + return mw.Key, nil + }) +} + +// ParseTokenString parse jwt token string +func (mw *GinJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error) { + return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { + if jwt.GetSigningMethod(mw.SigningAlgorithm) != t.Method { + return nil, ErrInvalidSigningAlgorithm + } + if mw.usingPublicKeyAlgo() { + return mw.pubKey, nil + } + + return mw.Key, nil + }) +} + +func (mw *GinJWTMiddleware) unauthorized(c *gin.Context, code int, message string) { + c.Header("WWW-Authenticate", "JWT realm="+mw.Realm) + if !mw.DisabledAbort { + c.Abort() + } + + mw.Unauthorized(c, code, message) +} + +// ExtractClaims help to extract the JWT claims +func ExtractClaims(c *gin.Context) MapClaims { + claims, exists := c.Get(JwtPayloadKey) + if !exists { + return make(MapClaims) + } + + return claims.(MapClaims) +} + +// ExtractClaimsFromToken help to extract the JWT claims from token +func ExtractClaimsFromToken(token *jwt.Token) MapClaims { + if token == nil { + return make(MapClaims) + } + + claims := MapClaims{} + for key, value := range token.Claims.(jwt.MapClaims) { + claims[key] = value + } + + return claims +} + +// GetToken help to get the JWT token string +func GetToken(c *gin.Context) string { + token, exists := c.Get("JWT_TOKEN") + if !exists { + return "" + } + + return token.(string) +} diff --git a/internal/pkg/jwtauth/user/user.go b/internal/pkg/jwtauth/user/user.go new file mode 100644 index 0000000..9bd6d37 --- /dev/null +++ b/internal/pkg/jwtauth/user/user.go @@ -0,0 +1,92 @@ +package user + +import ( + "fmt" + "time" + + "github.com/gin-gonic/gin" + jwt "github.com/winc-link/hummingbird/internal/pkg/jwtauth" +) + +func ExtractClaims(c *gin.Context) jwt.MapClaims { + claims, exists := c.Get(jwt.JwtPayloadKey) + if !exists { + return make(jwt.MapClaims) + } + + return claims.(jwt.MapClaims) +} + +func Get(c *gin.Context, key string) interface{} { + data := ExtractClaims(c) + if data[key] != nil { + return data[key] + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " Get 缺少 " + key) + return nil +} + +func GetUserId(c *gin.Context) string { + data := ExtractClaims(c) + if data["identity"] != nil { + return (data["identity"]).(string) + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " GetUserId 缺少 identity") + return "" +} + +func GetUserIdStr(c *gin.Context) string { + data := ExtractClaims(c) + if data["identity"] != nil { + return (data["identity"]).(string) + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " GetUserIdStr 缺少 identity") + return "" +} + +func GetUserName(c *gin.Context) string { + data := ExtractClaims(c) + if data["nice"] != nil { + return (data["nice"]).(string) + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " GetUserName 缺少 nice") + return "" +} + +func GetRoleName(c *gin.Context) string { + data := ExtractClaims(c) + if data["rolekey"] != nil { + return (data["rolekey"]).(string) + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " GetRoleName 缺少 rolekey") + return "" +} + +func GetRoleId(c *gin.Context) string { + data := ExtractClaims(c) + if data["roleid"] != nil { + i := (data["roleid"]).(string) + return i + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " GetRoleId 缺少 roleid") + return "" +} + +func GetDeptId(c *gin.Context) string { + data := ExtractClaims(c) + if data["deptid"] != nil { + i := (data["deptid"]).(string) + return i + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " GetDeptId 缺少 deptid") + return "" +} + +func GetDeptName(c *gin.Context) string { + data := ExtractClaims(c) + if data["deptkey"] != nil { + return (data["deptkey"]).(string) + } + fmt.Println(time.Now().Format("2006-01-02 15:04:05") + " [WARING] " + c.Request.Method + " " + c.Request.URL.Path + " GetDeptName 缺少 deptkey") + return "" +} diff --git a/internal/pkg/utils/uuid.go b/internal/pkg/utils/uuid.go index cb8f56e..f0e0c55 100644 --- a/internal/pkg/utils/uuid.go +++ b/internal/pkg/utils/uuid.go @@ -15,8 +15,6 @@ func GenUUID() string { } func RandomNum() string { - //n, _ := rand.Int(rand.Reader, big.NewInt(100000000)) - //return n.String() return fmt.Sprintf("%08v", rand2.New(rand2.NewSource(time.Now().UnixNano())).Int31n(100000000)) } diff --git a/manifest/docker/docker-compose.yml b/manifest/docker/docker-compose.yml index b28986f..4d47bdd 100644 --- a/manifest/docker/docker-compose.yml +++ b/manifest/docker/docker-compose.yml @@ -31,7 +31,7 @@ services: networks: - hummingbird hummingbird-core: - image: registry.cn-shanghai.aliyuncs.com/winc-link/hummingbird-core:v1.0 + image: registry.cn-shanghai.aliyuncs.com/winc-link/hummingbird-core:1.0 container_name: hummingbird-core restart: always hostname: hummingbird-core