Refactor TwoFactorIsEnabledQuery to optimize premium access checks and improve two-factor provider handling. Introduced bulk fetching of premium status for users with only premium providers and streamlined the logic for determining if two-factor authentication is enabled.

This commit is contained in:
Rui Tome 2025-12-08 12:16:42 +00:00
parent 303b81c002
commit 30ff3da4b7
No known key found for this signature in database
GPG Key ID: 526239D96A8EC066

View File

@ -100,13 +100,25 @@ public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
}
var users = await _userRepository.GetManyAsync([.. userIds]);
var premiumStatus = await _hasPremiumAccessQuery.HasPremiumAccessAsync(userIds);
// Get enabled providers for each user
var usersTwoFactorProvidersMap = users.ToDictionary(u => u.Id, GetEnabledTwoFactorProviders);
// Bulk fetch premium status only for users who need it (those with only premium providers)
var userIdsNeedingPremium = usersTwoFactorProvidersMap
.Where(kvp => kvp.Value.Any() && kvp.Value.All(TwoFactorProvider.RequiresPremium))
.Select(kvp => kvp.Key)
.ToList();
var premiumStatusMap = userIdsNeedingPremium.Count > 0
? await _hasPremiumAccessQuery.HasPremiumAccessAsync(userIdsNeedingPremium)
: new Dictionary<Guid, bool>();
foreach (var user in users)
{
var twoFactorProviders = user.GetTwoFactorProviders();
var hasPremiumAccess = premiumStatus.GetValueOrDefault(user.Id, false);
var twoFactorIsEnabled = TwoFactorIsEnabled(twoFactorProviders, hasPremiumAccess);
var userTwoFactorProviders = usersTwoFactorProvidersMap[user.Id];
var twoFactorIsEnabled = userTwoFactorProviders.Any() &&
(!premiumStatusMap.TryGetValue(user.Id, out var hasPremium) || hasPremium);
result.Add((user.Id, twoFactorIsEnabled));
}
@ -146,50 +158,41 @@ public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
public async Task<bool> TwoFactorIsEnabledVNextAsync(User user)
{
var providers = user.GetTwoFactorProviders();
var hasPremium = await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id);
var enabledProviders = GetEnabledTwoFactorProviders(user);
return TwoFactorIsEnabled(providers, hasPremium);
if (!enabledProviders.Any())
{
return false;
}
// If all providers require premium, check if user has premium access
if (enabledProviders.All(TwoFactorProvider.RequiresPremium))
{
return await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id);
}
// User has at least one non-premium provider
return true;
}
/// <summary>
/// Checks to see what kind of two-factor is enabled.
/// Synchronous version used when premium access status is already known.
/// Gets all enabled two-factor provider types for a user.
/// </summary>
/// <param name="providers">dictionary of two factor providers</param>
/// <param name="hasPremiumAccess">whether the user has premium access</param>
/// <returns>true if the user has two factor enabled; false otherwise</returns>
private static bool TwoFactorIsEnabled(
Dictionary<TwoFactorProviderType, TwoFactorProvider> providers,
bool hasPremiumAccess)
/// <param name="user">user with two factor providers</param>
/// <returns>list of enabled provider types</returns>
private static IList<TwoFactorProviderType> GetEnabledTwoFactorProviders(User user)
{
// If there are no providers, then two factor is not enabled
var providers = user.GetTwoFactorProviders();
if (providers == null || providers.Count == 0)
{
return false;
return Array.Empty<TwoFactorProviderType>();
}
// Get all enabled providers
// TODO: PM-21210: In practice we don't save disabled providers to the database, worth looking into.
var enabledProviderKeys = from provider in providers
where provider.Value?.Enabled ?? false
select provider.Key;
// If no providers are enabled then two factor is not enabled
if (!enabledProviderKeys.Any())
{
return false;
}
// If there are only premium two factor options then check premium access
var onlyHasPremiumTwoFactor = enabledProviderKeys.All(TwoFactorProvider.RequiresPremium);
if (onlyHasPremiumTwoFactor)
{
return hasPremiumAccess;
}
// The user has at least one non-premium two factor option
return true;
return (from provider in providers
where provider.Value?.Enabled ?? false
select provider.Key).ToList();
}
/// <summary>