NejCommon.NET/Models/Api/ModelBinder.cs
2024-09-15 20:08:31 +02:00

267 lines
9.6 KiB
C#

using AngleSharp.Dom;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Abstractions;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Microsoft.AspNetCore.Mvc.Controllers;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using Microsoft.AspNetCore.Mvc.ModelBinding.Binders;
using Microsoft.AspNetCore.Mvc.ModelBinding.Validation;
using Microsoft.AspNetCore.Mvc.Routing;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.OpenApi.Models;
using Swashbuckle.AspNetCore.SwaggerGen;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
namespace NejCommon.Models
{
public abstract class RouteValue
{
public string Name { get; }
public bool Primary { get; } = false;
public RouteValue(string name, bool primary = false)
{
Name = name;
Primary = primary;
}
}
public class RouteValue<TEntity> : RouteValue where TEntity : class
{
public Expression<Func<TEntity, string, bool>> Expression { get; }
public RouteValue(string name, Expression<Func<TEntity, string, bool>> expression) : base(name)
{
Expression = expression;
}
public RouteValue(string name, bool primary, Expression<Func<TEntity, string, bool>> expression) : base(name, primary)
{
Expression = expression;
}
}
public abstract class EntityBinder
{
}
public abstract class EntityBinder<TEntity> : EntityBinder<TEntity, string> where TEntity : class
{
public EntityBinder(DbContext db, RouteValue<TEntity>[] routeValues) : base(db, routeValues)
{
}
}
//id should be Guid, int, string, etc.
public abstract class EntityBinder<TEntity, TIdType> : EntityBinder, IModelBinder
where TEntity : class
where TIdType : IEquatable<TIdType>
{
private readonly DbContext _db;
private readonly RouteValue<TEntity>[] _routeValues;
public RouteValue<TEntity>[] RouteValues => _routeValues;
public EntityBinder(DbContext db, RouteValue<TEntity>[] routeValues)
{
_db = db;
_routeValues = routeValues;
}
public async Task BindModelAsync(ModelBindingContext bindingContext)
{
if (bindingContext == null)
{
throw new ArgumentNullException(nameof(bindingContext));
}
//get the binding parameter attributes
// Fetch the route values from the route data
var capturedRouteData = new Dictionary<string, object>();
foreach (var routeValue in _routeValues)
{
var name = routeValue.Name;
if (routeValue.Primary && !string.IsNullOrWhiteSpace(bindingContext.ModelName))
{
name = bindingContext.ModelName;
}
/*
Console.WriteLine("Route value: " + name);
Console.WriteLine();
Console.WriteLine(bindingContext.ActionContext.RouteData.Values[name]);
*/
var value = bindingContext.ValueProvider.GetValue(name).FirstOrDefault();//bindingContext.ActionContext.RouteData.Values[name];
if (value is null)
{
bindingContext.ModelState.AddModelError(name, "Route value not found");
}
else
{
capturedRouteData[routeValue.Name] = value;
}
}
// Check if all the required route values are present
if (!bindingContext.ModelState.IsValid)
{
bindingContext.Result = ModelBindingResult.Failed();
return;
}
// Find the entity with the specified route values
var query = _db.Set<TEntity>().AsQueryable();
foreach (var routeValue in _routeValues)
{
var value = capturedRouteData[routeValue.Name];
var par = Expression.Parameter(typeof(TEntity));
var exp = Expression.Invoke(routeValue.Expression, par, Expression.Constant((string)value));
var redExp = exp.Reduce();
var compiledExp = Expression.Lambda<Func<TEntity, bool>>(redExp, par);
query = query.Where(compiledExp);
}
var model = await query.FirstOrDefaultAsync();
if (model == null)
{
bindingContext.HttpContext.Response.StatusCode = 404;
bindingContext.ModelState.AddModelError(bindingContext.FieldName, "Not found in DB");
bindingContext.Result = ModelBindingResult.Failed();
return;
}
bindingContext.Result = ModelBindingResult.Success(model);
bindingContext.ValidationState[bindingContext.Result] = new ValidationStateEntry
{
SuppressValidation = true
};
}
}
public class EntityBinderOperationFilter : IOperationFilter
{
public object CreateDummyInstance(Type entityBinderType)
{
var constructors = entityBinderType.GetConstructors();
var constructor = constructors.OrderBy(x => x.GetParameters().Count()).First();
var constructorParameters = constructor.GetParameters();
var entityBinder = (EntityBinder)constructor.Invoke(constructorParameters.Select(x => (object?)null).ToArray());
return entityBinder;
}
public void Apply(OpenApiOperation operation, OperationFilterContext context)
{
var actionDescriptor = context.ApiDescription.ActionDescriptor;
// Get the action parameters with a ModelBinder attribute
var modelBinderParameters = actionDescriptor.Parameters
.Where(p => p.BindingInfo?.BinderType != null && typeof(EntityBinder).IsAssignableFrom(p.BindingInfo.BinderType))
.Where(x => x != null)
.ToList();
if (modelBinderParameters.Count == 0)
{
return;
}
//Console.WriteLine("Applying EntityBinderOperationFilter to: " + actionDescriptor.DisplayName);
operation.Parameters = operation.Parameters.Where(p => !modelBinderParameters.Any(mp => mp.Name == p.Name)).ToList();
/*
// Get the EntityBinder RouteValues
foreach (var parameter in modelBinderParameters)
{
var entityBinderType = parameter.BindingInfo!.BinderType;
var routeValuesProperty = entityBinderType!.GetProperty("RouteValues");
if (routeValuesProperty == null)
{
continue;
}
var routeValues = (IEnumerable<RouteValue>?)routeValuesProperty.GetValue(CreateDummyInstance(entityBinderType));
if (routeValues == null)
{
continue;
}
foreach (var routeValue in routeValues)
{
operation.Parameters.Add(new OpenApiParameter
{
Name = routeValue.Name,
In = ParameterLocation.Path,
Required = true,
Schema = new OpenApiSchema
{
Type = "string"
}
});
}
}*/
// Exclude the EntityBinder parameter types from the document schemas
var schemaRep = context.SchemaRepository;
foreach (var parameter in modelBinderParameters)
{
var entityBinderType = parameter.BindingInfo?.BinderType;
// Get the EntityBinder generic type argument
var entityType = GetBinderEntityType(entityBinderType);
if (entityType == null)
{
Console.WriteLine("Couldn't find entityType of: " + entityBinderType);
continue;
}
if (schemaRep.Schemas.ContainsKey(entityType.Name))
{
//Console.WriteLine("Removing schema: " + entityType.Name);
schemaRep.Schemas.Remove(entityType.Name);
}
}
}
private Type? GetBinderEntityType(Type? binderType)
{
if (binderType == null)
return null;
if (binderType.IsGenericType && binderType.GetGenericTypeDefinition() == typeof(EntityBinder<>))
{
return binderType.GetTypeInfo().GetGenericArguments().FirstOrDefault();
}
return GetBinderEntityType(binderType.BaseType);
}
}
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)]
public sealed class EntityAttribute<TEntityBinder, TEntity> : Attribute, IPropertyValidationFilter, IApiDescriptionVisibilityProvider, IBinderTypeProviderMetadata where TEntityBinder : EntityBinder<TEntity> where TEntity : class
{
/// <inheritdoc />
public bool ShouldValidateEntry(ValidationEntry entry, ValidationEntry parentEntry)
{
return false;
}
/// <inheritdoc />
public Type? BinderType => typeof(TEntityBinder);
/// <inheritdoc />
public BindingSource? BindingSource => BindingSource.Custom;
public bool IgnoreApi => true;
}
}