2021-09-23 17:06:09 +00:00

464 lines
13 KiB
Go

package privateinternetaccess
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/golibs/format"
)
var (
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
ErrServerNameEmpty = errors.New("server name is empty")
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
ErrBindPort = errors.New("cannot bind port")
)
// PortForward obtains a VPN server side port forwarded from PIA.
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
logger utils.Logger, gateway net.IP, serverName string) (
port uint16, err error) {
server := constants.PIAServerWhereName(p.servers, serverName)
if !server.PortForward {
logger.Error("The server " + serverName +
" (region " + server.Region + ") does not support port forwarding")
return 0, nil
}
if gateway == nil {
return 0, ErrGatewayIPIsNil
} else if serverName == "" {
return 0, ErrServerNameEmpty
}
privateIPClient, err := newHTTPClient(serverName)
if err != nil {
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
}
data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
}
dataFound := data.Port > 0
durationToExpiration := data.Expiration.Sub(p.timeNow())
expired := durationToExpiration <= 0
if dataFound {
logger.Info("Found saved forwarded port data for port " + strconv.Itoa(int(data.Port)))
if expired {
logger.Warn("Forwarded port data expired on " +
data.Expiration.Format(time.RFC1123) + ", getting another one")
} else {
logger.Info("Forwarded port data expires in " + format.FriendlyDuration(durationToExpiration))
}
}
if !dataFound || expired {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
}
durationToExpiration = data.Expiration.Sub(p.timeNow())
}
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
// First time binding
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
}
return data.Port, nil
}
var (
ErrPortForwardedExpired = errors.New("port forwarded data expired")
)
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
port uint16, gateway net.IP, serverName string) (err error) {
privateIPClient, err := newHTTPClient(serverName)
if err != nil {
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
}
data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
expiryTimer := time.NewTimer(durationToExpiration)
const keepAlivePeriod = 15 * time.Minute
// Timer behaving as a ticker
keepAliveTimer := time.NewTimer(keepAlivePeriod)
for {
select {
case <-ctx.Done():
if !keepAliveTimer.Stop() {
<-keepAliveTimer.C
}
if !expiryTimer.Stop() {
<-expiryTimer.C
}
return ctx.Err()
case <-keepAliveTimer.C:
err := bindPort(ctx, privateIPClient, gateway, data)
if err != nil {
return fmt.Errorf("%w: %s", ErrBindPort, err)
}
keepAliveTimer.Reset(keepAlivePeriod)
case <-expiryTimer.C:
return fmt.Errorf("%w: on %s", ErrPortForwardedExpired,
data.Expiration.Format(time.RFC1123))
}
}
}
var (
ErrFetchToken = errors.New("cannot fetch token")
ErrFetchPortForwarding = errors.New("cannot fetch port forwarding data")
ErrPersistPortForwarding = errors.New("cannot persist port forwarding data")
)
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, authFilePath)
if err != nil {
return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
}
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
if err != nil {
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
}
if err := writePIAPortForwardData(portForwardPath, data); err != nil {
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
}
return data, nil
}
type piaPayload struct {
Token string `json:"token"`
Port uint16 `json:"port"`
Expiration time.Time `json:"expires_at"`
}
type piaPortForwardData struct {
Port uint16 `json:"port"`
Token string `json:"token"`
Signature string `json:"signature"`
Expiration time.Time `json:"expires_at"`
}
func readPIAPortForwardData(portForwardPath string) (data piaPortForwardData, err error) {
file, err := os.Open(portForwardPath)
if os.IsNotExist(err) {
return data, nil
} else if err != nil {
return data, err
}
decoder := json.NewDecoder(file)
if err := decoder.Decode(&data); err != nil {
_ = file.Close()
return data, err
}
return data, file.Close()
}
func writePIAPortForwardData(portForwardPath string, data piaPortForwardData) (err error) {
file, err := os.OpenFile(portForwardPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
encoder := json.NewEncoder(file)
if err := encoder.Encode(data); err != nil {
_ = file.Close()
return err
}
return file.Close()
}
func unpackPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
b, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return 0, "", expiration,
fmt.Errorf("%w: for payload: %s", err, payload)
}
var payloadData piaPayload
if err := json.Unmarshal(b, &payloadData); err != nil {
return 0, "", expiration,
fmt.Errorf("%w: for data: %s", err, string(b))
}
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
}
func packPayload(port uint16, token string, expiration time.Time) (payload string, err error) {
payloadData := piaPayload{
Token: token,
Port: port,
Expiration: expiration,
}
b, err := json.Marshal(&payloadData)
if err != nil {
return "", err
}
payload = base64.StdEncoding.EncodeToString(b)
return payload, nil
}
var (
errGetCredentials = errors.New("cannot get username and password")
errEmptyToken = errors.New("token received is empty")
)
func fetchToken(ctx context.Context, client *http.Client,
authFilePath string) (token string, err error) {
username, password, err := getOpenvpnCredentials(authFilePath)
if err != nil {
return "", fmt.Errorf("%w: %s", errGetCredentials, err)
}
errSubstitutions := map[string]string{
username: "<username>",
password: "<password>",
}
url := url.URL{
Scheme: "https",
User: url.UserPassword(username, password),
Host: "privateinternetaccess.com",
Path: "/gtoken/generateToken",
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil {
return "", replaceInErr(err, errSubstitutions)
}
response, err := client.Do(request)
if err != nil {
return "", replaceInErr(err, errSubstitutions)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return "", makeNOKStatusError(response, nil)
}
decoder := json.NewDecoder(response.Body)
var result struct {
Token string `json:"token"`
}
if err := decoder.Decode(&result); err != nil {
return "", fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
}
if result.Token == "" {
return "", errEmptyToken
}
return result.Token, nil
}
var (
errAuthFileRead = errors.New("cannot read OpenVPN authentication file")
errAuthFileMalformed = errors.New("authentication file is malformed")
)
func getOpenvpnCredentials(authFilePath string) (
username, password string, err error) {
file, err := os.Open(authFilePath)
if err != nil {
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
}
authData, err := io.ReadAll(file)
if err != nil {
_ = file.Close()
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
}
if err := file.Close(); err != nil {
return "", "", err
}
lines := strings.Split(string(authData), "\n")
const minLines = 2
if len(lines) < minLines {
return "", "", fmt.Errorf("%w: only %d lines exist", errAuthFileMalformed, len(lines))
}
username, password = lines[0], lines[1]
return username, password, nil
}
var (
errGetSignaturePayload = errors.New("cannot obtain signature payload")
errUnpackPayload = errors.New("cannot unpack payload data")
)
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (
port uint16, signature string, expiration time.Time, err error) {
errSubstitutions := map[string]string{token: "<token>"}
queryParams := make(url.Values)
queryParams.Add("token", token)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil {
err = replaceInErr(err, errSubstitutions)
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
}
response, err := client.Do(request)
if err != nil {
err = replaceInErr(err, errSubstitutions)
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return 0, "", expiration, makeNOKStatusError(response, errSubstitutions)
}
decoder := json.NewDecoder(response.Body)
var data struct {
Status string `json:"status"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}
if err := decoder.Decode(&data); err != nil {
return 0, "", expiration, fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
}
if data.Status != "OK" {
return 0, "", expiration, fmt.Errorf("%w: status is: %s", ErrBadResponse, data.Status)
}
port, _, expiration, err = unpackPayload(data.Payload)
if err != nil {
return 0, "", expiration, fmt.Errorf("%w: %s", errUnpackPayload, err)
}
return port, data.Signature, expiration, err
}
var (
ErrSerializePayload = errors.New("cannot serialize payload")
ErrUnmarshalResponse = errors.New("cannot unmarshal response")
ErrBadResponse = errors.New("bad response received")
)
func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
payload, err := packPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return fmt.Errorf("%w: %s", ErrSerializePayload, err)
}
queryParams := make(url.Values)
queryParams.Add("payload", payload)
queryParams.Add("signature", data.Signature)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/bindPort",
RawQuery: queryParams.Encode(),
}
errSubstitutions := map[string]string{
payload: "<payload>",
data.Signature: "<signature>",
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil {
return replaceInErr(err, errSubstitutions)
}
response, err := client.Do(request)
if err != nil {
return replaceInErr(err, errSubstitutions)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return makeNOKStatusError(response, errSubstitutions)
}
decoder := json.NewDecoder(response.Body)
var responseData struct {
Status string `json:"status"`
Message string `json:"message"`
}
if err := decoder.Decode(&responseData); err != nil {
return fmt.Errorf("%w: from %s: %s", ErrUnmarshalResponse, url.String(), err)
}
if responseData.Status != "OK" {
return fmt.Errorf("%w: %s: %s", ErrBadResponse, responseData.Status, responseData.Message)
}
return nil
}
// replaceInErr is used to remove sensitive information from errors.
func replaceInErr(err error, substitutions map[string]string) error {
s := replaceInString(err.Error(), substitutions)
return errors.New(s) //nolint:goerr113
}
// replaceInString is used to remove sensitive information.
func replaceInString(s string, substitutions map[string]string) string {
for old, new := range substitutions {
s = strings.ReplaceAll(s, old, new)
}
return s
}
var ErrHTTPStatusCodeNotOK = errors.New("HTTP status code is not OK")
func makeNOKStatusError(response *http.Response, substitutions map[string]string) (err error) {
url := response.Request.URL.String()
url = replaceInString(url, substitutions)
b, _ := io.ReadAll(response.Body)
shortenMessage := string(b)
shortenMessage = strings.ReplaceAll(shortenMessage, "\n", "")
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
shortenMessage = replaceInString(shortenMessage, substitutions)
return fmt.Errorf("%w: %s: %s: response received: %s",
ErrHTTPStatusCodeNotOK, url, response.Status, shortenMessage)
}