From fe0bc1516b86a44cfadc7c2cd31bbb269051046e Mon Sep 17 00:00:00 2001 From: Alex Morask Date: Tue, 3 Feb 2026 12:54:46 -0600 Subject: [PATCH] Move enable/disable operations to SubscriberService --- .../PaymentSucceededHandler.cs | 1 + .../SetupIntentSucceededHandler.cs | 1 + .../SubscriptionUpdatedHandler.cs | 80 +----- src/Billing/Startup.cs | 1 - .../Extensions/ServiceCollectionExtensions.cs | 2 + .../IPushNotificationAdapter.cs | 2 +- .../Notifications}/PushNotificationAdapter.cs | 2 +- .../Billing/Services/ISubscriberService.cs | 17 ++ .../Implementations/SubscriberService.cs | 56 +++- .../SetupIntentSucceededHandlerTests.cs | 1 + .../SubscriptionUpdatedHandlerTests.cs | 200 +++++---------- .../Services/SubscriberServiceTests.cs | 241 ++++++++++++++++++ 12 files changed, 394 insertions(+), 210 deletions(-) rename src/{Billing/Services => Core/Billing/Notifications}/IPushNotificationAdapter.cs (88%) rename src/{Billing/Services/Implementations => Core/Billing/Notifications}/PushNotificationAdapter.cs (98%) diff --git a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs index 443227f7bf..7b2eb554db 100644 --- a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Notifications; using Bit.Core.Billing.Pricing; using Bit.Core.Repositories; using Bit.Core.Services; diff --git a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs index 89e40f0e43..2324951ad8 100644 --- a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Notifications; using Bit.Core.Billing.Services; using Bit.Core.Repositories; using OneOf; diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index 4507d9e308..1f2a3aaddf 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -1,8 +1,6 @@ -using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.AdminConsole.Services; -using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -22,13 +20,9 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler private readonly IStripeFacade _stripeFacade; private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationEnableCommand _organizationEnableCommand; - private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IPricingClient _pricingClient; - private readonly IProviderRepository _providerRepository; - private readonly IProviderService _providerService; - private readonly IPushNotificationAdapter _pushNotificationAdapter; + private readonly ISubscriberService _subscriberService; + private readonly IOrganizationRepository _organizationRepository; public SubscriptionUpdatedHandler( IStripeEventService stripeEventService, @@ -37,29 +31,19 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler IStripeFacade stripeFacade, IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand, IUserService userService, - IOrganizationRepository organizationRepository, - IOrganizationEnableCommand organizationEnableCommand, - IOrganizationDisableCommand organizationDisableCommand, IPricingClient pricingClient, - IProviderRepository providerRepository, - IProviderService providerService, - IPushNotificationAdapter pushNotificationAdapter) + ISubscriberService subscriberService, + IOrganizationRepository organizationRepository) { _stripeEventService = stripeEventService; _stripeEventUtilityService = stripeEventUtilityService; _organizationService = organizationService; - _providerService = providerService; _stripeFacade = stripeFacade; _organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand; _userService = userService; - _organizationRepository = organizationRepository; - _providerRepository = providerRepository; - _organizationEnableCommand = organizationEnableCommand; - _organizationDisableCommand = organizationDisableCommand; _pricingClient = pricingClient; - _providerRepository = providerRepository; - _providerService = providerService; - _pushNotificationAdapter = pushNotificationAdapter; + _subscriberService = subscriberService; + _organizationRepository = organizationRepository; } public async Task HandleAsync(Event parsedEvent) @@ -71,12 +55,12 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler if (SubscriptionWentUnpaid(parsedEvent, subscription)) { - await DisableSubscriberAsync(subscriberId, currentPeriodEnd); + await _subscriberService.DisableSubscriberAsync(subscriberId, currentPeriodEnd); await SetSubscriptionToCancelAsync(subscription); } else if (SubscriptionBecameActive(parsedEvent, subscription)) { - await EnableSubscriberAsync(subscriberId, currentPeriodEnd); + await _subscriberService.EnableSubscriberAsync(subscriberId, currentPeriodEnd); await RemovePendingCancellationAsync(subscription); } @@ -125,50 +109,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler LatestInvoice.BillingReason: BillingReasons.SubscriptionCreate or BillingReasons.SubscriptionCycle }; - private Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => - subscriberId.Match( - userId => _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd), - async organizationId => - { - await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd); - var organization = await _organizationRepository.GetByIdAsync(organizationId.Value); - if (organization != null) - { - await _pushNotificationAdapter.NotifyEnabledChangedAsync(organization); - } - }, - async providerId => - { - var provider = await _providerRepository.GetByIdAsync(providerId.Value); - if (provider != null) - { - provider.Enabled = false; - await _providerService.UpdateAsync(provider); - } - }); - - private Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => - subscriberId.Match( - userId => _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd), - async organizationId => - { - await _organizationEnableCommand.EnableAsync(organizationId.Value, currentPeriodEnd); - var organization = await _organizationRepository.GetByIdAsync(organizationId.Value); - if (organization != null) - { - await _pushNotificationAdapter.NotifyEnabledChangedAsync(organization); - } - }, - async providerId => - { - var provider = await _providerRepository.GetByIdAsync(providerId.Value); - if (provider != null) - { - provider.Enabled = true; - await _providerService.UpdateAsync(provider); - } - }); - private async Task SetSubscriptionToCancelAsync(Subscription subscription) { if (subscription.TestClock != null) diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs index f5f98bfd53..6e2b93563d 100644 --- a/src/Billing/Startup.cs +++ b/src/Billing/Startup.cs @@ -101,7 +101,6 @@ public class Startup services.AddScoped(); services.AddScoped(); services.AddScoped(); - services.AddScoped(); // Add Quartz services first services.AddQuartz(q => diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index c61c4e6279..13a120a1f4 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Caches.Implementations; using Bit.Core.Billing.Licenses; using Bit.Core.Billing.Licenses.Extensions; +using Bit.Core.Billing.Notifications; using Bit.Core.Billing.Organizations.Commands; using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Organizations.Services; @@ -31,6 +32,7 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddLicenseServices(); services.AddLicenseOperations(); services.AddPricingClient(); diff --git a/src/Billing/Services/IPushNotificationAdapter.cs b/src/Core/Billing/Notifications/IPushNotificationAdapter.cs similarity index 88% rename from src/Billing/Services/IPushNotificationAdapter.cs rename to src/Core/Billing/Notifications/IPushNotificationAdapter.cs index 2f74f35eec..7981c50428 100644 --- a/src/Billing/Services/IPushNotificationAdapter.cs +++ b/src/Core/Billing/Notifications/IPushNotificationAdapter.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; -namespace Bit.Billing.Services; +namespace Bit.Core.Billing.Notifications; public interface IPushNotificationAdapter { diff --git a/src/Billing/Services/Implementations/PushNotificationAdapter.cs b/src/Core/Billing/Notifications/PushNotificationAdapter.cs similarity index 98% rename from src/Billing/Services/Implementations/PushNotificationAdapter.cs rename to src/Core/Billing/Notifications/PushNotificationAdapter.cs index 673ae1415e..81a9244383 100644 --- a/src/Billing/Services/Implementations/PushNotificationAdapter.cs +++ b/src/Core/Billing/Notifications/PushNotificationAdapter.cs @@ -6,7 +6,7 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Platform.Push; -namespace Bit.Billing.Services.Implementations; +namespace Bit.Core.Billing.Notifications; public class PushNotificationAdapter( IProviderUserRepository providerUserRepository, diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs index 343a0e4f38..da8877a33a 100644 --- a/src/Core/Billing/Services/ISubscriberService.cs +++ b/src/Core/Billing/Services/ISubscriberService.cs @@ -2,6 +2,7 @@ #nullable disable using Bit.Core.Billing.Models; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Billing.Tax.Models; using Bit.Core.Entities; using Bit.Core.Enums; @@ -143,4 +144,20 @@ public interface ISubscriberService /// if the gateway subscription ID is valid or empty; if the subscription doesn't exist in the gateway. /// Thrown when the is . Task IsValidGatewaySubscriptionIdAsync(ISubscriber subscriber); + + /// + /// Disables a subscriber based on the type. + /// For users, this disables premium. For organizations and providers, this disables the entity. + /// + /// The subscriber identifier (UserId, OrganizationId, or ProviderId). + /// The current billing period end date to set as the expiration date. + Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd); + + /// + /// Enables a subscriber based on the type. + /// For users, this enables premium. For organizations and providers, this enables the entity. + /// + /// The subscriber identifier (UserId, OrganizationId, or ProviderId). + /// The current billing period end date to set as the expiration date. + Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd); } diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 7acbe20014..9aa6e1f34d 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -3,18 +3,23 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Notifications; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -33,12 +38,17 @@ public class SubscriberService( IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, ILogger logger, + IOrganizationDisableCommand organizationDisableCommand, + IOrganizationEnableCommand organizationEnableCommand, IOrganizationRepository organizationRepository, IProviderRepository providerRepository, + IProviderService providerService, + IPushNotificationAdapter pushNotificationAdapter, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ITaxService taxService, - IUserRepository userRepository) : ISubscriberService + IUserRepository userRepository, + IUserService userService) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -817,6 +827,50 @@ public class SubscriberService( } } + public Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => + subscriberId.Match( + userId => userService.DisablePremiumAsync(userId.Value, currentPeriodEnd), + async organizationId => + { + await organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd); + var organization = await organizationRepository.GetByIdAsync(organizationId.Value); + if (organization != null) + { + await pushNotificationAdapter.NotifyEnabledChangedAsync(organization); + } + }, + async providerId => + { + var provider = await providerRepository.GetByIdAsync(providerId.Value); + if (provider != null) + { + provider.Enabled = false; + await providerService.UpdateAsync(provider); + } + }); + + public Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => + subscriberId.Match( + userId => userService.EnablePremiumAsync(userId.Value, currentPeriodEnd), + async organizationId => + { + await organizationEnableCommand.EnableAsync(organizationId.Value, currentPeriodEnd); + var organization = await organizationRepository.GetByIdAsync(organizationId.Value); + if (organization != null) + { + await pushNotificationAdapter.NotifyEnabledChangedAsync(organization); + } + }, + async providerId => + { + var provider = await providerRepository.GetByIdAsync(providerId.Value); + if (provider != null) + { + provider.Enabled = true; + await providerService.UpdateAsync(provider); + } + }); + #region Shared Utilities private async Task AddBraintreeCustomerIdAsync( diff --git a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs index a7aefe3163..976fd962b8 100644 --- a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs +++ b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Notifications; using Bit.Core.Billing.Services; using Bit.Core.Repositories; using NSubstitute; diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 6d74146b03..52a1b2a0ed 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -2,11 +2,10 @@ using Bit.Billing.Services.Implementations; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -14,7 +13,6 @@ using Bit.Core.Test.Billing.Mocks; using Bit.Core.Test.Billing.Mocks.Plans; using Newtonsoft.Json.Linq; using NSubstitute; -using NSubstitute.ReturnsExtensions; using Stripe; using Xunit; using static Bit.Core.Billing.Constants.StripeConstants; @@ -30,13 +28,9 @@ public class SubscriptionUpdatedHandlerTests private readonly IStripeFacade _stripeFacade; private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationEnableCommand _organizationEnableCommand; - private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IPricingClient _pricingClient; - private readonly IProviderRepository _providerRepository; - private readonly IProviderService _providerService; - private readonly IPushNotificationAdapter _pushNotificationAdapter; + private readonly ISubscriberService _subscriberService; + private readonly IOrganizationRepository _organizationRepository; private readonly SubscriptionUpdatedHandler _sut; public SubscriptionUpdatedHandlerTests() @@ -47,14 +41,9 @@ public class SubscriptionUpdatedHandlerTests _stripeFacade = Substitute.For(); _organizationSponsorshipRenewCommand = Substitute.For(); _userService = Substitute.For(); - _providerService = Substitute.For(); - _organizationRepository = Substitute.For(); - _organizationEnableCommand = Substitute.For(); - _organizationDisableCommand = Substitute.For(); _pricingClient = Substitute.For(); - _providerRepository = Substitute.For(); - _providerService = Substitute.For(); - _pushNotificationAdapter = Substitute.For(); + _subscriberService = Substitute.For(); + _organizationRepository = Substitute.For(); _sut = new SubscriptionUpdatedHandler( _stripeEventService, @@ -63,13 +52,9 @@ public class SubscriptionUpdatedHandlerTests _stripeFacade, _organizationSponsorshipRenewCommand, _userService, - _organizationRepository, - _organizationEnableCommand, - _organizationDisableCommand, _pricingClient, - _providerRepository, - _providerService, - _pushNotificationAdapter); + _subscriberService, + _organizationRepository); } [Fact] @@ -119,8 +104,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _organizationRepository.GetByIdAsync(organizationId).Returns(organization); - var plan = new Enterprise2023Plan(true); _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(plan); _pricingClient.ListPlans().Returns(MockPlans.Plans); @@ -129,10 +112,10 @@ public class SubscriptionUpdatedHandlerTests await _sut.HandleAsync(parsedEvent); // Assert - await _organizationDisableCommand.Received(1) - .DisableAsync(organizationId, currentPeriodEnd); - await _pushNotificationAdapter.Received(1) - .NotifyEnabledChangedAsync(organization); + await _subscriberService.Received(1) + .DisableSubscriberAsync( + Arg.Is(s => s.Match(_ => false, o => o.Value == organizationId, _ => false)), + currentPeriodEnd); await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => @@ -150,6 +133,7 @@ public class SubscriptionUpdatedHandlerTests // Arrange var providerId = Guid.NewGuid(); var subscriptionId = "sub_test123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); var previousSubscription = new Subscription { @@ -165,7 +149,7 @@ public class SubscriptionUpdatedHandlerTests { Data = [ - new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } ] }, Metadata = new Dictionary { ["providerId"] = providerId.ToString() }, @@ -182,17 +166,16 @@ public class SubscriptionUpdatedHandlerTests } }; - var provider = new Provider { Id = providerId, Enabled = true }; - _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); - _providerRepository.GetByIdAsync(providerId).Returns(provider); // Act await _sut.HandleAsync(parsedEvent); // Assert - Assert.False(provider.Enabled); - await _providerService.Received(1).UpdateAsync(provider); + await _subscriberService.Received(1) + .DisableSubscriberAsync( + Arg.Is(s => s.Match(_ => false, _ => false, p => p.Value == providerId)), + currentPeriodEnd); // Verify that UpdateSubscription was called with CancelAt await _stripeFacade.Received(1).UpdateSubscription( @@ -233,8 +216,6 @@ public class SubscriptionUpdatedHandlerTests LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; - var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; - var parsedEvent = new Event { Data = new EventData @@ -247,15 +228,11 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _providerRepository.GetByIdAsync(providerId) - .Returns(provider); - // Act await _sut.HandleAsync(parsedEvent); // Assert - No disable or cancellation since there was no valid status transition - Assert.True(provider.Enabled); - await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().DisableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } @@ -288,8 +265,6 @@ public class SubscriptionUpdatedHandlerTests LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; - var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; - var parsedEvent = new Event { Data = new EventData @@ -302,15 +277,11 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _providerRepository.GetByIdAsync(providerId) - .Returns(provider); - // Act await _sut.HandleAsync(parsedEvent); // Assert - No disable or cancellation since the previous status (Canceled) is not a valid transition source - Assert.True(provider.Enabled); - await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().DisableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } @@ -344,8 +315,6 @@ public class SubscriptionUpdatedHandlerTests LatestInvoice = new Invoice { BillingReason = "renewal" } }; - var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; - var parsedEvent = new Event { Data = new EventData @@ -358,20 +327,17 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _providerRepository.GetByIdAsync(providerId) - .Returns(provider); - // Act await _sut.HandleAsync(parsedEvent); // Assert - IncompleteExpired status is not handled by the new logic - Assert.True(provider.Enabled); - await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().DisableSubscriberAsync(Arg.Any(), Arg.Any()); + await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_StillSetsCancellation() + public async Task HandleAsync_UnpaidProviderSubscription_StillSetsCancellation() { // Arrange var providerId = Guid.NewGuid(); @@ -411,14 +377,14 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _providerRepository.GetByIdAsync(providerId) - .Returns((Provider)null); - // Act await _sut.HandleAsync(parsedEvent); - // Assert - Provider not updated (since not found), but cancellation is still set - await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); + // Assert - DisableSubscriberAsync is called and cancellation is set + await _subscriberService.Received(1) + .DisableSubscriberAsync( + Arg.Is(s => s.Match(_ => false, _ => false, p => p.Value == providerId)), + currentPeriodEnd); await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => @@ -474,8 +440,10 @@ public class SubscriptionUpdatedHandlerTests await _sut.HandleAsync(parsedEvent); // Assert - await _userService.Received(1) - .DisablePremiumAsync(userId, currentPeriodEnd); + await _subscriberService.Received(1) + .DisableSubscriberAsync( + Arg.Is(s => s.Match(u => u.Value == userId, _ => false, _ => false)), + currentPeriodEnd); await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => @@ -569,7 +537,6 @@ public class SubscriptionUpdatedHandlerTests LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; - var organization = new Organization { Id = organizationId, PlanType = PlanType.EnterpriseAnnually2023 }; var parsedEvent = new Event { Data = new EventData @@ -582,11 +549,8 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _organizationRepository.GetByIdAsync(organizationId) - .Returns(organization); - var plan = new Enterprise2023Plan(true); - _pricingClient.GetPlanOrThrow(organization.PlanType) + _pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually2023) .Returns(plan); _pricingClient.ListPlans() .Returns(MockPlans.Plans); @@ -595,12 +559,12 @@ public class SubscriptionUpdatedHandlerTests await _sut.HandleAsync(parsedEvent); // Assert - await _organizationEnableCommand.Received(1) - .EnableAsync(organizationId, currentPeriodEnd); + await _subscriberService.Received(1) + .EnableSubscriberAsync( + Arg.Is(s => s.Match(_ => false, o => o.Value == organizationId, _ => false)), + currentPeriodEnd); await _organizationService.Received(1) .UpdateExpirationDateAsync(organizationId, currentPeriodEnd); - await _pushNotificationAdapter.Received(1) - .NotifyEnabledChangedAsync(organization); await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => @@ -653,8 +617,10 @@ public class SubscriptionUpdatedHandlerTests await _sut.HandleAsync(parsedEvent); // Assert - await _userService.Received(1) - .EnablePremiumAsync(userId, currentPeriodEnd); + await _subscriberService.Received(1) + .EnableSubscriberAsync( + Arg.Is(s => s.Match(u => u.Value == userId, _ => false, _ => false)), + currentPeriodEnd); await _userService.Received(1) .UpdatePremiumExpirationAsync(userId, currentPeriodEnd); await _stripeFacade.Received(1).UpdateSubscription( @@ -881,16 +847,13 @@ public class SubscriptionUpdatedHandlerTests Subscription previousSubscription) { // Arrange - var (providerId, newSubscription, provider, parsedEvent) = + var (providerId, newSubscription, _, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _providerRepository - .GetByIdAsync(Arg.Any()) - .Returns(provider); _stripeFacade .UpdateSubscription(Arg.Any(), Arg.Any()) .Returns(newSubscription); @@ -902,12 +865,11 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - await _providerRepository + await _subscriberService .Received(1) - .GetByIdAsync(providerId); - await _providerService - .Received(1) - .UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == true)); + .EnableSubscriberAsync( + Arg.Is(s => s.Match(_ => false, _ => false, p => p.Value == providerId)), + Arg.Any()); await _stripeFacade .Received(1) .UpdateSubscription(newSubscription.Id, @@ -922,15 +884,12 @@ public class SubscriptionUpdatedHandlerTests { // Arrange var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Canceled }; - var (providerId, newSubscription, provider, parsedEvent) = + var (_, newSubscription, _, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _providerRepository - .GetByIdAsync(Arg.Any()) - .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); @@ -939,10 +898,7 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); - await _providerService - .DidNotReceive() - .UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); @@ -954,15 +910,12 @@ public class SubscriptionUpdatedHandlerTests { // Arrange var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Active }; - var (providerId, newSubscription, provider, parsedEvent) = + var (_, newSubscription, _, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _providerRepository - .GetByIdAsync(Arg.Any()) - .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); @@ -971,10 +924,7 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); - await _providerService - .DidNotReceive() - .UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); @@ -986,15 +936,12 @@ public class SubscriptionUpdatedHandlerTests { // Arrange var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Trialing }; - var (providerId, newSubscription, provider, parsedEvent) = + var (_, newSubscription, _, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _providerRepository - .GetByIdAsync(Arg.Any()) - .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); @@ -1003,10 +950,7 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); - await _providerService - .DidNotReceive() - .UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); @@ -1018,15 +962,12 @@ public class SubscriptionUpdatedHandlerTests { // Arrange var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.PastDue }; - var (providerId, newSubscription, provider, parsedEvent) = + var (_, newSubscription, _, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _providerRepository - .GetByIdAsync(Arg.Any()) - .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); @@ -1035,17 +976,14 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); - await _providerService - .DidNotReceive() - .UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); } [Fact] - public async Task HandleAsync_ActiveProviderSubscriptionEvent_AndProviderDoesNotExist_NoChanges() + public async Task HandleAsync_ActiveProviderSubscriptionEvent_EnablesProviderViaSubscriberService() { // Arrange var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Unpaid }; @@ -1055,9 +993,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _providerRepository - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); // Act await _sut.HandleAsync(parsedEvent); @@ -1066,15 +1001,14 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - await _providerRepository + await _subscriberService .Received(1) - .GetByIdAsync(providerId); - await _providerService - .DidNotReceive() - .UpdateAsync(Arg.Any()); + .EnableSubscriberAsync( + Arg.Is(s => s.Match(_ => false, _ => false, p => p.Value == providerId)), + Arg.Any()); await _stripeFacade - .DidNotReceive() - .UpdateSubscription(Arg.Any()); + .Received(1) + .UpdateSubscription(Arg.Any(), Arg.Any()); } [Fact] @@ -1082,15 +1016,12 @@ public class SubscriptionUpdatedHandlerTests { // Arrange - Using a previous status (Canceled) that doesn't trigger SubscriptionBecameActive var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Canceled }; - var (providerId, newSubscription, provider, parsedEvent) = + var (_, newSubscription, _, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _providerRepository - .GetByIdAsync(Arg.Any()) - .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); @@ -1099,13 +1030,10 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); - await _providerService - .DidNotReceive() - .UpdateAsync(Arg.Any()); + await _subscriberService.DidNotReceive().EnableSubscriberAsync(Arg.Any(), Arg.Any()); await _stripeFacade .DidNotReceive() - .UpdateSubscription(Arg.Any()); + .UpdateSubscription(Arg.Any(), Arg.Any()); } private static (Guid providerId, Subscription newSubscription, Provider provider, Event parsedEvent) diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 2f938065e5..7a0a0581bc 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -1,12 +1,19 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Notifications; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Billing.Tax.Models; using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -1771,4 +1778,238 @@ public class SubscriberServiceTests } #endregion + + #region DisableSubscriberAsync + + [Theory, BitAutoData] + public async Task DisableSubscriberAsync_UserId_DisablesPremium( + SutProvider sutProvider) + { + // Arrange + var userId = Guid.NewGuid(); + var subscriberId = new UserId(userId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var userService = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await userService.Received(1).DisablePremiumAsync(userId, currentPeriodEnd); + } + + [Theory, BitAutoData] + public async Task DisableSubscriberAsync_OrganizationId_DisablesOrganizationAndNotifies( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationId = organization.Id; + var subscriberId = new OrganizationId(organizationId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var organizationDisableCommand = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var pushNotificationAdapter = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + // Act + await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await organizationDisableCommand.Received(1).DisableAsync(organizationId, currentPeriodEnd); + await organizationRepository.Received(1).GetByIdAsync(organizationId); + await pushNotificationAdapter.Received(1).NotifyEnabledChangedAsync(organization); + } + + [Theory, BitAutoData] + public async Task DisableSubscriberAsync_OrganizationId_OrganizationNotFound_DoesNotNotify( + SutProvider sutProvider) + { + // Arrange + var organizationId = Guid.NewGuid(); + var subscriberId = new OrganizationId(organizationId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var organizationDisableCommand = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var pushNotificationAdapter = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organizationId).ReturnsNull(); + + // Act + await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await organizationDisableCommand.Received(1).DisableAsync(organizationId, currentPeriodEnd); + await organizationRepository.Received(1).GetByIdAsync(organizationId); + await pushNotificationAdapter.DidNotReceive().NotifyEnabledChangedAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DisableSubscriberAsync_ProviderId_DisablesProvider( + Provider provider, + SutProvider sutProvider) + { + // Arrange + var providerId = provider.Id; + provider.Enabled = true; + var subscriberId = new ProviderId(providerId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var providerRepository = sutProvider.GetDependency(); + var providerService = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + // Act + await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await providerRepository.Received(1).GetByIdAsync(providerId); + await providerService.Received(1).UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == false)); + } + + [Theory, BitAutoData] + public async Task DisableSubscriberAsync_ProviderId_ProviderNotFound_DoesNotUpdate( + SutProvider sutProvider) + { + // Arrange + var providerId = Guid.NewGuid(); + var subscriberId = new ProviderId(providerId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var providerRepository = sutProvider.GetDependency(); + var providerService = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).ReturnsNull(); + + // Act + await sutProvider.Sut.DisableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await providerRepository.Received(1).GetByIdAsync(providerId); + await providerService.DidNotReceive().UpdateAsync(Arg.Any()); + } + + #endregion + + #region EnableSubscriberAsync + + [Theory, BitAutoData] + public async Task EnableSubscriberAsync_UserId_EnablesPremium( + SutProvider sutProvider) + { + // Arrange + var userId = Guid.NewGuid(); + var subscriberId = new UserId(userId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var userService = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await userService.Received(1).EnablePremiumAsync(userId, currentPeriodEnd); + } + + [Theory, BitAutoData] + public async Task EnableSubscriberAsync_OrganizationId_EnablesOrganizationAndNotifies( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationId = organization.Id; + var subscriberId = new OrganizationId(organizationId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var organizationEnableCommand = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var pushNotificationAdapter = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + // Act + await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await organizationEnableCommand.Received(1).EnableAsync(organizationId, currentPeriodEnd); + await organizationRepository.Received(1).GetByIdAsync(organizationId); + await pushNotificationAdapter.Received(1).NotifyEnabledChangedAsync(organization); + } + + [Theory, BitAutoData] + public async Task EnableSubscriberAsync_OrganizationId_OrganizationNotFound_DoesNotNotify( + SutProvider sutProvider) + { + // Arrange + var organizationId = Guid.NewGuid(); + var subscriberId = new OrganizationId(organizationId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var organizationEnableCommand = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var pushNotificationAdapter = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organizationId).ReturnsNull(); + + // Act + await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await organizationEnableCommand.Received(1).EnableAsync(organizationId, currentPeriodEnd); + await organizationRepository.Received(1).GetByIdAsync(organizationId); + await pushNotificationAdapter.DidNotReceive().NotifyEnabledChangedAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task EnableSubscriberAsync_ProviderId_EnablesProvider( + Provider provider, + SutProvider sutProvider) + { + // Arrange + var providerId = provider.Id; + provider.Enabled = false; + var subscriberId = new ProviderId(providerId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var providerRepository = sutProvider.GetDependency(); + var providerService = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + // Act + await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await providerRepository.Received(1).GetByIdAsync(providerId); + await providerService.Received(1).UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == true)); + } + + [Theory, BitAutoData] + public async Task EnableSubscriberAsync_ProviderId_ProviderNotFound_DoesNotUpdate( + SutProvider sutProvider) + { + // Arrange + var providerId = Guid.NewGuid(); + var subscriberId = new ProviderId(providerId); + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var providerRepository = sutProvider.GetDependency(); + var providerService = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).ReturnsNull(); + + // Act + await sutProvider.Sut.EnableSubscriberAsync(subscriberId, currentPeriodEnd); + + // Assert + await providerRepository.Received(1).GetByIdAsync(providerId); + await providerService.DidNotReceive().UpdateAsync(Arg.Any()); + } + + #endregion }