﻿// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Sqlite.Internal;

// ReSharper disable once CheckNamespace
namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal;

/// <summary>
///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
///     the same compatibility standards as public APIs. It may be changed or removed without notice in
///     any release. You should only use it directly in your code with extreme caution and knowing that
///     doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class SqliteQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator
{
    private readonly ISqlExpressionFactory _sqlExpressionFactory;

    /// <summary>
    ///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
    ///     the same compatibility standards as public APIs. It may be changed or removed without notice in
    ///     any release. You should only use it directly in your code with extreme caution and knowing that
    ///     doing so can result in application failures when updating to a new Entity Framework Core release.
    /// </summary>
    public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
        => _sqlExpressionFactory = sqlExpressionFactory;

    /// <summary>
    ///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
    ///     the same compatibility standards as public APIs. It may be changed or removed without notice in
    ///     any release. You should only use it directly in your code with extreme caution and knowing that
    ///     doing so can result in application failures when updating to a new Entity Framework Core release.
    /// </summary>
    public virtual SqlExpression? Translate(
        MethodInfo method,
        EnumerableExpression source,
        IReadOnlyList<SqlExpression> arguments,
        IDiagnosticsLogger<DbLoggerCategory.Query> logger)
    {
        if (method.DeclaringType == typeof(Queryable))
        {
            var methodInfo = method.IsGenericMethod
                ? method.GetGenericMethodDefinition()
                : method;
            switch (methodInfo.Name)
            {
                case nameof(Queryable.Average)
                    when (QueryableMethods.IsAverageWithoutSelector(methodInfo)
                        || QueryableMethods.IsAverageWithSelector(methodInfo))
                    && source.Selector is SqlExpression averageSqlExpression:
                    var averageArgumentType = GetProviderType(averageSqlExpression);
                    if (averageArgumentType == typeof(decimal))
                    {
                        averageSqlExpression = CombineTerms(source, averageSqlExpression);
                        return _sqlExpressionFactory.Function(
                            "ef_avg",
                            [averageSqlExpression],
                            nullable: true,
                            argumentsPropagateNullability: Statics.FalseArrays[1],
                            averageSqlExpression.Type,
                            averageSqlExpression.TypeMapping);
                    }

                    break;

                case nameof(Queryable.Max)
                    when (methodInfo == QueryableMethods.MaxWithoutSelector
                        || methodInfo == QueryableMethods.MaxWithSelector)
                    && source.Selector is SqlExpression maxSqlExpression:
                    var maxArgumentType = GetProviderType(maxSqlExpression);
                    if (maxArgumentType == typeof(DateTimeOffset)
                        || maxArgumentType == typeof(TimeSpan)
                        || maxArgumentType == typeof(ulong))
                    {
                        throw new NotSupportedException(
                            SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), maxArgumentType.ShortDisplayName()));
                    }

                    if (maxArgumentType == typeof(decimal))
                    {
                        maxSqlExpression = CombineTerms(source, maxSqlExpression);
                        return _sqlExpressionFactory.Function(
                            "ef_max",
                            [maxSqlExpression],
                            nullable: true,
                            argumentsPropagateNullability: [false],
                            maxSqlExpression.Type,
                            maxSqlExpression.TypeMapping);
                    }

                    break;

                case nameof(Queryable.Min)
                    when (methodInfo == QueryableMethods.MinWithoutSelector
                        || methodInfo == QueryableMethods.MinWithSelector)
                    && source.Selector is SqlExpression minSqlExpression:
                    var minArgumentType = GetProviderType(minSqlExpression);
                    if (minArgumentType == typeof(DateTimeOffset)
                        || minArgumentType == typeof(TimeSpan)
                        || minArgumentType == typeof(ulong))
                    {
                        throw new NotSupportedException(
                            SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), minArgumentType.ShortDisplayName()));
                    }

                    if (minArgumentType == typeof(decimal))
                    {
                        minSqlExpression = CombineTerms(source, minSqlExpression);
                        return _sqlExpressionFactory.Function(
                            "ef_min",
                            [minSqlExpression],
                            nullable: true,
                            argumentsPropagateNullability: [false],
                            minSqlExpression.Type,
                            minSqlExpression.TypeMapping);
                    }

                    break;

                case nameof(Queryable.Sum)
                    when (QueryableMethods.IsSumWithoutSelector(methodInfo)
                        || QueryableMethods.IsSumWithSelector(methodInfo))
                    && source.Selector is SqlExpression sumSqlExpression:
                    var sumArgumentType = GetProviderType(sumSqlExpression);
                    if (sumArgumentType == typeof(decimal))
                    {
                        sumSqlExpression = CombineTerms(source, sumSqlExpression);
                        return _sqlExpressionFactory.Function(
                            "ef_sum",
                            [sumSqlExpression],
                            nullable: true,
                            argumentsPropagateNullability: Statics.FalseArrays[1],
                            sumSqlExpression.Type,
                            sumSqlExpression.TypeMapping);
                    }

                    break;
            }
        }

        return null;
    }

    private static Type? GetProviderType(SqlExpression expression)
        => expression.TypeMapping?.Converter?.ProviderClrType
            ?? expression.TypeMapping?.ClrType
            ?? expression.Type;

    private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression)
    {
        if (enumerableExpression.Predicate != null)
        {
            sqlExpression = _sqlExpressionFactory.Case(
                new List<CaseWhenClause> { new(enumerableExpression.Predicate, sqlExpression) },
                elseResult: null);
        }

        if (enumerableExpression.IsDistinct)
        {
            sqlExpression = new DistinctExpression(sqlExpression);
        }

        return sqlExpression;
    }
}
