stash-box/pkg/api/server.go

327 lines
8.0 KiB
Go

package api
import (
"context"
"crypto/tls"
"embed"
"errors"
"fmt"
"html/template"
"io/fs"
"io/ioutil"
"net/http"
"net/http/pprof"
"path"
"runtime/debug"
"strconv"
"strings"
"github.com/klauspost/compress/flate"
gqlHandler "github.com/99designs/gqlgen/graphql/handler"
gqlExtension "github.com/99designs/gqlgen/graphql/handler/extension"
gqlTransport "github.com/99designs/gqlgen/graphql/handler/transport"
gqlPlayground "github.com/99designs/gqlgen/graphql/playground"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/rs/cors"
"github.com/stashapp/stash-box/pkg/dataloader"
"github.com/stashapp/stash-box/pkg/logger"
"github.com/stashapp/stash-box/pkg/manager/config"
"github.com/stashapp/stash-box/pkg/manager/paths"
"github.com/stashapp/stash-box/pkg/models"
"github.com/stashapp/stash-box/pkg/user"
)
var version string
var buildstamp string
var githash string
var buildtype string
const APIKeyHeader = "ApiKey"
func getUserAndRoles(fac models.Repo, userID string) (*models.User, []models.RoleEnum, error) {
u, err := user.Get(fac, userID)
if err != nil {
return nil, nil, err
}
roles, err := user.GetRoles(fac, userID)
if err != nil {
return nil, nil, err
}
return u, roles, nil
}
func authenticateHandler() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// translate api key into current user, if present
userID := ""
apiKey := r.Header.Get(APIKeyHeader)
var err error
if apiKey != "" {
userID, err = user.GetUserIDFromAPIKey(apiKey)
} else {
// handle session
userID, err = getSessionUserID(w, r)
}
var u *models.User
var roles []models.RoleEnum
if err == nil {
u, roles, err = getUserAndRoles(getRepo(ctx), userID)
}
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, err = w.Write([]byte(err.Error()))
if err != nil {
logger.Error(err)
}
return
}
// ensure api key of the user matches the passed one
if apiKey != "" && u != nil && u.APIKey != apiKey {
w.WriteHeader(http.StatusUnauthorized)
return
}
// TODO - increment api key counters
ctx = context.WithValue(ctx, user.ContextUser, u)
ctx = context.WithValue(ctx, user.ContextRoles, roles)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
func redirect(w http.ResponseWriter, req *http.Request) {
target := "https://" + req.Host + req.URL.Path
if len(req.URL.RawQuery) > 0 {
target += "?" + req.URL.RawQuery
}
http.Redirect(w, req, target, http.StatusPermanentRedirect)
}
func Start(rfp RepoProvider, ui embed.FS) {
r := chi.NewRouter()
var corsConfig *cors.Cors
if config.GetIsProduction() {
corsConfig = cors.AllowAll()
} else {
corsConfig = cors.New(cors.Options{
AllowOriginFunc: func(origin string) bool { return true },
AllowCredentials: true,
AllowedHeaders: []string{"*"},
})
}
r.Use(corsConfig.Handler)
r.Use(repoMiddleware(rfp))
r.Use(authenticateHandler())
r.Use(middleware.Recoverer)
compressor := middleware.NewCompressor(flate.DefaultCompression)
r.Use(compressor.Handler)
r.Use(middleware.StripSlashes)
r.Use(BaseURLMiddleware)
recoverFunc := func(ctx context.Context, err interface{}) error {
logger.Error(err)
debug.PrintStack()
message := fmt.Sprintf("Internal system error. Error <%v>", err)
return errors.New(message)
}
gqlConfig := models.Config{
Resolvers: NewResolver(getRepo),
Directives: models.DirectiveRoot{
IsUserOwner: IsUserOwnerDirective,
HasRole: HasRoleDirective,
},
}
gqlSrv := gqlHandler.New(models.NewExecutableSchema(gqlConfig))
gqlSrv.SetRecoverFunc(recoverFunc)
gqlSrv.AddTransport(gqlTransport.Options{})
gqlSrv.AddTransport(gqlTransport.GET{})
gqlSrv.AddTransport(gqlTransport.POST{})
gqlSrv.AddTransport(gqlTransport.MultipartForm{})
gqlSrv.Use(gqlExtension.Introspection{})
r.Handle("/graphql", dataloader.Middleware(rfp.Repo())(gqlSrv))
if !config.GetIsProduction() {
r.Handle("/playground", gqlPlayground.Handler("GraphQL playground", "/graphql"))
}
index := getIndex(ui)
// session handlers
r.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
_, _ = w.Write(index)
return
}
handleLogin(w, r)
})
r.HandleFunc("/logout", handleLogout)
r.Mount("/image", imageRoutes{}.Routes())
// Serve the web app
r.HandleFunc("/*", func(w http.ResponseWriter, r *http.Request) {
ext := path.Ext(r.URL.Path)
if ext == ".html" || ext == "" {
_, _ = w.Write(index)
} else {
isStatic, _ := path.Match("/static/*/*", r.URL.Path)
if isStatic {
w.Header().Add("Cache-Control", "max-age=604800000")
}
uiRoot, err := fs.Sub(ui, "frontend/build")
if err != nil {
panic(error.Error(err))
}
http.FileServer(http.FS(uiRoot)).ServeHTTP(w, r)
}
})
if config.GetProfilerPort() != nil {
go func() {
mux := http.NewServeMux()
mux.HandleFunc("/", pprof.Index)
mux.HandleFunc("/cmdline", pprof.Cmdline)
mux.HandleFunc("/profile", pprof.Profile)
mux.HandleFunc("/symbol", pprof.Symbol)
mux.HandleFunc("/trace", pprof.Trace)
mux.Handle("/allocs", pprof.Handler("allocs"))
mux.Handle("/block", pprof.Handler("block"))
mux.Handle("/goroutine", pprof.Handler("goroutine"))
mux.Handle("/heap", pprof.Handler("heap"))
mux.Handle("/mutex", pprof.Handler("mutex"))
mux.Handle("/threadcreate", pprof.Handler("threadcreate"))
logger.Infof("profiler is running at http://localhost:%d/", *config.GetProfilerPort())
logger.Fatal(http.ListenAndServe(":"+strconv.Itoa(*config.GetProfilerPort()), mux))
}()
}
address := config.GetHost() + ":" + strconv.Itoa(config.GetPort())
if tlsConfig := makeTLSConfig(); tlsConfig != nil {
httpsServer := &http.Server{
Addr: address,
Handler: r,
TLSConfig: tlsConfig,
}
if config.GetHTTPUpgrade() {
go func() {
logger.Fatal(http.ListenAndServe(config.GetHost()+":80", http.HandlerFunc(redirect)))
}()
}
go func() {
printVersion()
logger.Infof("stash-box is running on HTTPS at https://" + address + "/")
logger.Fatal(httpsServer.ListenAndServeTLS("", ""))
}()
} else {
server := &http.Server{
Addr: address,
Handler: r,
}
go func() {
printVersion()
logger.Infof("stash-box is running on HTTP at http://" + address + "/")
logger.Fatal(server.ListenAndServe())
}()
}
}
func printVersion() {
versionString := version
if buildtype != "OFFICIAL" {
versionString += " (" + githash + ")"
}
fmt.Printf("stash-box version: %s - %s\n", versionString, buildstamp)
}
func GetVersion() (string, string, string) {
return version, githash, buildstamp
}
func makeTLSConfig() *tls.Config {
cert, err := ioutil.ReadFile(paths.GetSSLCert())
if err != nil {
return nil
}
key, err := ioutil.ReadFile(paths.GetSSLKey())
if err != nil {
return nil
}
certs := make([]tls.Certificate, 1)
certs[0], err = tls.X509KeyPair(cert, key)
if err != nil {
return nil
}
tlsConfig := &tls.Config{
Certificates: certs,
}
return tlsConfig
}
type contextKey struct {
name string
}
var (
BaseURLCtxKey = &contextKey{"BaseURL"}
)
func BaseURLMiddleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var scheme string
if strings.Compare("https", r.URL.Scheme) == 0 || r.Proto == "HTTP/2.0" || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
} else {
scheme = "http"
}
baseURL := scheme + "://" + r.Host
r = r.WithContext(context.WithValue(ctx, BaseURLCtxKey, baseURL))
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func getIndex(ui embed.FS) []byte {
indexFile, err := ui.ReadFile("frontend/build/index.html")
if err != nil {
panic(error.Error(err))
}
tmpl := template.Must(template.New("index").Parse(string(indexFile)))
title := template.HTMLEscapeString(config.GetTitle())
output := new(strings.Builder)
if err := tmpl.Execute(output, template.HTML(title)); err != nil {
panic(error.Error(err))
}
return []byte(output.String())
}