mirror of
https://github.com/stashapp/stash.git
synced 2025-12-11 04:43:03 -06:00
275 lines
7.3 KiB
Go
275 lines
7.3 KiB
Go
package graphql
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/vektah/gqlparser/ast"
|
|
"github.com/vektah/gqlparser/gqlerror"
|
|
)
|
|
|
|
type Resolver func(ctx context.Context) (res interface{}, err error)
|
|
type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
|
|
type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
|
|
type ComplexityLimitFunc func(ctx context.Context) int
|
|
|
|
type RequestContext struct {
|
|
RawQuery string
|
|
Variables map[string]interface{}
|
|
Doc *ast.QueryDocument
|
|
|
|
ComplexityLimit int
|
|
OperationComplexity int
|
|
DisableIntrospection bool
|
|
|
|
// ErrorPresenter will be used to generate the error
|
|
// message from errors given to Error().
|
|
ErrorPresenter ErrorPresenterFunc
|
|
Recover RecoverFunc
|
|
ResolverMiddleware FieldMiddleware
|
|
DirectiveMiddleware FieldMiddleware
|
|
RequestMiddleware RequestMiddleware
|
|
Tracer Tracer
|
|
|
|
errorsMu sync.Mutex
|
|
Errors gqlerror.List
|
|
extensionsMu sync.Mutex
|
|
Extensions map[string]interface{}
|
|
}
|
|
|
|
func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
|
|
return next(ctx)
|
|
}
|
|
|
|
func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
|
|
return next(ctx)
|
|
}
|
|
|
|
func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
|
|
return next(ctx)
|
|
}
|
|
|
|
func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
|
|
return &RequestContext{
|
|
Doc: doc,
|
|
RawQuery: query,
|
|
Variables: variables,
|
|
ResolverMiddleware: DefaultResolverMiddleware,
|
|
DirectiveMiddleware: DefaultDirectiveMiddleware,
|
|
RequestMiddleware: DefaultRequestMiddleware,
|
|
Recover: DefaultRecover,
|
|
ErrorPresenter: DefaultErrorPresenter,
|
|
Tracer: &NopTracer{},
|
|
}
|
|
}
|
|
|
|
type key string
|
|
|
|
const (
|
|
request key = "request_context"
|
|
resolver key = "resolver_context"
|
|
)
|
|
|
|
func GetRequestContext(ctx context.Context) *RequestContext {
|
|
if val, ok := ctx.Value(request).(*RequestContext); ok {
|
|
return val
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
|
|
return context.WithValue(ctx, request, rc)
|
|
}
|
|
|
|
type ResolverContext struct {
|
|
Parent *ResolverContext
|
|
// The name of the type this field belongs to
|
|
Object string
|
|
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
|
|
Args map[string]interface{}
|
|
// The raw field
|
|
Field CollectedField
|
|
// The index of array in path.
|
|
Index *int
|
|
// The result object of resolver
|
|
Result interface{}
|
|
// IsMethod indicates if the resolver is a method
|
|
IsMethod bool
|
|
}
|
|
|
|
func (r *ResolverContext) Path() []interface{} {
|
|
var path []interface{}
|
|
for it := r; it != nil; it = it.Parent {
|
|
if it.Index != nil {
|
|
path = append(path, *it.Index)
|
|
} else if it.Field.Field != nil {
|
|
path = append(path, it.Field.Alias)
|
|
}
|
|
}
|
|
|
|
// because we are walking up the chain, all the elements are backwards, do an inplace flip.
|
|
for i := len(path)/2 - 1; i >= 0; i-- {
|
|
opp := len(path) - 1 - i
|
|
path[i], path[opp] = path[opp], path[i]
|
|
}
|
|
|
|
return path
|
|
}
|
|
|
|
func GetResolverContext(ctx context.Context) *ResolverContext {
|
|
if val, ok := ctx.Value(resolver).(*ResolverContext); ok {
|
|
return val
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
|
|
rc.Parent = GetResolverContext(ctx)
|
|
return context.WithValue(ctx, resolver, rc)
|
|
}
|
|
|
|
// This is just a convenient wrapper method for CollectFields
|
|
func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
|
|
resctx := GetResolverContext(ctx)
|
|
return CollectFields(GetRequestContext(ctx), resctx.Field.Selections, satisfies)
|
|
}
|
|
|
|
// CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context.
|
|
// The slice will contain the unique set of all field names requested regardless of fragment type conditions.
|
|
func CollectAllFields(ctx context.Context) []string {
|
|
resctx := GetResolverContext(ctx)
|
|
collected := CollectFields(GetRequestContext(ctx), resctx.Field.Selections, nil)
|
|
uniq := make([]string, 0, len(collected))
|
|
Next:
|
|
for _, f := range collected {
|
|
for _, name := range uniq {
|
|
if name == f.Name {
|
|
continue Next
|
|
}
|
|
}
|
|
uniq = append(uniq, f.Name)
|
|
}
|
|
return uniq
|
|
}
|
|
|
|
// Errorf sends an error string to the client, passing it through the formatter.
|
|
func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
|
|
c.errorsMu.Lock()
|
|
defer c.errorsMu.Unlock()
|
|
|
|
c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
|
|
}
|
|
|
|
// Error sends an error to the client, passing it through the formatter.
|
|
func (c *RequestContext) Error(ctx context.Context, err error) {
|
|
c.errorsMu.Lock()
|
|
defer c.errorsMu.Unlock()
|
|
|
|
c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
|
|
}
|
|
|
|
// HasError returns true if the current field has already errored
|
|
func (c *RequestContext) HasError(rctx *ResolverContext) bool {
|
|
c.errorsMu.Lock()
|
|
defer c.errorsMu.Unlock()
|
|
path := rctx.Path()
|
|
|
|
for _, err := range c.Errors {
|
|
if equalPath(err.Path, path) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetErrors returns a list of errors that occurred in the current field
|
|
func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
|
|
c.errorsMu.Lock()
|
|
defer c.errorsMu.Unlock()
|
|
path := rctx.Path()
|
|
|
|
var errs gqlerror.List
|
|
for _, err := range c.Errors {
|
|
if equalPath(err.Path, path) {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
return errs
|
|
}
|
|
|
|
func equalPath(a []interface{}, b []interface{}) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
for i := 0; i < len(a); i++ {
|
|
if a[i] != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// AddError is a convenience method for adding an error to the current response
|
|
func AddError(ctx context.Context, err error) {
|
|
GetRequestContext(ctx).Error(ctx, err)
|
|
}
|
|
|
|
// AddErrorf is a convenience method for adding an error to the current response
|
|
func AddErrorf(ctx context.Context, format string, args ...interface{}) {
|
|
GetRequestContext(ctx).Errorf(ctx, format, args...)
|
|
}
|
|
|
|
// RegisterExtension registers an extension, returns error if extension has already been registered
|
|
func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
|
|
c.extensionsMu.Lock()
|
|
defer c.extensionsMu.Unlock()
|
|
|
|
if c.Extensions == nil {
|
|
c.Extensions = make(map[string]interface{})
|
|
}
|
|
|
|
if _, ok := c.Extensions[key]; ok {
|
|
return fmt.Errorf("extension already registered for key %s", key)
|
|
}
|
|
|
|
c.Extensions[key] = value
|
|
return nil
|
|
}
|
|
|
|
// ChainFieldMiddleware add chain by FieldMiddleware
|
|
func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware {
|
|
n := len(handleFunc)
|
|
|
|
if n > 1 {
|
|
lastI := n - 1
|
|
return func(ctx context.Context, next Resolver) (interface{}, error) {
|
|
var (
|
|
chainHandler Resolver
|
|
curI int
|
|
)
|
|
chainHandler = func(currentCtx context.Context) (interface{}, error) {
|
|
if curI == lastI {
|
|
return next(currentCtx)
|
|
}
|
|
curI++
|
|
res, err := handleFunc[curI](currentCtx, chainHandler)
|
|
curI--
|
|
return res, err
|
|
|
|
}
|
|
return handleFunc[0](ctx, chainHandler)
|
|
}
|
|
}
|
|
|
|
if n == 1 {
|
|
return handleFunc[0]
|
|
}
|
|
|
|
return func(ctx context.Context, next Resolver) (interface{}, error) {
|
|
return next(ctx)
|
|
}
|
|
}
|