diff --git a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs index 498cc668c0..c458969748 100644 --- a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs +++ b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs @@ -10,129 +10,29 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Time.Testing; +using Xunit; using Xunit.Sdk; +using Xunit.v3; namespace Bit.Infrastructure.IntegrationTest; public class DatabaseDataAttribute : DataAttribute { + private static IConfiguration? _cachedConfiguration; + private static IConfiguration GetConfiguration() + { + return _cachedConfiguration ??= new ConfigurationBuilder() + .AddUserSecrets(optional: true, reloadOnChange: false) + .AddEnvironmentVariables("BW_TEST_") + .AddCommandLine(Environment.GetCommandLineArgs()) + .Build(); + } + + public bool SelfHosted { get; set; } public bool UseFakeTimeProvider { get; set; } public string? MigrationName { get; set; } - public override IEnumerable GetData(MethodInfo testMethod) - { - var parameters = testMethod.GetParameters(); - - var config = DatabaseTheoryAttribute.GetConfiguration(); - - var serviceProviders = GetDatabaseProviders(config); - - foreach (var provider in serviceProviders) - { - var objects = new object[parameters.Length]; - for (var i = 0; i < parameters.Length; i++) - { - objects[i] = provider.GetRequiredService(parameters[i].ParameterType); - } - yield return objects; - } - } - - protected virtual IEnumerable GetDatabaseProviders(IConfiguration config) - { - // This is for the device repository integration testing. - var userRequestExpiration = 15; - - var configureLogging = (ILoggingBuilder builder) => - { - if (!config.GetValue("Quiet")) - { - builder.AddConfiguration(config); - builder.AddConsole(); - builder.AddDebug(); - } - }; - - var databases = config.GetDatabases(); - - foreach (var database in databases) - { - if (database.Type == SupportedDatabaseProviders.SqlServer && !database.UseEf) - { - var dapperSqlServerCollection = new ServiceCollection(); - AddCommonServices(dapperSqlServerCollection, configureLogging); - dapperSqlServerCollection.AddDapperRepositories(SelfHosted); - var globalSettings = new GlobalSettings - { - DatabaseProvider = "sqlServer", - SqlServer = new GlobalSettings.SqlSettings - { - ConnectionString = database.ConnectionString, - }, - PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings - { - UserRequestExpiration = TimeSpan.FromMinutes(userRequestExpiration), - } - }; - dapperSqlServerCollection.AddSingleton(globalSettings); - dapperSqlServerCollection.AddSingleton(globalSettings); - dapperSqlServerCollection.AddSingleton(database); - dapperSqlServerCollection.AddDistributedSqlServerCache(o => - { - o.ConnectionString = database.ConnectionString; - o.SchemaName = "dbo"; - o.TableName = "Cache"; - }); - - if (!string.IsNullOrEmpty(MigrationName)) - { - AddSqlMigrationTester(dapperSqlServerCollection, database.ConnectionString, MigrationName); - } - - yield return dapperSqlServerCollection.BuildServiceProvider(); - } - else - { - var efCollection = new ServiceCollection(); - AddCommonServices(efCollection, configureLogging); - efCollection.SetupEntityFramework(database.ConnectionString, database.Type); - efCollection.AddPasswordManagerEFRepositories(SelfHosted); - - var globalSettings = new GlobalSettings - { - PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings - { - UserRequestExpiration = TimeSpan.FromMinutes(userRequestExpiration), - } - }; - efCollection.AddSingleton(globalSettings); - efCollection.AddSingleton(globalSettings); - - efCollection.AddSingleton(database); - efCollection.AddSingleton(); - - if (!string.IsNullOrEmpty(MigrationName)) - { - AddEfMigrationTester(efCollection, database.Type, MigrationName); - } - - yield return efCollection.BuildServiceProvider(); - } - } - } - - private void AddCommonServices(IServiceCollection services, Action configureLogging) - { - services.AddLogging(configureLogging); - services.AddDataProtection(); - - if (UseFakeTimeProvider) - { - services.AddSingleton(); - } - } - private void AddSqlMigrationTester(IServiceCollection services, string connectionString, string migrationName) { services.AddSingleton(_ => new SqlMigrationTesterService(connectionString, migrationName)); @@ -146,4 +46,171 @@ public class DatabaseDataAttribute : DataAttribute return new EfMigrationTesterService(dbContext, databaseType, migrationName); }); } + + public override ValueTask> GetData(MethodInfo testMethod, DisposalTracker disposalTracker) + { + var config = GetConfiguration(); + + HashSet unconfiguredDatabases = + [ + SupportedDatabaseProviders.MySql, + SupportedDatabaseProviders.Postgres, + SupportedDatabaseProviders.Sqlite, + SupportedDatabaseProviders.SqlServer + ]; + + var theories = new List(); + + foreach (var database in config.GetDatabases()) + { + unconfiguredDatabases.Remove(database.Type); + + if (!database.Enabled) + { + var theory = new TheoryDataRow() + .WithSkip("Not-Enabled") + .WithTrait("Database", database.Type.ToString()); + theory.Label = database.Type.ToString(); + theories.Add(theory); + continue; + } + + var services = new ServiceCollection(); + AddCommonServices(services); + + if (database.Type == SupportedDatabaseProviders.SqlServer && !database.UseEf) + { + // Dapper services + AddDapperServices(services, database); + } + else + { + // Ef services + AddEfServices(services, database); + } + + var serviceProvider = services.BuildServiceProvider(); + disposalTracker.Add(serviceProvider); + + var serviceTheory = new ServiceBasedTheoryDataRow(serviceProvider, testMethod) + .WithTrait("Database", database.Type.ToString()) + .WithTrait("ConnectionString", database.ConnectionString); + + serviceTheory.Label = database.Type.ToString(); + theories.Add(serviceTheory); + } + + foreach (var unconfiguredDatabase in unconfiguredDatabases) + { + var theory = new TheoryDataRow() + .WithSkip("Unconfigured") + .WithTrait("Database", unconfiguredDatabase.ToString()); + theory.Label = unconfiguredDatabase.ToString(); + theories.Add(theory); + } + + return new(theories); + } + + private void AddCommonServices(IServiceCollection services) + { + // Common services + services.AddDataProtection(); + services.AddLogging(logging => + { + logging.AddProvider(new XUnitLoggerProvider()); + }); + if (UseFakeTimeProvider) + { + services.AddSingleton(); + } + } + + private void AddDapperServices(IServiceCollection services, Database database) + { + services.AddDapperRepositories(SelfHosted); + var globalSettings = new GlobalSettings + { + DatabaseProvider = "sqlServer", + SqlServer = new GlobalSettings.SqlSettings + { + ConnectionString = database.ConnectionString, + }, + PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings + { + UserRequestExpiration = TimeSpan.FromMinutes(15), + } + }; + services.AddSingleton(globalSettings); + services.AddSingleton(globalSettings); + services.AddSingleton(database); + services.AddDistributedSqlServerCache(o => + { + o.ConnectionString = database.ConnectionString; + o.SchemaName = "dbo"; + o.TableName = "Cache"; + }); + + if (!string.IsNullOrEmpty(MigrationName)) + { + AddSqlMigrationTester(services, database.ConnectionString, MigrationName); + } + } + + private void AddEfServices(IServiceCollection services, Database database) + { + services.SetupEntityFramework(database.ConnectionString, database.Type); + services.AddPasswordManagerEFRepositories(SelfHosted); + + var globalSettings = new GlobalSettings + { + PasswordlessAuth = new GlobalSettings.PasswordlessAuthSettings + { + UserRequestExpiration = TimeSpan.FromMinutes(15), + }, + }; + services.AddSingleton(globalSettings); + services.AddSingleton(globalSettings); + + services.AddSingleton(database); + services.AddSingleton(); + + if (!string.IsNullOrEmpty(MigrationName)) + { + AddEfMigrationTester(services, database.Type, MigrationName); + } + } + + public override bool SupportsDiscoveryEnumeration() + { + return true; + } + + private class ServiceBasedTheoryDataRow : TheoryDataRowBase + { + private readonly IServiceProvider _serviceProvider; + private readonly MethodInfo _testMethod; + + public ServiceBasedTheoryDataRow(IServiceProvider serviceProvider, MethodInfo testMethod) + { + _serviceProvider = serviceProvider; + _testMethod = testMethod; + } + + protected override object?[] GetData() + { + var parameters = _testMethod.GetParameters(); + + var services = new object?[parameters.Length]; + + for (var i = 0; i < parameters.Length; i++) + { + var parameter = parameters[i]; + // TODO: Could support keyed services/optional/nullable + services[i] = _serviceProvider.GetRequiredService(parameter.ParameterType); + } + + return services; + } + } } diff --git a/test/Infrastructure.IntegrationTest/DatabaseTheoryAttribute.cs b/test/Infrastructure.IntegrationTest/DatabaseTheoryAttribute.cs index 1dc6dc76ed..f897220652 100644 --- a/test/Infrastructure.IntegrationTest/DatabaseTheoryAttribute.cs +++ b/test/Infrastructure.IntegrationTest/DatabaseTheoryAttribute.cs @@ -1,32 +1,17 @@ -using Microsoft.Extensions.Configuration; +using System.Runtime.CompilerServices; using Xunit; namespace Bit.Infrastructure.IntegrationTest; +[Obsolete("This attribute is no longer needed and can be replaced with a [Theory]")] public class DatabaseTheoryAttribute : TheoryAttribute { - private static IConfiguration? _cachedConfiguration; - public DatabaseTheoryAttribute() { - if (!HasAnyDatabaseSetup()) - { - Skip = "No databases setup."; - } + } - private static bool HasAnyDatabaseSetup() + public DatabaseTheoryAttribute([CallerFilePath] string? sourceFilePath = null, [CallerLineNumber] int sourceLineNumber = -1) : base(sourceFilePath, sourceLineNumber) { - var config = GetConfiguration(); - return config.GetDatabases().Length > 0; - } - - public static IConfiguration GetConfiguration() - { - return _cachedConfiguration ??= new ConfigurationBuilder() - .AddUserSecrets(optional: true, reloadOnChange: false) - .AddEnvironmentVariables("BW_TEST_") - .AddCommandLine(Environment.GetCommandLineArgs()) - .Build(); } } diff --git a/test/Infrastructure.IntegrationTest/DistributedCacheTests.cs b/test/Infrastructure.IntegrationTest/DistributedCacheTests.cs index 875f9d16c6..974b8e0c18 100644 --- a/test/Infrastructure.IntegrationTest/DistributedCacheTests.cs +++ b/test/Infrastructure.IntegrationTest/DistributedCacheTests.cs @@ -65,7 +65,7 @@ public class DistributedCacheTests [DatabaseTheory, DatabaseData] public async Task MultipleWritesOnSameKey_ShouldNotThrow(IDistributedCache cache) { - await cache.SetAsync("test-duplicate", "some-value"u8.ToArray()); - await cache.SetAsync("test-duplicate", "some-value"u8.ToArray()); + await cache.SetAsync("test-duplicate", "some-value"u8.ToArray(), TestContext.Current.CancellationToken); + await cache.SetAsync("test-duplicate", "some-value"u8.ToArray(), TestContext.Current.CancellationToken); } } diff --git a/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj b/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj index 6d9e0d6667..a2215e3453 100644 --- a/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj +++ b/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj @@ -12,8 +12,8 @@ - - + + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/test/Infrastructure.IntegrationTest/XUnitLoggerProvider.cs b/test/Infrastructure.IntegrationTest/XUnitLoggerProvider.cs new file mode 100644 index 0000000000..43310496f5 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/XUnitLoggerProvider.cs @@ -0,0 +1,47 @@ +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest; + +public sealed class XUnitLoggerProvider : ILoggerProvider +{ + public ILogger CreateLogger(string categoryName) + { + return new XUnitLogger(categoryName); + } + + public void Dispose() + { + + } + + private class XUnitLogger : ILogger + { + private readonly string _categoryName; + + public XUnitLogger(string categoryName) + { + _categoryName = categoryName; + } + + public IDisposable? BeginScope(TState state) where TState : notnull + { + return null; + } + + public bool IsEnabled(LogLevel logLevel) + { + return true; + } + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (TestContext.Current?.TestOutputHelper is not ITestOutputHelper testOutputHelper) + { + return; + } + + testOutputHelper.WriteLine($"[{_categoryName}] {formatter(state, exception)}"); + } + } +}