Move enable/disable operations to SubscriberService

This commit is contained in:
Alex Morask 2026-02-03 12:54:46 -06:00
parent 158cbb9447
commit fe0bc1516b
No known key found for this signature in database
GPG Key ID: 23E38285B743E3A8
12 changed files with 394 additions and 210 deletions

View File

@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Notifications;
using Bit.Core.Billing.Pricing;
using Bit.Core.Repositories;
using Bit.Core.Services;

View File

@ -2,6 +2,7 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Notifications;
using Bit.Core.Billing.Services;
using Bit.Core.Repositories;
using OneOf;

View File

@ -1,8 +1,6 @@
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories;
@ -22,13 +20,9 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private readonly IStripeFacade _stripeFacade;
private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand;
private readonly IUserService _userService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IPricingClient _pricingClient;
private readonly IProviderRepository _providerRepository;
private readonly IProviderService _providerService;
private readonly IPushNotificationAdapter _pushNotificationAdapter;
private readonly ISubscriberService _subscriberService;
private readonly IOrganizationRepository _organizationRepository;
public SubscriptionUpdatedHandler(
IStripeEventService stripeEventService,
@ -37,29 +31,19 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
IStripeFacade stripeFacade,
IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand,
IUserService userService,
IOrganizationRepository organizationRepository,
IOrganizationEnableCommand organizationEnableCommand,
IOrganizationDisableCommand organizationDisableCommand,
IPricingClient pricingClient,
IProviderRepository providerRepository,
IProviderService providerService,
IPushNotificationAdapter pushNotificationAdapter)
ISubscriberService subscriberService,
IOrganizationRepository organizationRepository)
{
_stripeEventService = stripeEventService;
_stripeEventUtilityService = stripeEventUtilityService;
_organizationService = organizationService;
_providerService = providerService;
_stripeFacade = stripeFacade;
_organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand;
_userService = userService;
_organizationRepository = organizationRepository;
_providerRepository = providerRepository;
_organizationEnableCommand = organizationEnableCommand;
_organizationDisableCommand = organizationDisableCommand;
_pricingClient = pricingClient;
_providerRepository = providerRepository;
_providerService = providerService;
_pushNotificationAdapter = pushNotificationAdapter;
_subscriberService = subscriberService;
_organizationRepository = organizationRepository;
}
public async Task HandleAsync(Event parsedEvent)
@ -71,12 +55,12 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
if (SubscriptionWentUnpaid(parsedEvent, subscription))
{
await DisableSubscriberAsync(subscriberId, currentPeriodEnd);
await _subscriberService.DisableSubscriberAsync(subscriberId, currentPeriodEnd);
await SetSubscriptionToCancelAsync(subscription);
}
else if (SubscriptionBecameActive(parsedEvent, subscription))
{
await EnableSubscriberAsync(subscriberId, currentPeriodEnd);
await _subscriberService.EnableSubscriberAsync(subscriberId, currentPeriodEnd);
await RemovePendingCancellationAsync(subscription);
}
@ -125,50 +109,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
LatestInvoice.BillingReason: BillingReasons.SubscriptionCreate or BillingReasons.SubscriptionCycle
};
private Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) =>
subscriberId.Match(
userId => _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd),
async organizationId =>
{
await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd);
var organization = await _organizationRepository.GetByIdAsync(organizationId.Value);
if (organization != null)
{
await _pushNotificationAdapter.NotifyEnabledChangedAsync(organization);
}
},
async providerId =>
{
var provider = await _providerRepository.GetByIdAsync(providerId.Value);
if (provider != null)
{
provider.Enabled = false;
await _providerService.UpdateAsync(provider);
}
});
private Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) =>
subscriberId.Match(
userId => _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd),
async organizationId =>
{
await _organizationEnableCommand.EnableAsync(organizationId.Value, currentPeriodEnd);
var organization = await _organizationRepository.GetByIdAsync(organizationId.Value);
if (organization != null)
{
await _pushNotificationAdapter.NotifyEnabledChangedAsync(organization);
}
},
async providerId =>
{
var provider = await _providerRepository.GetByIdAsync(providerId.Value);
if (provider != null)
{
provider.Enabled = true;
await _providerService.UpdateAsync(provider);
}
});
private async Task SetSubscriptionToCancelAsync(Subscription subscription)
{
if (subscription.TestClock != null)

View File

@ -101,7 +101,6 @@ public class Startup
services.AddScoped<IStripeFacade, StripeFacade>();
services.AddScoped<IStripeEventService, StripeEventService>();
services.AddScoped<IProviderEventService, ProviderEventService>();
services.AddScoped<IPushNotificationAdapter, PushNotificationAdapter>();
// Add Quartz services first
services.AddQuartz(q =>

View File

@ -2,6 +2,7 @@
using Bit.Core.Billing.Caches.Implementations;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Licenses.Extensions;
using Bit.Core.Billing.Notifications;
using Bit.Core.Billing.Organizations.Commands;
using Bit.Core.Billing.Organizations.Queries;
using Bit.Core.Billing.Organizations.Services;
@ -31,6 +32,7 @@ public static class ServiceCollectionExtensions
services.AddTransient<IPremiumUserBillingService, PremiumUserBillingService>();
services.AddTransient<ISetupIntentCache, SetupIntentDistributedCache>();
services.AddTransient<ISubscriberService, SubscriberService>();
services.AddTransient<IPushNotificationAdapter, PushNotificationAdapter>();
services.AddLicenseServices();
services.AddLicenseOperations();
services.AddPricingClient();

View File

@ -1,7 +1,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
namespace Bit.Billing.Services;
namespace Bit.Core.Billing.Notifications;
public interface IPushNotificationAdapter
{

View File

@ -6,7 +6,7 @@ using Bit.Core.Enums;
using Bit.Core.Models;
using Bit.Core.Platform.Push;
namespace Bit.Billing.Services.Implementations;
namespace Bit.Core.Billing.Notifications;
public class PushNotificationAdapter(
IProviderUserRepository providerUserRepository,

View File

@ -2,6 +2,7 @@
#nullable disable
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
@ -143,4 +144,20 @@ public interface ISubscriberService
/// <returns><see langword="true"/> if the gateway subscription ID is valid or empty; <see langword="false"/> if the subscription doesn't exist in the gateway.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
Task<bool> IsValidGatewaySubscriptionIdAsync(ISubscriber subscriber);
/// <summary>
/// Disables a subscriber based on the <paramref name="subscriberId"/> type.
/// For users, this disables premium. For organizations and providers, this disables the entity.
/// </summary>
/// <param name="subscriberId">The subscriber identifier (UserId, OrganizationId, or ProviderId).</param>
/// <param name="currentPeriodEnd">The current billing period end date to set as the expiration date.</param>
Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd);
/// <summary>
/// Enables a subscriber based on the <paramref name="subscriberId"/> type.
/// For users, this enables premium. For organizations and providers, this enables the entity.
/// </summary>
/// <param name="subscriberId">The subscriber identifier (UserId, OrganizationId, or ProviderId).</param>
/// <param name="currentPeriodEnd">The current billing period end date to set as the expiration date.</param>
Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd);
}

View File

@ -3,18 +3,23 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Notifications;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Braintree;
@ -33,12 +38,17 @@ public class SubscriberService(
IBraintreeGateway braintreeGateway,
IGlobalSettings globalSettings,
ILogger<SubscriberService> logger,
IOrganizationDisableCommand organizationDisableCommand,
IOrganizationEnableCommand organizationEnableCommand,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IProviderService providerService,
IPushNotificationAdapter pushNotificationAdapter,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ITaxService taxService,
IUserRepository userRepository) : ISubscriberService
IUserRepository userRepository,
IUserService userService) : ISubscriberService
{
public async Task CancelSubscription(
ISubscriber subscriber,
@ -817,6 +827,50 @@ public class SubscriberService(
}
}
public Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) =>
subscriberId.Match(
userId => userService.DisablePremiumAsync(userId.Value, currentPeriodEnd),
async organizationId =>
{
await organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd);
var organization = await organizationRepository.GetByIdAsync(organizationId.Value);
if (organization != null)
{
await pushNotificationAdapter.NotifyEnabledChangedAsync(organization);
}
},
async providerId =>
{
var provider = await providerRepository.GetByIdAsync(providerId.Value);
if (provider != null)
{
provider.Enabled = false;
await providerService.UpdateAsync(provider);
}
});
public Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) =>
subscriberId.Match(
userId => userService.EnablePremiumAsync(userId.Value, currentPeriodEnd),
async organizationId =>
{
await organizationEnableCommand.EnableAsync(organizationId.Value, currentPeriodEnd);
var organization = await organizationRepository.GetByIdAsync(organizationId.Value);
if (organization != null)
{
await pushNotificationAdapter.NotifyEnabledChangedAsync(organization);
}
},
async providerId =>
{
var provider = await providerRepository.GetByIdAsync(providerId.Value);
if (provider != null)
{
provider.Enabled = true;
await providerService.UpdateAsync(provider);
}
});
#region Shared Utilities
private async Task AddBraintreeCustomerIdAsync(

View File

@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Notifications;
using Bit.Core.Billing.Services;
using Bit.Core.Repositories;
using NSubstitute;

View File

@ -2,11 +2,10 @@
using Bit.Billing.Services.Implementations;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
@ -14,7 +13,6 @@ using Bit.Core.Test.Billing.Mocks;
using Bit.Core.Test.Billing.Mocks.Plans;
using Newtonsoft.Json.Linq;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Constants.StripeConstants;
@ -30,13 +28,9 @@ public class SubscriptionUpdatedHandlerTests
private readonly IStripeFacade _stripeFacade;
private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand;
private readonly IUserService _userService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IPricingClient _pricingClient;
private readonly IProviderRepository _providerRepository;
private readonly IProviderService _providerService;
private readonly IPushNotificationAdapter _pushNotificationAdapter;
private readonly ISubscriberService _subscriberService;
private readonly IOrganizationRepository _organizationRepository;
private readonly SubscriptionUpdatedHandler _sut;
public SubscriptionUpdatedHandlerTests()
@ -47,14 +41,9 @@ public class SubscriptionUpdatedHandlerTests
_stripeFacade = Substitute.For<IStripeFacade>();
_organizationSponsorshipRenewCommand = Substitute.For<IOrganizationSponsorshipRenewCommand>();
_userService = Substitute.For<IUserService>();
_providerService = Substitute.For<IProviderService>();
_organizationRepository = Substitute.For<IOrganizationRepository>();
_organizationEnableCommand = Substitute.For<IOrganizationEnableCommand>();
_organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>();
_pricingClient = Substitute.For<IPricingClient>();
_providerRepository = Substitute.For<IProviderRepository>();
_providerService = Substitute.For<IProviderService>();
_pushNotificationAdapter = Substitute.For<IPushNotificationAdapter>();
_subscriberService = Substitute.For<ISubscriberService>();
_organizationRepository = Substitute.For<IOrganizationRepository>();
_sut = new SubscriptionUpdatedHandler(
_stripeEventService,
@ -63,13 +52,9 @@ public class SubscriptionUpdatedHandlerTests
_stripeFacade,
_organizationSponsorshipRenewCommand,
_userService,
_organizationRepository,
_organizationEnableCommand,
_organizationDisableCommand,
_pricingClient,
_providerRepository,
_providerService,
_pushNotificationAdapter);
_subscriberService,
_organizationRepository);
}
[Fact]
@ -119,8 +104,6 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventService.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(subscription);
_organizationRepository.GetByIdAsync(organizationId).Returns(organization);
var plan = new Enterprise2023Plan(true);
_pricingClient.GetPlanOrThrow(organization.PlanType).Returns(plan);
_pricingClient.ListPlans().Returns(MockPlans.Plans);
@ -129,10 +112,10 @@ public class SubscriptionUpdatedHandlerTests
await _sut.HandleAsync(parsedEvent);
// Assert
await _organizationDisableCommand.Received(1)
.DisableAsync(organizationId, currentPeriodEnd);
await _pushNotificationAdapter.Received(1)
.NotifyEnabledChangedAsync(organization);
await _subscriberService.Received(1)
.DisableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(_ => false, o => o.Value == organizationId, _ => false)),
currentPeriodEnd);
await _stripeFacade.Received(1).UpdateSubscription(
subscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options =>
@ -150,6 +133,7 @@ public class SubscriptionUpdatedHandlerTests
// Arrange
var providerId = Guid.NewGuid();
var subscriptionId = "sub_test123";
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var previousSubscription = new Subscription
{
@ -165,7 +149,7 @@ public class SubscriptionUpdatedHandlerTests
{
Data =
[
new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) }
new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd }
]
},
Metadata = new Dictionary<string, string> { ["providerId"] = providerId.ToString() },
@ -182,17 +166,16 @@ public class SubscriptionUpdatedHandlerTests
}
};
var provider = new Provider { Id = providerId, Enabled = true };
_stripeEventService.GetSubscription(parsedEvent, true, Arg.Any<List<string>>()).Returns(currentSubscription);
_providerRepository.GetByIdAsync(providerId).Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert
Assert.False(provider.Enabled);
await _providerService.Received(1).UpdateAsync(provider);
await _subscriberService.Received(1)
.DisableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(_ => false, _ => false, p => p.Value == providerId)),
currentPeriodEnd);
// Verify that UpdateSubscription was called with CancelAt
await _stripeFacade.Received(1).UpdateSubscription(
@ -233,8 +216,6 @@ public class SubscriptionUpdatedHandlerTests
LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle }
};
var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true };
var parsedEvent = new Event
{
Data = new EventData
@ -247,15 +228,11 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventService.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(subscription);
_providerRepository.GetByIdAsync(providerId)
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert - No disable or cancellation since there was no valid status transition
Assert.True(provider.Enabled);
await _providerService.DidNotReceive().UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().DisableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
}
@ -288,8 +265,6 @@ public class SubscriptionUpdatedHandlerTests
LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle }
};
var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true };
var parsedEvent = new Event
{
Data = new EventData
@ -302,15 +277,11 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventService.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(subscription);
_providerRepository.GetByIdAsync(providerId)
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert - No disable or cancellation since the previous status (Canceled) is not a valid transition source
Assert.True(provider.Enabled);
await _providerService.DidNotReceive().UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().DisableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
}
@ -344,8 +315,6 @@ public class SubscriptionUpdatedHandlerTests
LatestInvoice = new Invoice { BillingReason = "renewal" }
};
var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true };
var parsedEvent = new Event
{
Data = new EventData
@ -358,20 +327,17 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventService.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(subscription);
_providerRepository.GetByIdAsync(providerId)
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert - IncompleteExpired status is not handled by the new logic
Assert.True(provider.Enabled);
await _providerService.DidNotReceive().UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().DisableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
}
[Fact]
public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_StillSetsCancellation()
public async Task HandleAsync_UnpaidProviderSubscription_StillSetsCancellation()
{
// Arrange
var providerId = Guid.NewGuid();
@ -411,14 +377,14 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventService.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(subscription);
_providerRepository.GetByIdAsync(providerId)
.Returns((Provider)null);
// Act
await _sut.HandleAsync(parsedEvent);
// Assert - Provider not updated (since not found), but cancellation is still set
await _providerService.DidNotReceive().UpdateAsync(Arg.Any<Provider>());
// Assert - DisableSubscriberAsync is called and cancellation is set
await _subscriberService.Received(1)
.DisableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(_ => false, _ => false, p => p.Value == providerId)),
currentPeriodEnd);
await _stripeFacade.Received(1).UpdateSubscription(
subscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options =>
@ -474,8 +440,10 @@ public class SubscriptionUpdatedHandlerTests
await _sut.HandleAsync(parsedEvent);
// Assert
await _userService.Received(1)
.DisablePremiumAsync(userId, currentPeriodEnd);
await _subscriberService.Received(1)
.DisableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(u => u.Value == userId, _ => false, _ => false)),
currentPeriodEnd);
await _stripeFacade.Received(1).UpdateSubscription(
subscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options =>
@ -569,7 +537,6 @@ public class SubscriptionUpdatedHandlerTests
LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle }
};
var organization = new Organization { Id = organizationId, PlanType = PlanType.EnterpriseAnnually2023 };
var parsedEvent = new Event
{
Data = new EventData
@ -582,11 +549,8 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventService.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(subscription);
_organizationRepository.GetByIdAsync(organizationId)
.Returns(organization);
var plan = new Enterprise2023Plan(true);
_pricingClient.GetPlanOrThrow(organization.PlanType)
_pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually2023)
.Returns(plan);
_pricingClient.ListPlans()
.Returns(MockPlans.Plans);
@ -595,12 +559,12 @@ public class SubscriptionUpdatedHandlerTests
await _sut.HandleAsync(parsedEvent);
// Assert
await _organizationEnableCommand.Received(1)
.EnableAsync(organizationId, currentPeriodEnd);
await _subscriberService.Received(1)
.EnableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(_ => false, o => o.Value == organizationId, _ => false)),
currentPeriodEnd);
await _organizationService.Received(1)
.UpdateExpirationDateAsync(organizationId, currentPeriodEnd);
await _pushNotificationAdapter.Received(1)
.NotifyEnabledChangedAsync(organization);
await _stripeFacade.Received(1).UpdateSubscription(
subscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options =>
@ -653,8 +617,10 @@ public class SubscriptionUpdatedHandlerTests
await _sut.HandleAsync(parsedEvent);
// Assert
await _userService.Received(1)
.EnablePremiumAsync(userId, currentPeriodEnd);
await _subscriberService.Received(1)
.EnableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(u => u.Value == userId, _ => false, _ => false)),
currentPeriodEnd);
await _userService.Received(1)
.UpdatePremiumExpirationAsync(userId, currentPeriodEnd);
await _stripeFacade.Received(1).UpdateSubscription(
@ -881,16 +847,13 @@ public class SubscriptionUpdatedHandlerTests
Subscription previousSubscription)
{
// Arrange
var (providerId, newSubscription, provider, parsedEvent) =
var (providerId, newSubscription, _, parsedEvent) =
CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription);
_stripeEventService
.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(newSubscription);
_providerRepository
.GetByIdAsync(Arg.Any<Guid>())
.Returns(provider);
_stripeFacade
.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(newSubscription);
@ -902,12 +865,11 @@ public class SubscriptionUpdatedHandlerTests
await _stripeEventService
.Received(1)
.GetSubscription(parsedEvent, true, Arg.Any<List<string>>());
await _providerRepository
await _subscriberService
.Received(1)
.GetByIdAsync(providerId);
await _providerService
.Received(1)
.UpdateAsync(Arg.Is<Provider>(p => p.Id == providerId && p.Enabled == true));
.EnableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(_ => false, _ => false, p => p.Value == providerId)),
Arg.Any<DateTime?>());
await _stripeFacade
.Received(1)
.UpdateSubscription(newSubscription.Id,
@ -922,15 +884,12 @@ public class SubscriptionUpdatedHandlerTests
{
// Arrange
var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Canceled };
var (providerId, newSubscription, provider, parsedEvent) =
var (_, newSubscription, _, parsedEvent) =
CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription);
_stripeEventService
.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(newSubscription);
_providerRepository
.GetByIdAsync(Arg.Any<Guid>())
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
@ -939,10 +898,7 @@ public class SubscriptionUpdatedHandlerTests
await _stripeEventService
.Received(1)
.GetSubscription(parsedEvent, true, Arg.Any<List<string>>());
await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any<Guid>());
await _providerService
.DidNotReceive()
.UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade
.DidNotReceiveWithAnyArgs()
.UpdateSubscription(Arg.Any<string>());
@ -954,15 +910,12 @@ public class SubscriptionUpdatedHandlerTests
{
// Arrange
var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Active };
var (providerId, newSubscription, provider, parsedEvent) =
var (_, newSubscription, _, parsedEvent) =
CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription);
_stripeEventService
.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(newSubscription);
_providerRepository
.GetByIdAsync(Arg.Any<Guid>())
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
@ -971,10 +924,7 @@ public class SubscriptionUpdatedHandlerTests
await _stripeEventService
.Received(1)
.GetSubscription(parsedEvent, true, Arg.Any<List<string>>());
await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any<Guid>());
await _providerService
.DidNotReceive()
.UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade
.DidNotReceiveWithAnyArgs()
.UpdateSubscription(Arg.Any<string>());
@ -986,15 +936,12 @@ public class SubscriptionUpdatedHandlerTests
{
// Arrange
var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Trialing };
var (providerId, newSubscription, provider, parsedEvent) =
var (_, newSubscription, _, parsedEvent) =
CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription);
_stripeEventService
.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(newSubscription);
_providerRepository
.GetByIdAsync(Arg.Any<Guid>())
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
@ -1003,10 +950,7 @@ public class SubscriptionUpdatedHandlerTests
await _stripeEventService
.Received(1)
.GetSubscription(parsedEvent, true, Arg.Any<List<string>>());
await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any<Guid>());
await _providerService
.DidNotReceive()
.UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade
.DidNotReceiveWithAnyArgs()
.UpdateSubscription(Arg.Any<string>());
@ -1018,15 +962,12 @@ public class SubscriptionUpdatedHandlerTests
{
// Arrange
var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.PastDue };
var (providerId, newSubscription, provider, parsedEvent) =
var (_, newSubscription, _, parsedEvent) =
CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription);
_stripeEventService
.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(newSubscription);
_providerRepository
.GetByIdAsync(Arg.Any<Guid>())
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
@ -1035,17 +976,14 @@ public class SubscriptionUpdatedHandlerTests
await _stripeEventService
.Received(1)
.GetSubscription(parsedEvent, true, Arg.Any<List<string>>());
await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any<Guid>());
await _providerService
.DidNotReceive()
.UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade
.DidNotReceiveWithAnyArgs()
.UpdateSubscription(Arg.Any<string>());
}
[Fact]
public async Task HandleAsync_ActiveProviderSubscriptionEvent_AndProviderDoesNotExist_NoChanges()
public async Task HandleAsync_ActiveProviderSubscriptionEvent_EnablesProviderViaSubscriberService()
{
// Arrange
var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Unpaid };
@ -1055,9 +993,6 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventService
.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(newSubscription);
_providerRepository
.GetByIdAsync(Arg.Any<Guid>())
.ReturnsNull();
// Act
await _sut.HandleAsync(parsedEvent);
@ -1066,15 +1001,14 @@ public class SubscriptionUpdatedHandlerTests
await _stripeEventService
.Received(1)
.GetSubscription(parsedEvent, true, Arg.Any<List<string>>());
await _providerRepository
await _subscriberService
.Received(1)
.GetByIdAsync(providerId);
await _providerService
.DidNotReceive()
.UpdateAsync(Arg.Any<Provider>());
.EnableSubscriberAsync(
Arg.Is<SubscriberId>(s => s.Match(_ => false, _ => false, p => p.Value == providerId)),
Arg.Any<DateTime?>());
await _stripeFacade
.DidNotReceive()
.UpdateSubscription(Arg.Any<string>());
.Received(1)
.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
}
[Fact]
@ -1082,15 +1016,12 @@ public class SubscriptionUpdatedHandlerTests
{
// Arrange - Using a previous status (Canceled) that doesn't trigger SubscriptionBecameActive
var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Canceled };
var (providerId, newSubscription, provider, parsedEvent) =
var (_, newSubscription, _, parsedEvent) =
CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription);
_stripeEventService
.GetSubscription(Arg.Any<Event>(), Arg.Any<bool>(), Arg.Any<List<string>>())
.Returns(newSubscription);
_providerRepository
.GetByIdAsync(Arg.Any<Guid>())
.Returns(provider);
// Act
await _sut.HandleAsync(parsedEvent);
@ -1099,13 +1030,10 @@ public class SubscriptionUpdatedHandlerTests
await _stripeEventService
.Received(1)
.GetSubscription(parsedEvent, true, Arg.Any<List<string>>());
await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any<Guid>());
await _providerService
.DidNotReceive()
.UpdateAsync(Arg.Any<Provider>());
await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any<SubscriberId>(), Arg.Any<DateTime?>());
await _stripeFacade
.DidNotReceive()
.UpdateSubscription(Arg.Any<string>());
.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>());
}
private static (Guid providerId, Subscription newSubscription, Provider provider, Event parsedEvent)

