diff --git a/Controllers/AutoChildController.cs b/Controllers/AutoChildController.cs index bc36567..2469f1c 100644 --- a/Controllers/AutoChildController.cs +++ b/Controllers/AutoChildController.cs @@ -17,8 +17,8 @@ namespace NejCommon.Controllers { public abstract class AutoChildController : AutoChildController where TType : class, new() - where TRequest : IAutomappedAttribute, new() - where TResponse : IAutomappedAttribute, new() + where TRequest : class, IAutomappedAttribute, new() + where TResponse : class, IAutomappedAttribute, new() { public AutoChildController(CommonDbContext appDb, IServiceProvider providers) : base(appDb, providers) { @@ -34,9 +34,9 @@ namespace NejCommon.Controllers [ApiController] public abstract class AutoChildController : ControllerBase where TType : class, new() - where TGetResponse : IAutomappedAttribute, new() - where TUpdateRequest : IAutomappedAttribute, new() - where TUpdateResponse : IAutomappedAttribute, new() + where TGetResponse : class, IAutomappedAttribute, new() + where TUpdateRequest : class, IAutomappedAttribute, new() + where TUpdateResponse : class, IAutomappedAttribute, new() { protected readonly CommonDbContext db; protected readonly IServiceProvider providers; @@ -99,8 +99,6 @@ namespace NejCommon.Controllers }*/ body.ApplyTo(providers, entity); - var dat = new TUpdateResponse().ApplyFrom(providers, entity); - var res = await db.ApiSaveChangesAsyncOk(providers, entity); //use the private constructor thru reflection diff --git a/Controllers/AutoController.cs b/Controllers/AutoController.cs index 29b9096..0f99023 100644 --- a/Controllers/AutoController.cs +++ b/Controllers/AutoController.cs @@ -17,9 +17,9 @@ namespace NejCommon.Controllers { public abstract class AutoController : AutoController - where TType : class, new() - where TRequest : IAutomappedAttribute, new() - where TResponse : IAutomappedAttribute, new() + where TType : class + where TRequest : class, IAutomappedAttribute, new() + where TResponse : class, IAutomappedAttribute, new() { public AutoController(CommonDbContext appDb, IServiceProvider providers) : base(appDb, providers) { @@ -34,13 +34,13 @@ namespace NejCommon.Controllers /// The response type [ApiController] public abstract partial class AutoController : AutoGetterController - where TType : class, new() - where TGetAllResponse : IAutomappedAttribute, new() - where TCreateRequest : IAutomappedAttribute, new() - where TCreateResponse : IAutomappedAttribute, new() - where TGetResponse : IAutomappedAttribute, new() - where TUpdateRequest : IAutomappedAttribute, new() - where TUpdateResponse : IAutomappedAttribute, new() + where TType : class + where TGetAllResponse : class, IAutomappedAttribute, new() + where TCreateRequest : class, IAutomappedAttribute, new() + where TCreateResponse : class, IAutomappedAttribute, new() + where TGetResponse : class, IAutomappedAttribute, new() + where TUpdateRequest : class, IAutomappedAttribute, new() + where TUpdateResponse : class, IAutomappedAttribute, new() { public AutoController(CommonDbContext appDb, IServiceProvider providers) : base(appDb, providers) @@ -59,6 +59,12 @@ namespace NejCommon.Controllers props.SetValue(entity, comp); return entity; } + protected override IAutomappedAttribute GetResponseType(RequestedResponseType type, TType entity) => type switch + { + RequestedResponseType.Create => new TCreateResponse(), + RequestedResponseType.Update => new TUpdateResponse(), + _ => base.GetResponseType(type, entity), + }; /// /// Creates the @@ -69,15 +75,16 @@ namespace NejCommon.Controllers [HttpPost] public virtual async Task, CreatedAtRoute>> Create([FromServices] TOwner company, [FromBody] TCreateRequest body) { - var entity = db.Create(); + var type = body.GetSourceType(); + var entity = (TType)db.Create(type); entity = AssociateWithParent(entity, company); await db.AddAsync(entity); - body.ApplyTo(providers, entity); + body.ApplyTo(providers, (object)entity); - return await db.ApiSaveChangesAsyncCreate(providers, entity); + return await db.ApiSaveChangesAsyncCreate(providers, entity, true, (TCreateResponse)GetResponseType(RequestedResponseType.Create, entity)); } /// @@ -106,11 +113,11 @@ namespace NejCommon.Controllers [Route("{id}/")] public virtual async Task, Ok>> Update([FromServices][ModelBinder(Name = "id")] TType entity, [FromBody] TUpdateRequest body) { - body.ApplyTo(providers, entity); + body.ApplyTo(providers, (object)entity); var dat = new TUpdateResponse().ApplyFrom(providers, entity); - var res = await db.ApiSaveChangesAsyncOk(providers, entity); + var res = await db.ApiSaveChangesAsyncOk(providers, entity, true, (TUpdateResponse)GetResponseType(RequestedResponseType.Update, entity)); //use the private constructor thru reflection var ctor = typeof(Results, Ok>).GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance)[0]; @@ -119,7 +126,7 @@ namespace NejCommon.Controllers } [ApiController] public abstract class AutoGetterController : ControllerBase - where TType : class, new() + where TType : class where TGetAllResponse : IAutomappedAttribute, new() where TGetResponse : IAutomappedAttribute, new() { @@ -134,6 +141,18 @@ namespace NejCommon.Controllers protected abstract IQueryable GetQuery(TOwner comp); + public enum RequestedResponseType + { + Get, + Create, + Update + } + protected virtual IAutomappedAttribute GetResponseType(RequestedResponseType type, TType entity) => type switch + { + RequestedResponseType.Get => new TGetResponse(), + _ => throw new InvalidOperationException("Not implemented"), + }; + protected virtual IQueryable ApplyDefaultOrdering(IQueryable query) { return query; @@ -169,8 +188,8 @@ namespace NejCommon.Controllers [Route("{id}/")] public virtual async Task>> Get([FromServices][ModelBinder(Name = "id")] TType entity) { - var dat = new TGetResponse().ApplyFrom(providers, entity); - return TypedResults.Ok(dat); + var dat = GetResponseType(RequestedResponseType.Get, entity).ApplyFrom(providers, entity); + return TypedResults.Ok((TGetResponse)dat); } } } diff --git a/Controllers/TypedResultsPolyfill.cs b/Controllers/TypedResultsPolyfill.cs index 0154ced..5bb51a9 100644 --- a/Controllers/TypedResultsPolyfill.cs +++ b/Controllers/TypedResultsPolyfill.cs @@ -1,7 +1,14 @@ using System.Diagnostics.Eventing.Reader; +using System.Reflection; using System.Runtime.CompilerServices; +using System.Text.RegularExpressions; using Microsoft.AspNetCore.Http.HttpResults; +using Microsoft.AspNetCore.Http.Metadata; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.Options; using Microsoft.OpenApi.Models; using Swashbuckle.AspNetCore.SwaggerGen; @@ -9,100 +16,132 @@ namespace NejCommon.Controllers { public class TypedResultsMetadataProvider : IOperationFilter { - public void Apply(OpenApiOperation operation, OperationFilterContext context) + private readonly Lazy _contentTypes; + + /// + /// Constructor to inject services + /// + /// MVC options to define response content types + public TypedResultsMetadataProvider(IOptions mvc) { - var responseType = context.MethodInfo.ReturnType; - //Console.WriteLine(context.MethodInfo.DeclaringType.Name); - //Console.WriteLine(context.MethodInfo.Name); - //Console.WriteLine(responseType); - var t = IsSubclassOfRawGeneric(typeof(Microsoft.AspNetCore.Http.HttpResults.Results<,>), responseType); - if (t == null) + _contentTypes = new Lazy(() => { - return; - } - - var parArg = t.GetGenericArguments(); - if (operation.Responses.ContainsKey("200")) - operation.Responses.Remove("200"); - - foreach (var arg in parArg) - { - if (arg == typeof(NotFound)) + var apiResponseTypes = new List(); + if (mvc.Value == null) { - operation.Responses.Add("404", new OpenApiResponse { Description = "Not found" }); - } - else if (arg == typeof(Ok)) - { - operation.Responses.Add("200", new OpenApiResponse { Description = "Success" }); - } - else if (IsSubclassOfRawGeneric(typeof(Ok<>), arg) != null) - { - - var okArg = IsSubclassOfRawGeneric(typeof(Ok<>), arg).GetGenericArguments()[0]; - // Console.WriteLine("Adding: " + okArg); - - //get or generate the schema - var schema = context.SchemaGenerator.GenerateSchema(okArg, context.SchemaRepository); - operation.Responses.Add("200", new OpenApiResponse { Description = "Success", Content = { { "application/json", new OpenApiMediaType { Schema = schema } } } }); - - } - else if (arg == typeof(CreatedAtRoute)) - { - operation.Responses.Add("201", new OpenApiResponse { Description = "Success" }); - } - else if (IsSubclassOfRawGeneric(typeof(CreatedAtRoute<>), arg) != null) - { - if (operation.Responses.ContainsKey("201")) - operation.Responses.Remove("201"); - - var okArg = IsSubclassOfRawGeneric(typeof(CreatedAtRoute<>), arg).GetGenericArguments()[0]; - // Console.WriteLine("Adding: " + okArg); - - //get or generate the schema - var schema = context.SchemaGenerator.GenerateSchema(okArg, context.SchemaRepository); - operation.Responses.Add("201", new OpenApiResponse { Description = "Success", Content = { { "application/json", new OpenApiMediaType { Schema = schema } } } }); - } - else if (arg == typeof(BadRequest)) - { - operation.Responses.Add("400", new OpenApiResponse { Description = "There was an error" }); - } - else if (IsSubclassOfRawGeneric(typeof(BadRequest<>), arg) != null) - { - if (operation.Responses.ContainsKey("400")) - operation.Responses.Remove("400"); - - var okArg = IsSubclassOfRawGeneric(typeof(BadRequest<>), arg).GetGenericArguments()[0]; - // Console.WriteLine("Adding: " + okArg); - - //get or generate the schema - var schema = context.SchemaGenerator.GenerateSchema(okArg, context.SchemaRepository); - operation.Responses.Add("400", new OpenApiResponse { Description = "There was an error", Content = { { "application/json", new OpenApiMediaType { Schema = schema } } } }); - } - else if (arg == typeof(FileStreamHttpResult)){ - operation.Responses.Add("200", new OpenApiResponse { Description = "Success", Content = { { "application/octet-stream", new OpenApiMediaType { Schema = new OpenApiSchema { Type = "string", Format = "binary" } } } } }); + apiResponseTypes.Add("application/json"); } else { - Console.WriteLine("Unknown type: " + arg); + var jsonApplicationType = mvc.Value.FormatterMappings.GetMediaTypeMappingForFormat("json"); + if (jsonApplicationType != null) + apiResponseTypes.Add(jsonApplicationType); + var xmlApplicationType = mvc.Value.FormatterMappings.GetMediaTypeMappingForFormat("xml"); + if (xmlApplicationType != null) + apiResponseTypes.Add(xmlApplicationType); } + return apiResponseTypes.ToArray(); + }); + } + + void IOperationFilter.Apply(OpenApiOperation operation, OperationFilterContext context) + { + + if (!IsControllerAction(context)) return; + + var actionReturnType = UnwrapTask(context.MethodInfo.ReturnType); + if (!IsHttpResults(actionReturnType)) return; + + if (typeof(IEndpointMetadataProvider).IsAssignableFrom(actionReturnType)) + { + var populateMetadataMethod = actionReturnType.GetMethod("Microsoft.AspNetCore.Http.Metadata.IEndpointMetadataProvider.PopulateMetadata", BindingFlags.Static | BindingFlags.NonPublic); + if (populateMetadataMethod == null) return; + + var endpointBuilder = new MetadataEndpointBuilder(); + populateMetadataMethod.Invoke(null, new object[] { context.MethodInfo, endpointBuilder }); + + var responseTypes = endpointBuilder.Metadata.Cast().ToList(); + if (!responseTypes.Any()) return; + operation.Responses.Clear(); + foreach (var responseType in responseTypes) + { + var statusCode = responseType.StatusCode.ToString(); + var oar = new OpenApiResponse { Description = GetResponseDescription(statusCode) }; + + if (responseType.Type != null && responseType.Type != typeof(void)) + { + var schema = context.SchemaGenerator.GenerateSchema(responseType.Type, context.SchemaRepository); + foreach (var contentType in _contentTypes.Value) + { + oar.Content.Add(contentType, new OpenApiMediaType { Schema = schema }); + } + } + + operation.Responses.Add(statusCode, oar); + } + } + else if (actionReturnType == typeof(UnauthorizedHttpResult)) + { + operation.Responses.Clear(); + operation.Responses.Add("401", new OpenApiResponse { Description = ReasonPhrases.GetReasonPhrase(401) }); + } } - static Type? IsSubclassOfRawGeneric(Type generic, Type toCheck) + private static bool IsControllerAction(OperationFilterContext context) + => context.ApiDescription.ActionDescriptor is ControllerActionDescriptor; + + private static bool IsHttpResults(Type type) + => type.Namespace == "Microsoft.AspNetCore.Http.HttpResults"; + + private static Type UnwrapTask(Type type) { - while (toCheck != null && toCheck != typeof(object)) + if (type.IsGenericType) { - //if Task is used, we need to check the underlying type - var realTypeNoTask = toCheck.IsGenericType && toCheck.GetGenericTypeDefinition() == typeof(Task<>) ? toCheck.GetGenericArguments()[0] : toCheck; - var cur = realTypeNoTask.IsGenericType ? realTypeNoTask.GetGenericTypeDefinition() : realTypeNoTask; - //Console.WriteLine(cur); - if (generic == cur) + var genericType = type.GetGenericTypeDefinition(); + if (genericType == typeof(Task<>) || genericType == typeof(ValueTask<>)) { - return realTypeNoTask; + return type.GetGenericArguments()[0]; } - toCheck = toCheck.BaseType; } - return null; + return type; + } + + private static string? GetResponseDescription(string statusCode) + => ResponseDescriptionMap + .FirstOrDefault(entry => Regex.IsMatch(statusCode, entry.Key)) + .Value; + + private static readonly IReadOnlyCollection> ResponseDescriptionMap = new[] + { + new KeyValuePair("1\\d{2}", "Information"), + + new KeyValuePair("201", "Created"), + new KeyValuePair("202", "Accepted"), + new KeyValuePair("204", "No Content"), + new KeyValuePair("2\\d{2}", "Success"), + + new KeyValuePair("304", "Not Modified"), + new KeyValuePair("3\\d{2}", "Redirect"), + + new KeyValuePair("400", "Bad Request"), + new KeyValuePair("401", "Unauthorized"), + new KeyValuePair("403", "Forbidden"), + new KeyValuePair("404", "Not Found"), + new KeyValuePair("405", "Method Not Allowed"), + new KeyValuePair("406", "Not Acceptable"), + new KeyValuePair("408", "Request Timeout"), + new KeyValuePair("409", "Conflict"), + new KeyValuePair("429", "Too Many Requests"), + new KeyValuePair("4\\d{2}", "Client Error"), + + new KeyValuePair("5\\d{2}", "Server Error"), + new KeyValuePair("default", "Error") + }; + + private sealed class MetadataEndpointBuilder : EndpointBuilder + { + public override Endpoint Build() => throw new NotImplementedException(); } } } \ No newline at end of file diff --git a/Models/CommonDbContext.cs b/Models/CommonDbContext.cs index ccc8766..26d557b 100644 --- a/Models/CommonDbContext.cs +++ b/Models/CommonDbContext.cs @@ -11,158 +11,171 @@ using Microsoft.EntityFrameworkCore; namespace NejCommon.Models; - public abstract class CommonDbContext : DbContext +public abstract class CommonDbContext : DbContext +{ + public CommonDbContext() : base() { - public CommonDbContext() : base() + } + public CommonDbContext(DbContextOptions options) + : base(options) + { + } + + public abstract CommonDbContext CreateCopy(); + + public async Task ApiSaveChangesAsync() + { + try { + await SaveChangesAsync(); + return true; } - public CommonDbContext(DbContextOptions options) - : base(options) + catch (Exception ex) { + Console.WriteLine("Error saving db: " + ex.Message); + Console.WriteLine(ex.StackTrace); + return false; + } + } + + public static BadRequest SaveError = TypedResults.BadRequest(new Error + { + Message = "Error saving data to database" + }); + + public async Task, T1>> ApiSaveChangesAsync(T1 value) where T1 : IResult + { + var res = await ApiSaveChangesAsync(); + + if (res) + return value; + else + return SaveError; + } + public async Task, CreatedAtRoute>> ApiSaveChangesAsyncCreate(IServiceProvider providers, T1 value, bool apply = true, T2? response = null) where T2 : class, IAutomappedAttribute, new() + { + if(response == null) + response = new T2(); + if (!apply) + return TypedResults.CreatedAtRoute(response.ApplyFrom(providers, value)); + + var res = await ApiSaveChangesAsync(); + + if (res) + return TypedResults.CreatedAtRoute(response.ApplyFrom(providers, value)); + else + return SaveError; + } + public async Task, Ok>> ApiSaveChangesAsyncOk(IServiceProvider providers, T1 value, bool apply = true, T2? response = null) where T2 : class, IAutomappedAttribute, new() + { + if(response == null) + response = new T2(); + if (!apply) + return TypedResults.Ok(response.ApplyFrom(providers, value)); + + var res = await ApiSaveChangesAsync(); + + if (res) + return TypedResults.Ok(response.ApplyFrom(providers, value)); + else + return SaveError; + } + public async Task FindOrCreateAsync(Expression> predicate, Func factory) where T : class + { + var entity = ChangeTracker.Entries().Select(x => x.Entity).FirstOrDefault(predicate.Compile()); + + //find in change tracker + if (entity != null && Entry(entity).State == EntityState.Deleted) + { + Entry(entity).State = EntityState.Modified; } - public abstract CommonDbContext CreateCopy(); - - public async Task ApiSaveChangesAsync() + //find in currentDb + if (entity == null) { - try - { - await SaveChangesAsync(); - return true; - } - catch (Exception ex) - { - Console.WriteLine("Error saving db: " + ex.Message); - Console.WriteLine(ex.StackTrace); - return false; - } + entity = await Set().FirstOrDefaultAsync(predicate); } - public static BadRequest SaveError = TypedResults.BadRequest(new Error + //find in up-to-date db + if (entity == null) { - Message = "Error saving data to database" - }); + var newAppDB = CreateCopy(); - public async Task, T1>> ApiSaveChangesAsync(T1 value) where T1 : IResult - { - var res = await ApiSaveChangesAsync(); + entity = await newAppDB.Set().FirstOrDefaultAsync(predicate); - if (res) - return value; - else - return SaveError; - } - public async Task, CreatedAtRoute>> ApiSaveChangesAsyncCreate(IServiceProvider providers, T1 value, bool apply = true) where T2 : IAutomappedAttribute, new() - { - if (!apply) - return TypedResults.CreatedAtRoute(new T2().ApplyFrom(providers, value)); - - var res = await ApiSaveChangesAsync(); - - if (res) - return TypedResults.CreatedAtRoute(new T2().ApplyFrom(providers, value)); - else - return SaveError; - } - public async Task, Ok>> ApiSaveChangesAsyncOk(IServiceProvider providers, T1 value, bool apply = true) where T2 : IAutomappedAttribute, new() - { - if (!apply) - return TypedResults.Ok(new T2().ApplyFrom(providers, value)); - - var res = await ApiSaveChangesAsync(); - - if (res) - return TypedResults.Ok(new T2().ApplyFrom(providers, value)); - else - return SaveError; - } - public async Task FindOrCreateAsync(Expression> predicate, Func factory) where T : class - { - var entity = ChangeTracker.Entries().Select(x => x.Entity).FirstOrDefault(predicate.Compile()); - - //find in change tracker - if (entity != null && Entry(entity).State == EntityState.Deleted) - { - Entry(entity).State = EntityState.Modified; - } - - //find in currentDb - if (entity == null) - { - entity = await Set().FirstOrDefaultAsync(predicate); - } - - //find in up-to-date db - if (entity == null) - { - var newAppDB = CreateCopy(); - - entity = await newAppDB.Set().FirstOrDefaultAsync(predicate); - - //track the entity if it's not null and not already being tracked - if (entity != null && this.Entry(entity).State == EntityState.Detached) - Attach(entity); - } - - //create if not found - if (entity == null) - { - var newEntity = factory(); - await this.AddAsync(newEntity); - entity = newEntity; - } - - return entity; - } - public T FindOrCreate(Expression> predicate, Func factory) where T : class - { - var entity = ChangeTracker.Entries().Where(e => e.State != EntityState.Deleted).Select(x => x.Entity).FirstOrDefault(predicate.Compile()); - - //find in change tracker - if (entity != null && Entry(entity).State == EntityState.Deleted) - { - Entry(entity).State = EntityState.Modified; - } - - //find in currentDb - if (entity == null) - { - entity = Set().FirstOrDefault(predicate); - } - - //find in up-to-date db - if (entity == null) - { - var newAppDB = CreateCopy(); - - entity = newAppDB.Set().FirstOrDefault(predicate); - if (entity != null) - Attach(entity); - } - - //create if not found - if (entity == null) - { - var newEntity = factory(); - this.Add(newEntity); - entity = newEntity; - } - - return entity; + //track the entity if it's not null and not already being tracked + if (entity != null && this.Entry(entity).State == EntityState.Detached) + Attach(entity); } - public T Create(Action config = null, params object[] constructorArguments) + //create if not found + if (entity == null) { - var entity = this.CreateProxy(constructorArguments); - - config?.Invoke(entity); - this.Add(entity); - this.ChangeTracker.DetectChanges(); - return entity; + var newEntity = factory(); + await this.AddAsync(newEntity); + entity = newEntity; } - public void ApplyRelationships() + return entity; + } + public T FindOrCreate(Expression> predicate, Func factory) where T : class + { + var entity = ChangeTracker.Entries().Where(e => e.State != EntityState.Deleted).Select(x => x.Entity).FirstOrDefault(predicate.Compile()); + + //find in change tracker + if (entity != null && Entry(entity).State == EntityState.Deleted) { - this.ChangeTracker.DetectChanges(); + Entry(entity).State = EntityState.Modified; } - } \ No newline at end of file + + //find in currentDb + if (entity == null) + { + entity = Set().FirstOrDefault(predicate); + } + + //find in up-to-date db + if (entity == null) + { + var newAppDB = CreateCopy(); + + entity = newAppDB.Set().FirstOrDefault(predicate); + if (entity != null) + Attach(entity); + } + + //create if not found + if (entity == null) + { + var newEntity = factory(); + this.Add(newEntity); + entity = newEntity; + } + + return entity; + } + + public T Create(Action config = null, params object[] constructorArguments) + { + var entity = this.CreateProxy(constructorArguments); + + config?.Invoke(entity); + this.Add(entity); + this.ChangeTracker.DetectChanges(); + return entity; + } + public object Create(Type entityType, Action config = null, params object[] constructorArguments) + { + var entity = this.CreateProxy(entityType, constructorArguments); + + config?.Invoke(entity); + this.Add(entity); + this.ChangeTracker.DetectChanges(); + return entity; + } + + public void ApplyRelationships() + { + this.ChangeTracker.DetectChanges(); + } +} \ No newline at end of file diff --git a/Utils/Extensions.cs b/Utils/Extensions.cs index c42c2de..b24b7aa 100644 --- a/Utils/Extensions.cs +++ b/Utils/Extensions.cs @@ -33,6 +33,20 @@ public static class Extensions Data = query.Select(projector).AsAsyncEnumerable() }; } + public static async Task> ApplyPaginationRes(this IQueryable query, IServiceProvider providers, Pagination pag, Expression> projector) where TResponseType : IAutomappedAttribute, new() + { + var totalCount = await query.CountAsync(); + query = query.Skip(pag.Offset); + query = query.Take(pag.Count); + + return new PaginationResponse + { + TotalCount = totalCount, + Offset = pag.Offset, + Count = pag.Count, + Data = query.Select(projector).AsAsyncEnumerable() + }; + } public static async Task> ApplySearchPaginationRes(this IQueryable query, string? search, IServiceProvider providers, Pagination pag, List>> matchers) where TType : class