wip: 初始阶段

This commit is contained in:
loveuer
2024-10-31 17:56:26 +08:00
commit 9ba2aace6a
16 changed files with 680 additions and 0 deletions

View File

@ -0,0 +1,9 @@
package sqlType
import "errors"
var (
ErrConvertScanVal = errors.New("convert scan val to str err")
ErrInvalidScanVal = errors.New("scan val invalid")
ErrConvertVal = errors.New("convert err")
)

76
internal/sqlType/jsonb.go Normal file
View File

@ -0,0 +1,76 @@
package sqlType
import (
"database/sql/driver"
"encoding/json"
"github.com/jackc/pgtype"
)
type JSONB struct {
Val pgtype.JSONB
Valid bool
}
func NewJSONB(v interface{}) JSONB {
j := new(JSONB)
j.Val = pgtype.JSONB{}
if err := j.Val.Set(v); err == nil {
j.Valid = true
return *j
}
return *j
}
func (j *JSONB) Set(value interface{}) error {
if err := j.Val.Set(value); err != nil {
j.Valid = false
return err
}
j.Valid = true
return nil
}
func (j *JSONB) Bind(model interface{}) error {
return j.Val.AssignTo(model)
}
func (j *JSONB) Scan(value interface{}) error {
j.Val = pgtype.JSONB{}
if value == nil {
j.Valid = false
return nil
}
j.Valid = true
return j.Val.Scan(value)
}
func (j JSONB) Value() (driver.Value, error) {
if j.Valid {
return j.Val.Value()
}
return nil, nil
}
func (j JSONB) MarshalJSON() ([]byte, error) {
if j.Valid {
return j.Val.MarshalJSON()
}
return json.Marshal(nil)
}
func (j *JSONB) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
j.Valid = false
return j.Val.UnmarshalJSON(b)
}
return j.Val.UnmarshalJSON(b)
}

View File

@ -0,0 +1,71 @@
package sqlType
import (
"database/sql/driver"
"fmt"
"strconv"
"strings"
"github.com/spf13/cast"
)
type NumSlice[T ~int | ~int64 | ~uint | ~uint64] []T
func (n *NumSlice[T]) Scan(val interface{}) error {
str, ok := val.(string)
if !ok {
return ErrConvertScanVal
}
length := len(str)
if length <= 0 {
*n = make(NumSlice[T], 0)
return nil
}
if str[0] != '{' || str[length-1] != '}' {
return ErrInvalidScanVal
}
str = str[1 : length-1]
if len(str) == 0 {
*n = make(NumSlice[T], 0)
return nil
}
numStrs := strings.Split(str, ",")
nums := make([]T, len(numStrs))
for idx := range numStrs {
num, err := cast.ToInt64E(strings.TrimSpace(numStrs[idx]))
if err != nil {
return fmt.Errorf("%w: can't convert to %T", ErrConvertVal, T(0))
}
nums[idx] = T(num)
}
*n = nums
return nil
}
func (n NumSlice[T]) Value() (driver.Value, error) {
if n == nil {
return "{}", nil
}
if len(n) == 0 {
return "{}", nil
}
ss := make([]string, 0, len(n))
for idx := range n {
ss = append(ss, strconv.Itoa(int(n[idx])))
}
s := strings.Join(ss, ", ")
return fmt.Sprintf("{%s}", s), nil
}

View File

@ -0,0 +1,107 @@
package sqlType
import (
"bytes"
"database/sql/driver"
"encoding/json"
)
type StrSlice []string
func (s *StrSlice) Scan(val interface{}) error {
str, ok := val.(string)
if !ok {
return ErrConvertScanVal
}
if len(str) < 2 {
return nil
}
bs := make([]byte, 0, 128)
bss := make([]byte, 0, 2*len(str))
quoteCount := 0
for idx := 1; idx < len(str)-1; idx++ {
quote := str[idx]
switch quote {
case 44:
if quote == 44 && str[idx-1] != 92 && quoteCount == 0 {
if len(bs) > 0 {
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
bs = append([]byte{34}, bs...)
bs = append(bs, 34)
}
bss = append(bss, bs...)
bss = append(bss, 44)
}
bs = bs[:0]
} else {
bs = append(bs, quote)
}
case 34:
if str[idx-1] != 92 {
quoteCount = (quoteCount + 1) % 2
}
bs = append(bs, quote)
default:
bs = append(bs, quote)
}
}
if len(bs) > 0 {
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
bs = append([]byte{34}, bs...)
bs = append(bs, 34)
}
bss = append(bss, bs...)
} else {
if len(bss) > 2 {
bss = bss[:len(bss)-2]
}
}
bss = append([]byte{'['}, append(bss, ']')...)
if err := json.Unmarshal(bss, s); err != nil {
return err
}
return nil
}
func (s StrSlice) Value() (driver.Value, error) {
if s == nil {
return "{}", nil
}
if len(s) == 0 {
return "{}", nil
}
buf := &bytes.Buffer{}
encoder := json.NewEncoder(buf)
encoder.SetEscapeHTML(false)
if err := encoder.Encode(s); err != nil {
return "{}", err
}
bs := buf.Bytes()
bs[0] = '{'
if bs[len(bs)-1] == 10 {
bs = bs[:len(bs)-1]
}
bs[len(bs)-1] = '}'
return string(bs), nil
}