View File

@ -1,12 +1,19 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Notifications;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@ -1771,4 +1778,238 @@ public class SubscriberServiceTests
}
#endregion
#region DisableSubscriberAsync
[Theory, BitAutoData]
public async Task DisableSubscriberAsync_UserId_DisablesPremium(
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var userId = Guid.NewGuid();
var subscriberId = new UserId(userId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var userService = sutProvider.GetDependency<IUserService>();
// Act
await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await userService.Received(1).DisablePremiumAsync(userId, currentPeriodEnd);
}
[Theory, BitAutoData]
public async Task DisableSubscriberAsync_OrganizationId_DisablesOrganizationAndNotifies(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var organizationId = organization.Id;
var subscriberId = new OrganizationId(organizationId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var organizationDisableCommand = sutProvider.GetDependency<IOrganizationDisableCommand>();
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
var pushNotificationAdapter = sutProvider.GetDependency<IPushNotificationAdapter>();
organizationRepository.GetByIdAsync(organizationId).Returns(organization);
// Act
await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await organizationDisableCommand.Received(1).DisableAsync(organizationId, currentPeriodEnd);
await organizationRepository.Received(1).GetByIdAsync(organizationId);
await pushNotificationAdapter.Received(1).NotifyEnabledChangedAsync(organization);
}
[Theory, BitAutoData]
public async Task DisableSubscriberAsync_OrganizationId_OrganizationNotFound_DoesNotNotify(
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var organizationId = Guid.NewGuid();
var subscriberId = new OrganizationId(organizationId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var organizationDisableCommand = sutProvider.GetDependency<IOrganizationDisableCommand>();
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
var pushNotificationAdapter = sutProvider.GetDependency<IPushNotificationAdapter>();
organizationRepository.GetByIdAsync(organizationId).ReturnsNull();
// Act
await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await organizationDisableCommand.Received(1).DisableAsync(organizationId, currentPeriodEnd);
await organizationRepository.Received(1).GetByIdAsync(organizationId);
await pushNotificationAdapter.DidNotReceive().NotifyEnabledChangedAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task DisableSubscriberAsync_ProviderId_DisablesProvider(
Provider provider,
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var providerId = provider.Id;
provider.Enabled = true;
var subscriberId = new ProviderId(providerId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
var providerService = sutProvider.GetDependency<IProviderService>();
providerRepository.GetByIdAsync(providerId).Returns(provider);
// Act
await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await providerRepository.Received(1).GetByIdAsync(providerId);
await providerService.Received(1).UpdateAsync(Arg.Is<Provider>(p => p.Id == providerId && p.Enabled == false));
}
[Theory, BitAutoData]
public async Task DisableSubscriberAsync_ProviderId_ProviderNotFound_DoesNotUpdate(
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var providerId = Guid.NewGuid();
var subscriberId = new ProviderId(providerId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
var providerService = sutProvider.GetDependency<IProviderService>();
providerRepository.GetByIdAsync(providerId).ReturnsNull();
// Act
await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await providerRepository.Received(1).GetByIdAsync(providerId);
await providerService.DidNotReceive().UpdateAsync(Arg.Any<Provider>());
}
#endregion
#region EnableSubscriberAsync
[Theory, BitAutoData]
public async Task EnableSubscriberAsync_UserId_EnablesPremium(
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var userId = Guid.NewGuid();
var subscriberId = new UserId(userId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var userService = sutProvider.GetDependency<IUserService>();
// Act
await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await userService.Received(1).EnablePremiumAsync(userId, currentPeriodEnd);
}
[Theory, BitAutoData]
public async Task EnableSubscriberAsync_OrganizationId_EnablesOrganizationAndNotifies(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var organizationId = organization.Id;
var subscriberId = new OrganizationId(organizationId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var organizationEnableCommand = sutProvider.GetDependency<IOrganizationEnableCommand>();
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
var pushNotificationAdapter = sutProvider.GetDependency<IPushNotificationAdapter>();
organizationRepository.GetByIdAsync(organizationId).Returns(organization);
// Act
await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await organizationEnableCommand.Received(1).EnableAsync(organizationId, currentPeriodEnd);
await organizationRepository.Received(1).GetByIdAsync(organizationId);
await pushNotificationAdapter.Received(1).NotifyEnabledChangedAsync(organization);
}
[Theory, BitAutoData]
public async Task EnableSubscriberAsync_OrganizationId_OrganizationNotFound_DoesNotNotify(
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var organizationId = Guid.NewGuid();
var subscriberId = new OrganizationId(organizationId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var organizationEnableCommand = sutProvider.GetDependency<IOrganizationEnableCommand>();
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
var pushNotificationAdapter = sutProvider.GetDependency<IPushNotificationAdapter>();
organizationRepository.GetByIdAsync(organizationId).ReturnsNull();
// Act
await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await organizationEnableCommand.Received(1).EnableAsync(organizationId, currentPeriodEnd);
await organizationRepository.Received(1).GetByIdAsync(organizationId);
await pushNotificationAdapter.DidNotReceive().NotifyEnabledChangedAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task EnableSubscriberAsync_ProviderId_EnablesProvider(
Provider provider,
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var providerId = provider.Id;
provider.Enabled = false;
var subscriberId = new ProviderId(providerId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
var providerService = sutProvider.GetDependency<IProviderService>();
providerRepository.GetByIdAsync(providerId).Returns(provider);
// Act
await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await providerRepository.Received(1).GetByIdAsync(providerId);
await providerService.Received(1).UpdateAsync(Arg.Is<Provider>(p => p.Id == providerId && p.Enabled == true));
}
[Theory, BitAutoData]
public async Task EnableSubscriberAsync_ProviderId_ProviderNotFound_DoesNotUpdate(
SutProvider<SubscriberService> sutProvider)
{
// Arrange
var providerId = Guid.NewGuid();
var subscriberId = new ProviderId(providerId);
var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
var providerService = sutProvider.GetDependency<IProviderService>();
providerRepository.GetByIdAsync(providerId).ReturnsNull();
// Act
await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd);
// Assert
await providerRepository.Received(1).GetByIdAsync(providerId);
await providerService.DidNotReceive().UpdateAsync(Arg.Any<Provider>());
}
#endregion
}