diff --git a/Dockerfile b/Dockerfile index 435f0249..5395c553 100644 --- a/Dockerfile +++ b/Dockerfile @@ -207,7 +207,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ UPDATER_PERIOD=0 \ UPDATER_MIN_RATIO=0.8 \ UPDATER_VPN_SERVICE_PROVIDERS= \ - UPDATER_PROTONVPN_USERNAME= \ + UPDATER_PROTONVPN_EMAIL= \ UPDATER_PROTONVPN_PASSWORD= \ # Public IP PUBLICIP_FILE="/tmp/gluetun/ip" \ diff --git a/internal/cli/update.go b/internal/cli/update.go index 2acad8ae..eaad2d48 100644 --- a/internal/cli/update.go +++ b/internal/cli/update.go @@ -38,7 +38,7 @@ type UpdaterLogger interface { func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error { options := settings.Updater{} var endUserMode, maintainerMode, updateAll bool - var csvProviders, ipToken, protonUsername, protonPassword string + var csvProviders, ipToken, protonUsername, protonEmail, protonPassword string flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)") flagSet.BoolVar(&maintainerMode, "maintainer", false, @@ -50,7 +50,9 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers") flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for") flagSet.StringVar(&ipToken, "ip-token", "", "IP data service token (e.g. ipinfo.io) to use") - flagSet.StringVar(&protonUsername, "proton-username", "", "Username to use to authenticate with Proton") + flagSet.StringVar(&protonUsername, "proton-username", "", + "(Retro-compatibility) Username to use to authenticate with Proton. Use -proton-email instead.") // v4 remove this + flagSet.StringVar(&protonEmail, "proton-email", "", "Email to use to authenticate with Proton") flagSet.StringVar(&protonPassword, "proton-password", "", "Password to use to authenticate with Proton") if err := flagSet.Parse(args); err != nil { return err @@ -70,7 +72,12 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e } if slices.Contains(options.Providers, providers.Protonvpn) { - options.ProtonUsername = &protonUsername + if protonEmail == "" && protonUsername != "" { + protonEmail = protonUsername + "@protonmail.com" + logger.Warn("use -proton-email instead of -proton-username in the future. " + + "This assumes the email is " + protonEmail + " and may not work.") + } + options.ProtonEmail = &protonEmail options.ProtonPassword = &protonPassword } diff --git a/internal/configuration/settings/errors.go b/internal/configuration/settings/errors.go index 91c6b9ed..b3329a5b 100644 --- a/internal/configuration/settings/errors.go +++ b/internal/configuration/settings/errors.go @@ -37,7 +37,7 @@ var ( ErrSystemTimezoneNotValid = errors.New("timezone is not valid") ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small") ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing") - ErrUpdaterProtonUsernameMissing = errors.New("proton username is missing") + ErrUpdaterProtonEmailMissing = errors.New("proton email is missing") ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid") ErrVPNTypeNotValid = errors.New("VPN type is not valid") ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set") diff --git a/internal/configuration/settings/updater.go b/internal/configuration/settings/updater.go index c3564550..70a86995 100644 --- a/internal/configuration/settings/updater.go +++ b/internal/configuration/settings/updater.go @@ -32,8 +32,8 @@ type Updater struct { // Providers is the list of VPN service providers // to update server information for. Providers []string - // ProtonUsername is the username to authenticate with the Proton API. - ProtonUsername *string + // ProtonEmail is the email to authenticate with the Proton API. + ProtonEmail *string // ProtonPassword is the password to authenticate with the Proton API. ProtonPassword *string } @@ -58,11 +58,11 @@ func (u Updater) Validate() (err error) { } if provider == providers.Protonvpn { - authenticatedAPI := *u.ProtonUsername != "" || *u.ProtonPassword != "" + authenticatedAPI := *u.ProtonEmail != "" || *u.ProtonPassword != "" if authenticatedAPI { switch { - case *u.ProtonUsername == "": - return fmt.Errorf("%w", ErrUpdaterProtonUsernameMissing) + case *u.ProtonEmail == "": + return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing) case *u.ProtonPassword == "": return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing) } @@ -79,7 +79,7 @@ func (u *Updater) copy() (copied Updater) { DNSAddress: u.DNSAddress, MinRatio: u.MinRatio, Providers: gosettings.CopySlice(u.Providers), - ProtonUsername: gosettings.CopyPointer(u.ProtonUsername), + ProtonEmail: gosettings.CopyPointer(u.ProtonEmail), ProtonPassword: gosettings.CopyPointer(u.ProtonPassword), } } @@ -92,7 +92,7 @@ func (u *Updater) overrideWith(other Updater) { u.DNSAddress = gosettings.OverrideWithComparable(u.DNSAddress, other.DNSAddress) u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio) u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers) - u.ProtonUsername = gosettings.OverrideWithPointer(u.ProtonUsername, other.ProtonUsername) + u.ProtonEmail = gosettings.OverrideWithPointer(u.ProtonEmail, other.ProtonEmail) u.ProtonPassword = gosettings.OverrideWithPointer(u.ProtonPassword, other.ProtonPassword) } @@ -110,7 +110,7 @@ func (u *Updater) SetDefaults(vpnProvider string) { } // Set these to empty strings to avoid nil pointer panics - u.ProtonUsername = gosettings.DefaultPointer(u.ProtonUsername, "") + u.ProtonEmail = gosettings.DefaultPointer(u.ProtonEmail, "") u.ProtonPassword = gosettings.DefaultPointer(u.ProtonPassword, "") } @@ -129,7 +129,7 @@ func (u Updater) toLinesNode() (node *gotree.Node) { node.Appendf("Minimum ratio: %.1f", u.MinRatio) node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", ")) if slices.Contains(u.Providers, providers.Protonvpn) { - node.Appendf("Proton API username: %s", *u.ProtonUsername) + node.Appendf("Proton API email: %s", *u.ProtonEmail) node.Appendf("Proton API password: %s", gosettings.ObfuscateKey(*u.ProtonPassword)) } @@ -154,12 +154,7 @@ func (u *Updater) read(r *reader.Reader) (err error) { u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS") - u.ProtonUsername = r.Get("UPDATER_PROTONVPN_USERNAME") - if u.ProtonUsername != nil { - // Enforce to use the username not the email address - *u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@protonmail.com") - *u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@proton.me") - } + u.ProtonEmail = r.Get("UPDATER_PROTONVPN_EMAIL", reader.RetroKeys("UPDATER_PROTONVPN_USERNAME")) u.ProtonPassword = r.Get("UPDATER_PROTONVPN_PASSWORD") return nil diff --git a/internal/provider/protonvpn/provider.go b/internal/provider/protonvpn/provider.go index 64704e23..3e646bc2 100644 --- a/internal/provider/protonvpn/provider.go +++ b/internal/provider/protonvpn/provider.go @@ -18,12 +18,12 @@ type Provider struct { func New(storage common.Storage, randSource rand.Source, client *http.Client, updaterWarner common.Warner, - username, password string, + email, password string, ) *Provider { return &Provider{ storage: storage, randSource: randSource, - Fetcher: updater.New(client, updaterWarner, username, password), + Fetcher: updater.New(client, updaterWarner, email, password), } } diff --git a/internal/provider/protonvpn/updater/api.go b/internal/provider/protonvpn/updater/api.go index 006f6739..2d4cfaf0 100644 --- a/internal/provider/protonvpn/updater/api.go +++ b/internal/provider/protonvpn/updater/api.go @@ -76,7 +76,7 @@ func (c *apiClient) setHeaders(request *http.Request, cookie cookie) { // authenticate performs the full Proton authentication flow // to obtain an authenticated cookie (uid, token and session ID). -func (c *apiClient) authenticate(ctx context.Context, username, password string, +func (c *apiClient) authenticate(ctx context.Context, email, password string, ) (authCookie cookie, err error) { sessionID, err := c.getSessionID(ctx) if err != nil { @@ -98,8 +98,8 @@ func (c *apiClient) authenticate(ctx context.Context, username, password string, token: cookieToken, sessionID: sessionID, } - modulusPGPClearSigned, serverEphemeralBase64, saltBase64, - srpSessionHex, version, err := c.authInfo(ctx, username, unauthCookie) + username, modulusPGPClearSigned, serverEphemeralBase64, saltBase64, + srpSessionHex, version, err := c.authInfo(ctx, email, unauthCookie) if err != nil { return cookie{}, fmt.Errorf("getting auth information: %w", err) } @@ -118,7 +118,7 @@ func (c *apiClient) authenticate(ctx context.Context, username, password string, return cookie{}, fmt.Errorf("generating SRP proofs: %w", err) } - authCookie, err = c.auth(ctx, unauthCookie, username, srpSessionHex, proofs) + authCookie, err = c.auth(ctx, unauthCookie, email, srpSessionHex, proofs) if err != nil { return cookie{}, fmt.Errorf("authentifying: %w", err) } @@ -299,48 +299,45 @@ func (c *apiClient) cookieToken(ctx context.Context, sessionID, tokenType, acces return "", fmt.Errorf("%w", ErrAuthCookieNotFound) } -var ( - ErrUsernameDoesNotExist = errors.New("username does not exist") - ErrUsernameMismatch = errors.New("username in response does not match request username") -) +var ErrUsernameDoesNotExist = errors.New("username does not exist") // authInfo fetches SRP parameters for the account. -func (c *apiClient) authInfo(ctx context.Context, username string, unauthCookie cookie) ( - modulusPGPClearSigned, serverEphemeralBase64, saltBase64, srpSessionHex string, +func (c *apiClient) authInfo(ctx context.Context, email string, unauthCookie cookie) ( + username, modulusPGPClearSigned, serverEphemeralBase64, saltBase64, srpSessionHex string, version int, err error, ) { type requestBodySchema struct { - Intent string `json:"Intent"` // "Proton" - Username string `json:"Username"` // username without @domain.com + Intent string `json:"Intent"` // "Proton" + Username string `json:"Username"` } requestBody := requestBodySchema{ Intent: "Proton", - Username: username, + Username: email, } buffer := bytes.NewBuffer(nil) encoder := json.NewEncoder(buffer) if err := encoder.Encode(requestBody); err != nil { - return "", "", "", "", 0, fmt.Errorf("encoding request body: %w", err) + return "", "", "", "", "", 0, fmt.Errorf("encoding request body: %w", err) } request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth/info", buffer) if err != nil { - return "", "", "", "", 0, fmt.Errorf("creating request: %w", err) + return "", "", "", "", "", 0, fmt.Errorf("creating request: %w", err) } c.setHeaders(request, unauthCookie) response, err := c.httpClient.Do(request) if err != nil { - return "", "", "", "", 0, err + return "", "", "", "", "", 0, err } defer response.Body.Close() responseBody, err := io.ReadAll(response.Body) if err != nil { - return "", "", "", "", 0, fmt.Errorf("reading response body: %w", err) + return "", "", "", "", "", 0, fmt.Errorf("reading response body: %w", err) } else if response.StatusCode != http.StatusOK { - return "", "", "", "", 0, buildError(response.StatusCode, responseBody) + return "", "", "", "", "", 0, buildError(response.StatusCode, responseBody) } var info struct { @@ -354,32 +351,30 @@ func (c *apiClient) authInfo(ctx context.Context, username string, unauthCookie } err = json.Unmarshal(responseBody, &info) if err != nil { - return "", "", "", "", 0, fmt.Errorf("decoding response body: %w", err) + return "", "", "", "", "", 0, fmt.Errorf("decoding response body: %w", err) } const successCode = 1000 switch { case info.Code != successCode: - return "", "", "", "", 0, fmt.Errorf("%w: expected %d got %d", + return "", "", "", "", "", 0, fmt.Errorf("%w: expected %d got %d", ErrCodeNotSuccess, successCode, info.Code) case info.Modulus == "": - return "", "", "", "", 0, fmt.Errorf("%w: modulus is empty", ErrDataFieldMissing) + return "", "", "", "", "", 0, fmt.Errorf("%w: modulus is empty", ErrDataFieldMissing) case info.ServerEphemeral == "": - return "", "", "", "", 0, fmt.Errorf("%w: server ephemeral is empty", ErrDataFieldMissing) + return "", "", "", "", "", 0, fmt.Errorf("%w: server ephemeral is empty", ErrDataFieldMissing) case info.Salt == "": - return "", "", "", "", 0, fmt.Errorf("%w (salt data field is empty)", ErrUsernameDoesNotExist) + return "", "", "", "", "", 0, fmt.Errorf("%w (salt data field is empty)", ErrUsernameDoesNotExist) case info.SRPSession == "": - return "", "", "", "", 0, fmt.Errorf("%w: SRP session is empty", ErrDataFieldMissing) - - case !strings.EqualFold(info.Username, username): - return "", "", "", "", 0, fmt.Errorf("%w: expected %s got %s", - ErrUsernameMismatch, username, info.Username) + return "", "", "", "", "", 0, fmt.Errorf("%w: SRP session is empty", ErrDataFieldMissing) + case info.Username == "": + return "", "", "", "", "", 0, fmt.Errorf("%w: username is empty", ErrDataFieldMissing) case info.Version == nil: - return "", "", "", "", 0, fmt.Errorf("%w: version is missing", ErrDataFieldMissing) + return "", "", "", "", "", 0, fmt.Errorf("%w: version is missing", ErrDataFieldMissing) } version = int(*info.Version) //nolint:gosec - return info.Modulus, info.ServerEphemeral, info.Salt, + return info.Username, info.Modulus, info.ServerEphemeral, info.Salt, info.SRPSession, version, nil } diff --git a/internal/provider/protonvpn/updater/servers.go b/internal/provider/protonvpn/updater/servers.go index e625e09e..ff919937 100644 --- a/internal/provider/protonvpn/updater/servers.go +++ b/internal/provider/protonvpn/updater/servers.go @@ -14,8 +14,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( servers []models.Server, err error, ) { switch { - case u.username == "": - return nil, fmt.Errorf("%w: username is empty", common.ErrCredentialsMissing) + case u.email == "": + return nil, fmt.Errorf("%w: email is empty", common.ErrCredentialsMissing) case u.password == "": return nil, fmt.Errorf("%w: password is empty", common.ErrCredentialsMissing) } @@ -25,7 +25,7 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, fmt.Errorf("creating API client: %w", err) } - cookie, err := apiClient.authenticate(ctx, u.username, u.password) + cookie, err := apiClient.authenticate(ctx, u.email, u.password) if err != nil { return nil, fmt.Errorf("authentifying with Proton: %w", err) } diff --git a/internal/provider/protonvpn/updater/updater.go b/internal/provider/protonvpn/updater/updater.go index 3159078a..bc62a778 100644 --- a/internal/provider/protonvpn/updater/updater.go +++ b/internal/provider/protonvpn/updater/updater.go @@ -8,15 +8,15 @@ import ( type Updater struct { client *http.Client - username string + email string password string warner common.Warner } -func New(client *http.Client, warner common.Warner, username, password string) *Updater { +func New(client *http.Client, warner common.Warner, email, password string) *Updater { return &Updater{ client: client, - username: username, + email: email, password: password, warner: warner, } diff --git a/internal/provider/providers.go b/internal/provider/providers.go index 8ed69743..2091c7cf 100644 --- a/internal/provider/providers.go +++ b/internal/provider/providers.go @@ -75,7 +75,7 @@ func NewProviders(storage Storage, timeNow func() time.Time, providers.Privado: privado.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver), providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client), providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver), - providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, *credentials.ProtonUsername, *credentials.ProtonPassword), + providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, *credentials.ProtonEmail, *credentials.ProtonPassword), providers.Purevpn: purevpn.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver), providers.SlickVPN: slickvpn.New(storage, randSource, client, updaterWarner, parallelResolver), providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),