iot_server/internal/pkg/packets/packets.go

681 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package packets
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"unicode/utf8"
"github.com/winc-link/hummingbird/internal/pkg/codes"
)
// Error type
var (
ErrInvalUTF8String = errors.New("invalid utf-8 string")
)
// MQTT Version
type Version = byte
type QoS = byte
var version2protoName = map[Version][]byte{
Version31: {'M', 'Q', 'I', 's', 'd', 'p'},
Version311: {'M', 'Q', 'T', 'T'},
Version5: {'M', 'Q', 'T', 'T'},
}
const (
Version31 Version = 0x03
Version311 Version = 0x04
Version5 Version = 0x05
// The maximum packet size of a MQTT packet
MaximumSize = 268435456
)
//Packet type
const (
RESERVED = iota
CONNECT
CONNACK
PUBLISH
PUBACK
PUBREC
PUBREL
PUBCOMP
SUBSCRIBE
SUBACK
UNSUBSCRIBE
UNSUBACK
PINGREQ
PINGRESP
DISCONNECT
AUTH
)
// QoS levels & Subscribe failure
const (
Qos0 uint8 = 0x00
Qos1 uint8 = 0x01
Qos2 uint8 = 0x02
SubscribeFailure = 0x80
)
// Flag in the FixHeader
const (
FlagReserved = 0
FlagSubscribe = 2
FlagUnsubscribe = 2
FlagPubrel = 2
)
//PacketID is the type of packet identifier
type PacketID = uint16
//Max & min packet ID
const (
MaxPacketID PacketID = 65535
MinPacketID PacketID = 1
)
type PayloadFormat = byte
const (
PayloadFormatBytes PayloadFormat = 0
PayloadFormatString PayloadFormat = 1
)
func IsVersion3X(v Version) bool {
return v == Version311 || v == Version31
}
func IsVersion5(v Version) bool {
return v == Version5
}
// FixHeader represents the FixHeader of the MQTT packet
type FixHeader struct {
PacketType byte
Flags byte
RemainLength int
}
// Packet defines the interface for structs intended to hold
// decoded MQTT packets, either from being read or before being
// written
type Packet interface {
// Pack encodes the packet struct into bytes and writes it into io.Writer.
Pack(w io.Writer) error
// Unpack read the packet bytes from io.Reader and decodes it into the packet struct
Unpack(r io.Reader) error
// String is mainly used in logging, debugging and testing.
String() string
}
// Topic represents the MQTT Topic
type Topic struct {
SubOptions
Name string
}
// SubOptions is the subscription option of subscriptions.
// For details: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Subscription_Options
type SubOptions struct {
// Qos is the QoS level of the subscription.
// 0 = At most once delivery
// 1 = At least once delivery
// 2 = Exactly once delivery
Qos uint8
// RetainHandling specifies whether retained messages are sent when the subscription is established.
// 0 = Send retained messages at the time of the subscribe
// 1 = Send retained messages at subscribe only if the subscription does not currently exist
// 2 = Do not send retained messages at the time of the subscribe
RetainHandling byte
// NoLocal is the No Local option.
// If the value is 1, Application Messages MUST NOT be forwarded to a connection with a ClientID equal to the ClientID of the publishing connection
NoLocal bool
// RetainAsPublished is the Retain As Published option.
// If 1, Application Messages forwarded using this subscription keep the RETAIN flag they were published with.
// If 0, Application Messages forwarded using this subscription have the RETAIN flag set to 0. Retained messages sent when the subscription is established have the RETAIN flag set to 1.
RetainAsPublished bool
}
// Reader is used to read data from bufio.Reader and create MQTT packet instance.
type Reader struct {
bufr *bufio.Reader
version Version
}
// Writer is used to encode MQTT packet into bytes and write it to bufio.Writer.
type Writer struct {
bufw *bufio.Writer
}
// Flush writes any buffered data to the underlying io.Writer.
func (w *Writer) Flush() error {
return w.bufw.Flush()
}
// ReadWriter warps Reader and Writer.
type ReadWriter struct {
*Reader
*Writer
}
// NewReader returns a new Reader.
func NewReader(r io.Reader) *Reader {
if bufr, ok := r.(*bufio.Reader); ok {
return &Reader{bufr: bufr, version: Version311}
}
return &Reader{bufr: bufio.NewReaderSize(r, 2048), version: Version311}
}
func (r *Reader) SetVersion(version Version) {
r.version = version
}
// NewWriter returns a new Writer.
func NewWriter(w io.Writer) *Writer {
if bufw, ok := w.(*bufio.Writer); ok {
return &Writer{bufw: bufw}
}
return &Writer{bufw: bufio.NewWriterSize(w, 2048)}
}
// ReadPacket reads data from Reader and returns a Packet instance.
// If any errors occurs, returns nil, error
func (r *Reader) ReadPacket() (Packet, error) {
first, err := r.bufr.ReadByte()
if err != nil {
return nil, err
}
fh := &FixHeader{PacketType: first >> 4, Flags: first & 15} //设置FixHeader
length, err := EncodeRemainLength(r.bufr)
if err != nil {
return nil, err
}
fh.RemainLength = length
packet, err := NewPacket(fh, r.version, r.bufr)
if err != nil {
return nil, err
}
if p, ok := packet.(*Connect); ok {
r.version = p.Version
}
return packet, err
}
// WritePacket writes the packet bytes to the Writer.
// Call Flush after WritePacket to flush buffered data to the underlying io.Writer.
func (w *Writer) WritePacket(packet Packet) error {
err := packet.Pack(w.bufw)
if err != nil {
return err
}
return nil
}
// WriteRaw write raw bytes to the Writer.
// Call Flush after WriteRaw to flush buffered data to the underlying io.Writer.
func (w *Writer) WriteRaw(b []byte) error {
_, err := w.bufw.Write(b)
if err != nil {
return err
}
return nil
}
// WriteAndFlush writes and flush the packet bytes to the underlying io.Writer.
func (w *Writer) WriteAndFlush(packet Packet) error {
err := packet.Pack(w.bufw)
if err != nil {
return err
}
return w.Flush()
}
// Pack encodes the FixHeader struct into bytes and writes it into io.Writer.
func (fh *FixHeader) Pack(w io.Writer) error {
var err error
b := make([]byte, 1)
packetType := fh.PacketType << 4
b[0] = packetType | fh.Flags
length, err := DecodeRemainLength(fh.RemainLength)
if err != nil {
return err
}
b = append(b, length...)
_, err = w.Write(b)
return err
}
//DecodeRemainLength 将remain length 转成byte表示
//
//DecodeRemainLength puts the length int into bytes
func DecodeRemainLength(length int) ([]byte, error) {
var result []byte
if length < 128 {
result = make([]byte, 1)
} else if length < 16384 {
result = make([]byte, 2)
} else if length < 2097152 {
result = make([]byte, 3)
} else if length < 268435456 {
result = make([]byte, 4)
} else {
return nil, codes.ErrMalformed
}
var i int
for {
encodedByte := length % 128
length = length / 128
// if there are more data to encode, set the top bit of this byte
if length > 0 {
encodedByte = encodedByte | 128
}
result[i] = byte(encodedByte)
i++
if length <= 0 {
break
}
}
return result, nil
}
// EncodeRemainLength 读remainLength,如果格式错误返回 error
//
// EncodeRemainLength reads the remain length bytes from bufio.Reader and returns length int.
func EncodeRemainLength(r io.ByteReader) (int, error) {
var vbi uint32
var multiplier uint32
for {
digit, err := r.ReadByte()
if err != nil && err != io.EOF {
return 0, err
}
vbi |= uint32(digit&127) << multiplier
if vbi > 268435455 {
return 0, codes.ErrMalformed
}
if (digit & 128) == 0 {
break
}
multiplier += 7
}
return int(vbi), nil
}
// EncodeUTF8String encodes the bytes into UTF-8 encoded strings, returns the encoded bytes, bytes size and error.
func EncodeUTF8String(buf []byte) (b []byte, size int, err error) {
buflen := len(buf)
if buflen > 65535 {
return nil, 0, codes.ErrMalformed
}
//length := int(binary.BigEndian.Uint16(buf[0:2]))
bufw := make([]byte, 2, 2+buflen)
binary.BigEndian.PutUint16(bufw, uint16(buflen))
bufw = append(bufw, buf...)
return bufw, 2 + buflen, nil
}
func readUint16(r *bytes.Buffer) (uint16, error) {
if r.Len() < 2 {
return 0, codes.ErrMalformed
}
return binary.BigEndian.Uint16(r.Next(2)), nil
}
func writeUint16(w *bytes.Buffer, i uint16) {
w.WriteByte(byte(i >> 8))
w.WriteByte(byte(i))
}
func writeUint32(w *bytes.Buffer, i uint32) {
w.WriteByte(byte(i >> 24))
w.WriteByte(byte(i >> 16))
w.WriteByte(byte(i >> 8))
w.WriteByte(byte(i))
}
func readUint32(r *bytes.Buffer) (uint32, error) {
if r.Len() < 4 {
return 0, codes.ErrMalformed
}
return binary.BigEndian.Uint32(r.Next(4)), nil
}
func readBinary(r *bytes.Buffer) (b []byte, err error) {
return readUTF8String(false, r)
}
func readUTF8String(mustUTF8 bool, r *bytes.Buffer) (b []byte, err error) {
if r.Len() < 2 {
return nil, codes.ErrMalformed
}
length := int(binary.BigEndian.Uint16(r.Next(2)))
if r.Len() < length {
return nil, codes.ErrMalformed
}
payload := r.Next(length)
if mustUTF8 {
if !ValidUTF8(payload) {
return nil, codes.ErrMalformed
}
}
return payload, nil
}
// the length of s cannot be greater than 65535
func writeUTF8String(w *bytes.Buffer, s []byte) {
writeUint16(w, uint16(len(s)))
w.Write(s)
}
func writeBinary(w *bytes.Buffer, b []byte) {
writeUint16(w, uint16(len(b)))
w.Write(b)
}
// DecodeUTF8String decodes the UTF-8 encoded strings into bytes, returns the decoded bytes, bytes size and error.
func DecodeUTF8String(buf []byte) (b []byte, size int, err error) {
buflen := len(buf)
if buflen < 2 {
return nil, 0, ErrInvalUTF8String
}
length := int(binary.BigEndian.Uint16(buf[0:2]))
if buflen < length+2 {
return nil, 0, ErrInvalUTF8String
}
payload := buf[2 : length+2]
if !ValidUTF8(payload) {
return nil, 0, ErrInvalUTF8String
}
return payload, length + 2, nil
}
// NewPacket returns a packet representing the decoded MQTT packet and an error.
func NewPacket(fh *FixHeader, version Version, r io.Reader) (Packet, error) {
switch fh.PacketType {
case CONNECT:
return NewConnectPacket(fh, version, r)
case CONNACK:
return NewConnackPacket(fh, version, r)
case PUBLISH:
return NewPublishPacket(fh, version, r)
case PUBACK:
return NewPubackPacket(fh, version, r)
case PUBREC:
return NewPubrecPacket(fh, version, r)
case PUBREL:
return NewPubrelPacket(fh, r)
case PUBCOMP:
return NewPubcompPacket(fh, version, r)
case SUBSCRIBE:
return NewSubscribePacket(fh, version, r)
case SUBACK:
return NewSubackPacket(fh, version, r)
case UNSUBSCRIBE:
return NewUnsubscribePacket(fh, version, r)
case PINGREQ:
return NewPingreqPacket(fh, r)
case DISCONNECT:
return NewDisConnectPackets(fh, version, r)
case UNSUBACK:
return NewUnsubackPacket(fh, version, r)
case PINGRESP:
return NewPingrespPacket(fh, r)
case AUTH:
return NewAuthPacket(fh, r)
default:
return nil, codes.ErrProtocol
}
}
// ValidUTF8 验证是否utf8
//
// ValidUTF8 returns whether the given bytes is in UTF-8 form.
func ValidUTF8(p []byte) bool {
for {
if len(p) == 0 {
return true
}
ru, size := utf8.DecodeRune(p)
if ru >= '\u0000' && ru <= '\u001f' { //[MQTT-1.5.3-2]
return false
}
if ru >= '\u007f' && ru <= '\u009f' {
return false
}
if ru == utf8.RuneError {
return false
}
if !utf8.ValidRune(ru) {
return false
}
if size == 0 {
return true
}
p = p[size:]
}
}
// ValidTopicName returns whether the bytes is a valid non-shared topic filter.[MQTT-4.7.1-1].
func ValidTopicName(mustUTF8 bool, p []byte) bool {
for len(p) > 0 {
ru, size := utf8.DecodeRune(p)
if mustUTF8 && ru == utf8.RuneError {
return false
}
if size == 1 {
//主题名不允许使用通配符
if p[0] == byte('+') || p[0] == byte('#') {
return false
}
}
p = p[size:]
}
return true
}
// ValidV5Topic returns whether the given bytes is a valid MQTT V5 topic
func ValidV5Topic(p []byte) bool {
if len(p) == 0 {
return false
}
if bytes.HasPrefix(p, []byte("$share/")) {
if len(p) < 9 {
return false
}
if p[7] != '/' {
subp := p[7:]
for len(subp) > 0 {
ru, size := utf8.DecodeRune(subp)
if ru == utf8.RuneError {
return false
}
if size == 1 {
if subp[0] == '/' {
return ValidTopicFilter(true, subp[1:])
}
if subp[0] == byte('+') || subp[0] == byte('#') {
return false
}
}
subp = subp[size:]
}
}
return false
}
return ValidTopicFilter(true, p)
}
// ValidTopicFilter 验证主题过滤器是否合法
//
// ValidTopicFilter returns whether the bytes is a valid topic filter. [MQTT-4.7.1-2] [MQTT-4.7.1-3]
// ValidTopicFilter 验证主题过滤器是否合法
//
// ValidTopicFilter returns whether the bytes is a valid topic filter. [MQTT-4.7.1-2] [MQTT-4.7.1-3]
func ValidTopicFilter(mustUTF8 bool, p []byte) bool {
if len(p) == 0 {
return false
}
var prevByte byte //前一个字节
var isSetPrevByte bool
for len(p) > 0 {
ru, size := utf8.DecodeRune(p)
if mustUTF8 && ru == utf8.RuneError {
return false
}
plen := len(p)
if p[0] == byte('#') && plen != 1 { // #一定是最后一个字符
return false
}
if size == 1 && isSetPrevByte {
// + 前(如果有前后字节),一定是'/' [MQTT-4.7.1-2] [MQTT-4.7.1-3]
if (p[0] == byte('+') || p[0] == byte('#')) && prevByte != byte('/') {
return false
}
if plen > 1 { // p[0] 不是最后一个字节
if p[0] == byte('+') && p[1] != byte('/') { // + 后(如果有字节),一定是 '/'
return false
}
}
}
prevByte = p[0]
isSetPrevByte = true
p = p[size:]
}
return true
}
// TopicMatch 返回topic和topic filter是否
//
// TopicMatch returns whether the topic and topic filter is matched.
func TopicMatch(topic []byte, topicFilter []byte) bool {
var spos int
var tpos int
var sublen int
var topiclen int
var multilevelWildcard bool //是否是多层的通配符
sublen = len(topicFilter)
topiclen = len(topic)
if sublen == 0 || topiclen == 0 {
return false
}
if (topicFilter[0] == '$' && topic[0] != '$') || (topic[0] == '$' && topicFilter[0] != '$') {
return false
}
for {
//e.g. ([]byte("foo/bar"),[]byte("foo/+/#")
if spos < sublen && tpos <= topiclen {
if tpos != topiclen && topicFilter[spos] == topic[tpos] { // sublen是订阅 topiclen是发布,首字母匹配
if tpos == topiclen-1 { //遍历到topic的最后一个字节
/* Check for e.g. foo matching foo/# */
if spos == sublen-3 && topicFilter[spos+1] == '/' && topicFilter[spos+2] == '#' {
return true
}
}
spos++
tpos++
if spos == sublen && tpos == topiclen { //长度相等,内容相同,匹配
return true
} else if tpos == topiclen && spos == sublen-1 && topicFilter[spos] == '+' {
//订阅topic比发布topic多一个字节并且多出来的内容是+ ,比如: sub: foo/+ ,topic: foo/
if spos > 0 && topicFilter[spos-1] != '/' {
return false
}
spos++
return true
}
} else {
if topicFilter[spos] == '+' { //sub 和 topic 内容不匹配了
spos++
for { //找到topic的下一个主题分割符
if tpos < topiclen && topic[tpos] != '/' {
tpos++
} else {
break
}
}
if tpos == topiclen && spos == sublen { //都遍历完了,返回true
return true
}
} else if topicFilter[spos] == '#' {
multilevelWildcard = true
return true
} else {
/* Check for e.g. foo/bar matching foo/+/# */
if spos > 0 && spos+2 == sublen && tpos == topiclen && topicFilter[spos-1] == '+' && topicFilter[spos] == '/' && topicFilter[spos+1] == '#' {
multilevelWildcard = true
return true
}
return false
}
}
} else {
break
}
}
if !multilevelWildcard && (tpos < topiclen || spos < sublen) {
return false
}
return false
}
// TotalBytes returns how many bytes of the packet
func TotalBytes(p Packet) uint32 {
var header *FixHeader
switch pt := p.(type) {
case *Auth:
header = pt.FixHeader
case *Connect:
header = pt.FixHeader
case *Connack:
header = pt.FixHeader
case *Disconnect:
header = pt.FixHeader
case *Pingreq:
header = pt.FixHeader
case *Pingresp:
header = pt.FixHeader
case *Puback:
header = pt.FixHeader
case *Pubcomp:
header = pt.FixHeader
case *Publish:
header = pt.FixHeader
case *Pubrec:
header = pt.FixHeader
case *Pubrel:
header = pt.FixHeader
case *Suback:
header = pt.FixHeader
case *Subscribe:
header = pt.FixHeader
case *Unsuback:
header = pt.FixHeader
case *Unsubscribe:
header = pt.FixHeader
}
if header == nil {
return 0
}
var headerLength uint32
if header.RemainLength <= 127 {
headerLength = 2
} else if header.RemainLength <= 16383 {
headerLength = 3
} else if header.RemainLength <= 2097151 {
headerLength = 4
} else if header.RemainLength <= 268435455 {
headerLength = 5
}
return headerLength + uint32(header.RemainLength)
}