mirror of
https://github.com/stashapp/stash-box.git
synced 2026-02-16 14:32:06 -06:00
311 lines
6.5 KiB
Go
311 lines
6.5 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
type QueryBuilder struct {
|
|
Table Table
|
|
Body string
|
|
Distinct bool
|
|
|
|
whereClauses []string
|
|
havingClauses []string
|
|
args []interface{}
|
|
|
|
SortAndPagination string
|
|
}
|
|
|
|
func NewQueryBuilder(table Table) *QueryBuilder {
|
|
ret := &QueryBuilder{
|
|
Table: table,
|
|
Distinct: false,
|
|
}
|
|
|
|
tableName := table.Name()
|
|
ret.Body = "SELECT " + tableName + ".* FROM " + tableName + " "
|
|
|
|
return ret
|
|
}
|
|
|
|
func NewDeleteQueryBuilder(table Table) *QueryBuilder {
|
|
ret := &QueryBuilder{
|
|
Table: table,
|
|
Distinct: false,
|
|
}
|
|
|
|
tableName := table.Name()
|
|
ret.Body = "DELETE FROM " + tableName + " "
|
|
|
|
return ret
|
|
}
|
|
|
|
func (qb *QueryBuilder) AddJoin(joinTable Table, on string) {
|
|
qb.Body += " JOIN " + joinTable.Name() + " ON " + on
|
|
qb.Distinct = true
|
|
}
|
|
|
|
func (qb *QueryBuilder) AddWhere(clauses ...string) {
|
|
qb.whereClauses = append(qb.whereClauses, clauses...)
|
|
}
|
|
|
|
func (qb *QueryBuilder) Eq(column string, arg interface{}) {
|
|
qb.AddWhere(column + " = ?")
|
|
qb.AddArg(arg)
|
|
}
|
|
|
|
func (qb *QueryBuilder) NotEq(column string, arg interface{}) {
|
|
qb.AddWhere(column + " != ?")
|
|
qb.AddArg(arg)
|
|
}
|
|
|
|
func (qb *QueryBuilder) IsNull(column string) {
|
|
qb.AddWhere(column + " is NULL")
|
|
}
|
|
|
|
func (qb *QueryBuilder) IsNotNull(column string) {
|
|
qb.AddWhere(column + " is not NULL")
|
|
}
|
|
|
|
func (qb *QueryBuilder) AddHaving(clauses ...string) {
|
|
if len(clauses) == 1 && clauses[0] == "" {
|
|
return
|
|
}
|
|
qb.havingClauses = append(qb.havingClauses, clauses...)
|
|
}
|
|
|
|
func (qb *QueryBuilder) AddArg(args ...interface{}) {
|
|
qb.args = append(qb.args, args...)
|
|
}
|
|
|
|
func (qb QueryBuilder) buildBody() string {
|
|
body := qb.Body
|
|
|
|
if len(qb.whereClauses) > 0 {
|
|
body = body + " WHERE " + strings.Join(qb.whereClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
if qb.Distinct {
|
|
body = body + " GROUP BY " + qb.Table.Name() + ".id "
|
|
}
|
|
if len(qb.havingClauses) > 0 {
|
|
body = body + " HAVING " + strings.Join(qb.havingClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
func (qb QueryBuilder) buildCountQuery() string {
|
|
return "SELECT COUNT(*) as count FROM (" + qb.buildBody() + ") as temp"
|
|
}
|
|
|
|
func (qb QueryBuilder) buildQuery() string {
|
|
return qb.buildBody() + qb.SortAndPagination
|
|
}
|
|
|
|
type optionalValue interface {
|
|
IsValid() bool
|
|
}
|
|
|
|
func ensureTx(tx *sqlx.Tx) {
|
|
if tx == nil {
|
|
panic("must use a transaction")
|
|
}
|
|
}
|
|
|
|
func getByID(tx *sqlx.Tx, table string, id uuid.UUID, object interface{}) error {
|
|
query := tx.Rebind(`SELECT * FROM ` + table + ` WHERE id = ? LIMIT 1`)
|
|
return tx.Get(object, query, id)
|
|
}
|
|
|
|
func insertObject(tx *sqlx.Tx, table string, object interface{}, conflictHandling *string) error {
|
|
ensureTx(tx)
|
|
fields, values := sqlGenKeysCreate(object)
|
|
|
|
conflictClause := ""
|
|
if conflictHandling != nil {
|
|
conflictClause = *conflictHandling
|
|
}
|
|
|
|
_, err := tx.NamedExec(
|
|
`INSERT INTO `+table+` (`+fields+`)
|
|
VALUES (`+values+`)
|
|
`+conflictClause+`
|
|
`,
|
|
object,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
func updateObjectByID(tx *sqlx.Tx, table string, object interface{}, updateEmptyValues bool) error {
|
|
ensureTx(tx)
|
|
_, err := tx.NamedExec(
|
|
`UPDATE `+table+` SET `+sqlGenKeys(object, updateEmptyValues)+` WHERE `+table+`.id = :id`,
|
|
object,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
func executeDeleteQuery(tableName string, id uuid.UUID, tx *sqlx.Tx) error {
|
|
ensureTx(tx)
|
|
idColumnName := getColumn(tableName, "id")
|
|
query := tx.Rebind(`DELETE FROM ` + tableName + ` WHERE ` + idColumnName + ` = ?`)
|
|
_, err := tx.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
func softDeleteObjectByID(tx *sqlx.Tx, table string, id uuid.UUID) error {
|
|
ensureTx(tx)
|
|
idColumnName := getColumn(table, "id")
|
|
query := tx.Rebind(`UPDATE ` + table + ` SET deleted=TRUE WHERE ` + idColumnName + ` = ?`)
|
|
_, err := tx.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
func deleteObjectsByColumn(tx *sqlx.Tx, table string, column string, value interface{}) error {
|
|
ensureTx(tx)
|
|
query := tx.Rebind(`DELETE FROM ` + table + ` WHERE ` + column + ` = ?`)
|
|
_, err := tx.Exec(query, value)
|
|
return err
|
|
}
|
|
|
|
func getColumn(tableName string, columnName string) string {
|
|
return tableName + "." + columnName
|
|
}
|
|
|
|
func sqlGenKeysCreate(i interface{}) (string, string) {
|
|
var fields []string
|
|
var values []string
|
|
|
|
addPlaceholder := func(key string) {
|
|
fields = append(fields, dialect.FieldQuote(key))
|
|
values = append(values, ":"+key)
|
|
}
|
|
|
|
v := reflect.ValueOf(i)
|
|
for i := 0; i < v.NumField(); i++ {
|
|
//get key for struct tag
|
|
rawKey := v.Type().Field(i).Tag.Get("db")
|
|
key := strings.Split(rawKey, ",")[0]
|
|
switch t := v.Field(i).Interface().(type) {
|
|
case string:
|
|
if t != "" {
|
|
addPlaceholder(key)
|
|
}
|
|
case int, int64, float64:
|
|
if t != 0 {
|
|
addPlaceholder(key)
|
|
}
|
|
case uuid.UUID:
|
|
if t != uuid.Nil {
|
|
addPlaceholder(key)
|
|
}
|
|
case bool:
|
|
addPlaceholder(key)
|
|
case optionalValue:
|
|
if t.IsValid() {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullString:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullBool:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullInt64:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case uuid.NullUUID:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullFloat64:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
default:
|
|
reflectValue := reflect.ValueOf(t)
|
|
isNil := reflectValue.IsNil()
|
|
if !isNil {
|
|
addPlaceholder(key)
|
|
}
|
|
}
|
|
}
|
|
return strings.Join(fields, ", "), strings.Join(values, ", ")
|
|
}
|
|
|
|
func sqlGenKeys(i interface{}, partial bool) string {
|
|
var query []string
|
|
|
|
addKey := func(key string) {
|
|
query = append(query, fmt.Sprintf("%s=:%s", dialect.FieldQuote(key), key))
|
|
}
|
|
|
|
v := reflect.ValueOf(i)
|
|
for i := 0; i < v.NumField(); i++ {
|
|
//get key for struct tag
|
|
rawKey := v.Type().Field(i).Tag.Get("db")
|
|
key := strings.Split(rawKey, ",")[0]
|
|
if key == "id" {
|
|
continue
|
|
}
|
|
switch t := v.Field(i).Interface().(type) {
|
|
case string:
|
|
if partial || t != "" {
|
|
addKey(key)
|
|
}
|
|
case uuid.UUID:
|
|
if partial || t != uuid.Nil {
|
|
addKey(key)
|
|
}
|
|
case int, int64, float64:
|
|
if partial || t != 0 {
|
|
addKey(key)
|
|
}
|
|
case bool:
|
|
addKey(key)
|
|
case optionalValue:
|
|
if partial || t.IsValid() {
|
|
addKey(key)
|
|
}
|
|
case sql.NullString:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case sql.NullBool:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case sql.NullInt64:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case uuid.NullUUID:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case sql.NullFloat64:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
default:
|
|
reflectValue := reflect.ValueOf(t)
|
|
isNil := reflectValue.IsNil()
|
|
if !isNil {
|
|
addKey(key)
|
|
}
|
|
}
|
|
}
|
|
return strings.Join(query, ", ")
|
|
}
|