Adjust email domain restriction logic, consolidate into a function

This commit is contained in:
Mike Amirault
2026-04-10 14:19:38 -04:00
parent 2bafe528bc
commit c43c079376
2 changed files with 14 additions and 15 deletions

View File

@@ -86,7 +86,6 @@ public class SendControlsSyncPolicyEvent(
{
var orgUsers = await organizationUserRepository.GetManyByOrganizationAsync(postUpsertedPolicyState.OrganizationId, null);
var orgUserIds = orgUsers.Where(w => w.UserId != null).Select(s => s.UserId!.Value).ToList();
var domains = (sendControlsPolicyData.AllowedDomains ?? "").Split(",").Select(d => d.Trim()).Where(d => d != "");
var enabled = new List<Guid>();
var enabledSendUserIds = new List<Guid>();
var disabled = new List<Guid>();
@@ -109,7 +108,7 @@ public class SendControlsSyncPolicyEvent(
(sendControlsPolicyData.DisableHideEmail && (userSend.HideEmail ?? false)) ||
(sendControlsPolicyData.WhoCanAccess == SendWhoCanAccessType.PasswordProtected && userSend.AuthType != AuthType.Password) ||
(sendControlsPolicyData.WhoCanAccess == SendWhoCanAccessType.SpecificPeople && userSend.AuthType != AuthType.Email) ||
(sendControlsPolicyData.WhoCanAccess == SendWhoCanAccessType.SpecificPeople && domains.Any() && (userSend.Emails ?? "").Split(",").Select(e => e.Trim()).Any(e => !domains.Any(d => SendValidationService.SendEmailMatchesDomain(e, d)))))
(sendControlsPolicyData.WhoCanAccess == SendWhoCanAccessType.SpecificPeople && !SendValidationService.SendAllEmailsHaveAllowedDomains(userSend.Emails, sendControlsPolicyData.AllowedDomains)))
{
disabled.Add(userSend.Id);
userHadSendsDisabled = true;
@@ -135,6 +134,5 @@ public class SendControlsSyncPolicyEvent(
{
await sendRepository.UpdateManyDisabledAsync(disabled, true, disabledSendUserIds);
}
return;
}
}

View File

@@ -1,7 +1,5 @@
// FIXME: Update this file to be null safe and then delete the line below
#nullable disable
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
@@ -81,20 +79,23 @@ public class SendValidationService : ISendValidationService
if (emailsRequired && sendControlsRequirement.AllowedDomains != null)
{
var domains = sendControlsRequirement.AllowedDomains.Split(",").Select(domain => domain.Trim());
var emails = send.Emails.Split(",").Select(email => email.Trim());
if (emails.Any(email => !domains.Any(domain => SendEmailMatchesDomain(email, domain))))
if (!SendAllEmailsHaveAllowedDomains(send.Emails, sendControlsRequirement.AllowedDomains))
{
throw new BadRequestException($"Due to an Enterprise Policy your Sends must be protected by email verification and access granted only to the following domain(s): {string.Join(", ", domains)}");
throw new BadRequestException($"Due to an Enterprise Policy your Sends must be protected by email verification and access granted only to the following domain(s): {sendControlsRequirement.AllowedDomains}");
}
}
}
public static bool SendEmailMatchesDomain(string email, string domain)
public static bool SendAllEmailsHaveAllowedDomains(string? emailsString, string? domainsString)
{
var emailDomain = EmailValidation.GetDomain(email);
return emailDomain.Equals(domain, StringComparison.OrdinalIgnoreCase)
|| emailDomain.EndsWith("." + domain, StringComparison.OrdinalIgnoreCase);
var domains = (domainsString ?? "").Split(",", StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries);
var emails = (emailsString ?? "").Split(",", StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries);
return emails.All(email => domains.Any(domain =>
{
var emailDomain = EmailValidation.GetDomain(email);
return emailDomain.Equals(domain, StringComparison.OrdinalIgnoreCase)
|| emailDomain.EndsWith("." + domain, StringComparison.OrdinalIgnoreCase);
}));
}
public async Task<long> StorageRemainingForSendAsync(Send send)
@@ -102,7 +103,7 @@ public class SendValidationService : ISendValidationService
var storageBytesRemaining = 0L;
if (send.UserId.HasValue)
{
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
var user = await _userRepository.GetByIdAsync(send.UserId.Value) ?? throw new NotFoundException("Send user not found");
if (!await _userService.CanAccessPremium(user))
{
throw new BadRequestException("You must have premium status to use file Sends.");
@@ -137,7 +138,7 @@ public class SendValidationService : ISendValidationService
}
else if (send.OrganizationId.HasValue)
{
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value) ?? throw new NotFoundException("Send organization not found");
if (!org.MaxStorageGb.HasValue)
{
throw new BadRequestException("This organization cannot use file sends.");