/*******************************************************************************
 *
 * MIT License
 *
 * Copyright (c) 2023 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/

#include <vector>
#include <cstdint>

#include <miopen/env.hpp>

MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS)
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS_AI_HEUR)

#include <miopen/conv/solvers.hpp>
#include <miopen/generic_search.hpp>
#include <miopen/conv/wrw_invoke_params.hpp>
#include <miopen/solver/problem_description_interpreter.hpp>
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
#include <miopen/solver/ck_utility_common.hpp>
#if MIOPEN_ENABLE_AI_KERNEL_TUNING
#include <miopen/conv/heuristics/ai_heuristics.hpp>
#include <miopen/conv/heuristics/ai_candidate_selection.hpp>
#include <miopen/conv/heuristics/ai_conv_3d_kernel_tuning_utils.hpp>
#endif
#endif
#include <miopen/solver/implicitgemm_ck_util.hpp>
#include <miopen/solver/implicitgemm_util.hpp>

namespace miopen {
namespace solver {
namespace conv {

#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL

namespace {

template <typename DataType>
struct CKArgs
{
    CKArgs(const ::miopen::conv::ProblemDescription& problem)
    {
        G               = ProblemInterpreter::GetGroupCountG(problem);
        N               = ProblemInterpreter::GetBatchN(problem);
        K1              = ProblemInterpreter::GetOutputChannelK(problem);
        C1              = ProblemInterpreter::GetInputChannelC(problem);
        C               = C1 / G; // Number of input Channel per group
        K               = K1 / G; // Number of output Channel per group
        Hi              = ProblemInterpreter::GetInputHeightHi(problem);
        Wi              = ProblemInterpreter::GetInputWidthWi(problem);
        Ho              = ProblemInterpreter::GetOutputHeightHo(problem);
        Wo              = ProblemInterpreter::GetOutputWidthWo(problem);
        Y               = ProblemInterpreter::GetFilterHeightY(problem);
        X               = ProblemInterpreter::GetFilterWidthX(problem);
        Di              = ProblemInterpreter::GetInputDepthDi(problem);
        Do              = ProblemInterpreter::GetOutputDepthDo(problem);
        Z               = ProblemInterpreter::GetFilterDepthZ(problem);
        data_type       = ProblemInterpreter::GetOutputDataType(problem);
        alpha_beta_case = ProblemInterpreter::GetAlphaBetaCase(problem);

        in_lengths  = {G, N, C, Di, Hi, Wi};
        out_lengths = {G, N, K, Do, Ho, Wo};
        wei_lengths = {G, K, C, Z, Y, X};

        // CK strides are in GNCDHW order
        if(problem.IsLayoutNHWC())
        {
            // first entry reserved for G's stride
            auto copy_strides = [](const auto& src, auto& dst) {
                assert(dst.size() == (src.size() + 1));
                std::copy(src.begin(), src.end(), dst.begin() + 1);
            };
            copy_strides(problem.GetIn().GetStrides(), in_strides);
            copy_strides(problem.GetOut().GetStrides(), out_strides);
            copy_strides(problem.GetWeights().GetStrides(), wei_strides);

            // On a backward pass, problem.GetIn() means y(or out),
            // and problem.GetOut means x(or in)
            /// \todo remove this when we stop swapping in and out tensors/descriptors
            std::swap(in_strides, out_strides);

            // Now compute G's stride
            in_strides[0]  = C;
            out_strides[0] = K;
            wei_strides[0] = K * wei_strides[1];
        }
        else
        {
            assert(problem.IsLayoutDefault()); // already checked in IsApplicable
            // for default layout, we produce packed strides for NHWC layout
            // because we transpose to NHWC layout before calling CK kernel
            in_strides  = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
            out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
            wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
        }

        filter_strides   = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem),
                          ProblemInterpreter::GetAdjustedConvolutionStrideH(problem),
                          ProblemInterpreter::GetAdjustedConvolutionStrideW(problem)};
        filter_dilations = {ProblemInterpreter::GetAdjustedConvolutionDilationD(problem),
                            ProblemInterpreter::GetAdjustedConvolutionDilationH(problem),
                            ProblemInterpreter::GetAdjustedConvolutionDilationW(problem)};
        lPadding         = {ProblemInterpreter::GetInputLeftPadD(problem),
                    ProblemInterpreter::GetInputLeftPadH(problem),
                    ProblemInterpreter::GetInputLeftPadW(problem)};
        rPadding         = {ProblemInterpreter::GetAdjustedInputRightPadD(problem),
                    ProblemInterpreter::GetAdjustedInputRightPadH(problem),
                    ProblemInterpreter::GetAdjustedInputRightPadW(problem)};
    }
    CKArgs(const CKArgs&) = default;
    CKArgs(CKArgs&&)      = default;
    CKArgs& operator=(const CKArgs&) = default;

    template <typename ConvPtr>
    auto MakeArgPtr(const ConvPtr& conv_ptr,
                    ConstData_t x,
                    Data_t dw,
                    ConstData_t dy,
                    float alpha,
                    float beta,
                    int split_k) const
    {
        using DeviceP = std::remove_pointer_t<decltype(conv_ptr.get())>;
        if constexpr(std::is_same_v<DeviceP, DeviceOpGBwdWeightBilinear<DataType>>)
        {
            return MakeBilinearArgPtr(conv_ptr, x, dw, dy, alpha, beta, split_k);
        }
        else if constexpr(std::is_same_v<DeviceP, DeviceOpGBwdWeightScale<DataType>>)
        {
            (void)beta;
            return MakeScaleArgPtr(conv_ptr, x, dw, dy, alpha, split_k);
        }
        else
        {
            (void)alpha;
            (void)beta;
            static_assert(std::is_same_v<DeviceP, DeviceOpGBwdWeightDefault<DataType>>,
                          "Default should be wrw pass through");
            return MakeDefaultArgPtr(conv_ptr, x, dw, dy, split_k);
        }
    }
    template <typename ConvPtr>
    auto MakeBilinearArgPtr(const ConvPtr& conv_ptr,
                            ConstData_t x,
                            Data_t dw,
                            ConstData_t dy,
                            float alpha,
                            float beta,
                            int split_k) const
    {
        return conv_ptr->MakeArgumentPointer(x,
                                             dw,
                                             dy,
                                             {dw},
                                             in_lengths,
                                             in_strides,
                                             wei_lengths,
                                             wei_strides,
                                             out_lengths,
                                             out_strides,
                                             {wei_lengths},
                                             {wei_strides},
                                             filter_strides,
                                             filter_dilations,
                                             lPadding,
                                             rPadding,
                                             PassThrough{},
                                             Bilinear{alpha, beta},
                                             PassThrough{},
                                             split_k);
    }

    template <typename ConvPtr>
    auto MakeScaleArgPtr(const ConvPtr& conv_ptr,
                         ConstData_t x,
                         Data_t dw,
                         ConstData_t dy,
                         float alpha,
                         int split_k) const
    {
        return conv_ptr->MakeArgumentPointer(x,
                                             dw,
                                             dy,
                                             {},
                                             in_lengths,
                                             in_strides,
                                             wei_lengths,
                                             wei_strides,
                                             out_lengths,
                                             out_strides,
                                             {},
                                             {},
                                             filter_strides,
                                             filter_dilations,
                                             lPadding,
                                             rPadding,
                                             PassThrough{},
                                             Scale{alpha},
                                             PassThrough{},
                                             split_k);
    }

    template <typename ConvPtr>
    auto MakeDefaultArgPtr(
        const ConvPtr& conv_ptr, ConstData_t x, Data_t dw, ConstData_t dy, int split_k) const
    {
        return conv_ptr->MakeArgumentPointer(x,
                                             dw,
                                             dy,
                                             in_lengths,
                                             in_strides,
                                             wei_lengths,
                                             wei_strides,
                                             out_lengths,
                                             out_strides,
                                             filter_strides,
                                             filter_dilations,
                                             lPadding,
                                             rPadding,
                                             PassThrough{},
                                             PassThrough{},
                                             PassThrough{},
                                             split_k);
    }

    template <typename ConvPtr>
    auto MakeArgPtr(const ConvPtr& conv_ptr,
                    const ConvWrwTensors& tensors,
                    float alpha,
                    float beta,
                    int split_k) const
    {
        return MakeArgPtr(conv_ptr, tensors.x, tensors.dw, tensors.dy, alpha, beta, split_k);
    }

    template <typename ConvPtr>
    bool IsSupportedBy(const ConvPtr& conv_ptr) const
    {
        auto arg_ptr        = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, 1);
        auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get());
        if(workspace_size != 0)
            conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size);
        return conv_ptr->IsSupportedArgument(arg_ptr.get());
    }

    template <typename ConvPtr>
    bool IsSupportedBySplitK(const ConvPtr& conv_ptr, int split_k) const
    {
        auto arg_ptr        = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k);
        auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get());
        if(workspace_size != 0)
        {
            conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size);
        }
        return conv_ptr->IsSupportedArgument(arg_ptr.get());
    }

    template <typename ConvPtr>
    std::size_t GetCKSplitkWorkspaceSize(const ConvPtr& conv_ptr, int split_k) const
    {
        auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k);
        return conv_ptr->GetWorkSpaceSize(arg_ptr.get());
    }

    int G;
    int N;
    int K;
    int C;
    int C1;
    int K1;
    int Hi;
    int Wi;
    int Di;
    int Ho;
    int Wo;
    int Do;
    int Y;
    int X;
    int Z;
    miopenAlphaBetaCase_t alpha_beta_case;
    miopenDataType_t data_type;
    std::array<ck::index_t, 6> in_lengths;
    std::array<ck::index_t, 6> in_strides;
    std::array<ck::index_t, 6> out_lengths;
    std::array<ck::index_t, 6> out_strides;
    std::array<ck::index_t, 6> wei_lengths;
    std::array<ck::index_t, 6> wei_strides;
    std::array<ck::index_t, 3> filter_strides;
    std::array<ck::index_t, 3> filter_dilations;
    std::array<ck::index_t, 3> lPadding;
    std::array<ck::index_t, 3> rPadding;
};

template <typename DataType>
std::vector<std::string>
FillValidKernelsByAlphaBeta(const ::miopen::conv::ProblemDescription& problem)
{
    switch(problem.GetAlphaBetaCase())
    {
    case BILINEAR:
        return FillValidKernelsIDs<DeviceOpGBwdWeightBilinearPtrs<DataType>, CKArgs<DataType>>(
            problem);
    case SCALE:
        return FillValidKernelsIDs<DeviceOpGBwdWeightScalePtrs<DataType>, CKArgs<DataType>>(
            problem);
    default:
        return FillValidKernelsIDs<DeviceOpGBwdWeightDefaultPtrs<DataType>, CKArgs<DataType>>(
            problem);
    }
}
} // namespace

template <typename DataType>
void PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::Init(
    const ::miopen::conv::ProblemDescription& problem)
{
    valid_kernels = FillValidKernelsByAlphaBeta<DataType>(problem);
    index         = 0;
    split_k       = 1;
    kernel_id     = valid_kernels[index] + "+" + std::to_string(split_k);
}

template <typename DataType>
bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::CheckIsSupportCKArgs(
    const ::miopen::conv::ProblemDescription& problem) const
{
    switch(problem.GetAlphaBetaCase())
    {
    case BILINEAR:
        return IsCKArgsSupported<DeviceOpGBwdWeightBilinearPtrs<DataType>, CKArgs<DataType>>(
            problem, kernel_id);
    case SCALE:
        return IsCKArgsSupported<DeviceOpGBwdWeightScalePtrs<DataType>, CKArgs<DataType>>(
            problem, kernel_id);
    default:
        return IsCKArgsSupported<DeviceOpGBwdWeightDefaultPtrs<DataType>, CKArgs<DataType>>(
            problem, kernel_id);
    }
}

template <typename DataType>
bool ConvHipImplicitGemm3DGroupWrwXdlops::CheckCKApplicability(
    const ::miopen::conv::ProblemDescription& problem) const
{
    switch(problem.GetAlphaBetaCase())
    {
    case BILINEAR:
        return IsCKApplicable<DeviceOpGBwdWeightBilinearPtrs<DataType>, CKArgs<DataType>>(problem);
    case SCALE:
        return IsCKApplicable<DeviceOpGBwdWeightScalePtrs<DataType>, CKArgs<DataType>>(problem);
    default:
        return IsCKApplicable<DeviceOpGBwdWeightDefaultPtrs<DataType>, CKArgs<DataType>>(problem);
    }
}
void PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::InitValidKernels(
    const ::miopen::conv::ProblemDescription& problem)
{
    switch(problem.GetInDataType())
    {
    case miopenHalf: Init<ck::half_t>(problem); break;
    case miopenFloat: Init<float>(problem); break;
    case miopenInt8: Init<int8_t>(problem); break;
    case miopenBFloat16: Init<ck::bhalf_t>(problem); break;
    default: break; // Unsupported data types - valid_kernels remains empty
    }
}
#endif

void PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::HeuristicInit(
    [[maybe_unused]] const miopen::ExecutionContext& ctx,
    const ::miopen::conv::ProblemDescription& problem)
{
    index     = 0;
    kernel_id = "None";
    split_k   = 1; // This solver uses split_k, so initialize to 1

#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
    // 1. AI heuristics (if enabled)
#if MIOPEN_ENABLE_AI_KERNEL_TUNING
    if(&ctx != &GetDummyCtx() &&
       !env::disabled(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS_AI_HEUR))
    {
        MIOPEN_LOG_I2(
            "Step 1: Attempting AI heuristics for data type: " << problem.GetInDataType());

        std::string solver_name = "ConvHipImplicitGemm3DGroupWrwXdlops";

        bool ai_success = false;
        miopen::ai::tuning::candidate_selection::CandidateSelectionResult result;

        auto run_ai_heuristics = [&](auto CKDataType) {
            using T = decltype(CKDataType);
            auto fill_valid_kernels =
                [=](const ::miopen::conv::ProblemDescription& problem) -> std::vector<std::string> {
                return FillValidKernelsByAlphaBeta<T>(problem);
            };

            // Validation lambda for AI-predicted kernel + split_k combinations
            auto is_kernel_split_k_valid = [&](int kernel_index, int split_k_value) -> bool {
                if(kernel_index < 0 || kernel_index >= static_cast<int>(valid_kernels.size()))
                    return false;

                std::string test_kernel_id =
                    valid_kernels[kernel_index] + "+" + std::to_string(split_k_value);

                switch(problem.GetAlphaBetaCase())
                {
                case BILINEAR:
                    return IsCKArgsSupported<DeviceOpGBwdWeightBilinearPtrs<T>,
                                             CKArgs<T>,
                                             miopen::conv::ProblemDescription,
                                             true>(problem, test_kernel_id);
                case SCALE:
                    return IsCKArgsSupported<DeviceOpGBwdWeightScalePtrs<T>,
                                             CKArgs<T>,
                                             miopen::conv::ProblemDescription,
                                             true>(problem, test_kernel_id);
                default:
                    return IsCKArgsSupported<DeviceOpGBwdWeightDefaultPtrs<T>,
                                             CKArgs<T>,
                                             miopen::conv::ProblemDescription,
                                             true>(problem, test_kernel_id);
                }
            };
            return miopen::solver::conv::RunParameterPredictionModel<T>(ctx,
                                                                        problem,
                                                                        valid_kernels,
                                                                        index,
                                                                        split_k,
                                                                        kernel_id,
                                                                        fill_valid_kernels,
                                                                        solver_name,
                                                                        is_kernel_split_k_valid);
        };
        switch(problem.GetInDataType())
        {
        case miopenHalf: std::tie(ai_success, result) = run_ai_heuristics(ck::half_t{}); break;
        case miopenFloat: std::tie(ai_success, result) = run_ai_heuristics(float{}); break;
        case miopenBFloat16: std::tie(ai_success, result) = run_ai_heuristics(ck::bhalf_t{}); break;
        default: break;
        }

        if(ai_success && !result.IsEmpty())
        {
            MIOPEN_LOG_I("Step 1: AI heuristics selected kernel: " << kernel_id);
            return;
        }
        else
        {
            MIOPEN_LOG_I2("Step 1: AI heuristics failed, proceeding to default initialization");
            // print results to log to help debugging
            if(!ai_success)
                MIOPEN_LOG_I2("Step 1: AI heuristics internal failure");
            else if(result.IsEmpty())
                MIOPEN_LOG_I2("Step 1: AI heuristics returned empty result");
        }
    }
    else
    {
        MIOPEN_LOG_I2("Step 1: AI heuristics skipped (disabled or dummy context)");
    }
#else
    MIOPEN_LOG_I2("Step 1: AI heuristics not available (MIOPEN_ENABLE_AI_KERNEL_TUNING disabled)");
#endif

    // 2. Default: index remains 0, first valid_kernel will be used
    MIOPEN_LOG_I2("Step 2: Using default initialization (index=0)");
    InitValidKernels(problem);
    if(!valid_kernels.empty())
    {
        index     = 0;
        split_k   = 1;
        kernel_id = valid_kernels[index] + "+" + std::to_string(split_k);
        MIOPEN_LOG_I("Step 2: Default initialization selected kernel: " << kernel_id
                                                                        << " at index: 0");
    }
    else
    {
        MIOPEN_LOG_W("Step 2: Default initialization failed - no valid kernels found");
    }
#endif
}

bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::SetNextValue(
    const ::miopen::conv::ProblemDescription& problem)
{
#if MIOPEN_USE_COMPOSABLEKERNEL
    if(valid_kernels.empty())
    {
        // For generic search, we want all available kernels, not heuristic selection
        InitValidKernels(problem);
        if(valid_kernels.empty())
        {
            return false;
        }
    }
    do
    {
        bool flag = NextCKSplitkValue<1, 128>(split_k);
        if(!flag)
        {
            kernel_id = valid_kernels[index] + "+" + std::to_string(split_k);
            break;
        }

        if(!NextLinear(0, valid_kernels.size() - 1, index))
        {
            kernel_id = valid_kernels[index] + "+" + std::to_string(split_k);
            break;
        }
        // All split_k and index values were iterated
        return false;
    } while(false);
#endif
    return true;
}

bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::IsValidValue() const
{
    return index < valid_kernels.size();
}

bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::IsValid(
    [[maybe_unused]] const ::miopen::conv::ProblemDescription& problem) const
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
    switch(problem.GetInDataType())
    {
    case miopenHalf: return CheckIsSupportCKArgs<ck::half_t>(problem);
    case miopenFloat: return CheckIsSupportCKArgs<float>(problem);
    case miopenInt8: return CheckIsSupportCKArgs<int8_t>(problem);
    case miopenBFloat16: return CheckIsSupportCKArgs<ck::bhalf_t>(problem);
    case miopenInt64:
    case miopenInt32:
    case miopenFloat8_fnuz:
    case miopenBFloat8_fnuz:
    case miopenDouble: break;
    }
#endif
    return false;
}

bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::operator==(
    const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& other) const
{
    return kernel_id == other.kernel_id;
}

PerformanceConfigHipImplicitGemm3DGroupWrwXdlops
ConvHipImplicitGemm3DGroupWrwXdlops::GetDefaultPerformanceConfig(
    const ExecutionContext& ctx, const ::miopen::conv::ProblemDescription& problem) const
{
    PerformanceConfigHipImplicitGemm3DGroupWrwXdlops pp;
    pp.HeuristicInit(ctx, problem);
    return pp;
}

bool ConvHipImplicitGemm3DGroupWrwXdlops::IsValidPerformanceConfig(
    const ExecutionContext&,
    const ::miopen::conv::ProblemDescription& problem,
    const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& config) const
{
    return config.IsValid(problem);
}

template <typename DataType>
size_t ConvHipImplicitGemm3DGroupWrwXdlops::GetCKMaxWorkspaceSize(
    const ::miopen::conv::ProblemDescription& problem) const
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
    switch(problem.GetAlphaBetaCase())
    {
    case BILINEAR:
        return GetCKSplitkMaxWorkspaceSize<DeviceOpGBwdWeightBilinearPtrs<DataType>,
                                           CKArgs<DataType>>(problem);
    case SCALE:
        return GetCKSplitkMaxWorkspaceSize<DeviceOpGBwdWeightScalePtrs<DataType>, CKArgs<DataType>>(
            problem);
    default:
        return GetCKSplitkMaxWorkspaceSize<DeviceOpGBwdWeightDefaultPtrs<DataType>,
                                           CKArgs<DataType>>(problem);
    }
#else
    return 0;
#endif
}

size_t ConvHipImplicitGemm3DGroupWrwXdlops::GetCKMaxWorkspaceSize(
    const ::miopen::conv::ProblemDescription& problem) const
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
    switch(problem.GetInDataType())
    {
    case miopenHalf: return GetCKMaxWorkspaceSize<ck::half_t>(problem);
    case miopenFloat: return GetCKMaxWorkspaceSize<float>(problem);
    case miopenInt8: return GetCKMaxWorkspaceSize<int8_t>(problem);
    case miopenBFloat16: return GetCKMaxWorkspaceSize<ck::bhalf_t>(problem);
    case miopenInt64:
    case miopenInt32:
    case miopenFloat8_fnuz:
    case miopenBFloat8_fnuz:
    case miopenDouble: break;
    }
#endif
    return 0; // other types not applicable for this solver
}

size_t ConvHipImplicitGemm3DGroupWrwXdlops::GetWorkspaceSize(
    const ExecutionContext&, const ::miopen::conv::ProblemDescription& problem) const
{
    auto ck_ws_size = GetCKMaxWorkspaceSize(problem);
    return GetWorkspaceSizeLayoutTransformConv(problem, ck_ws_size);
}

PerformanceConfigHipImplicitGemm3DGroupWrwXdlops
ConvHipImplicitGemm3DGroupWrwXdlops::Search(const ExecutionContext& ctx,
                                            const ::miopen::conv::ProblemDescription& problem,
                                            const AnyInvokeParams& invoke_ctx) const
{
    return GenericSearch(*this, ctx, problem, invoke_ctx);
}

bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(
    [[maybe_unused]] const ExecutionContext& ctx,
    [[maybe_unused]] const ::miopen::conv::ProblemDescription& problem) const
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
    if(env::disabled(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS))
        return false;
    if(problem.GetConv().attribute.deterministic)
        return false;
    if(!problem.AllTensorsDimsFitIntoInt())
        return false;
    if(problem.HasMixedDataTypes())
        return false;
    if(!problem.IsDirectionBackwardWrW())
        return false;
    if(!problem.Is3d())
        return false;
    if(!(problem.IsLayoutNHWC() || problem.IsLayoutDefault()))
        return false;
    // needed because layout transpose kernel does not support non-packed tensors
    if(problem.IsLayoutDefault() && problem.HasNonPackedTensors())
        return false;
    if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName()))
        return false;
    switch(problem.GetInDataType())
    {
    case miopenHalf: return CheckCKApplicability<ck::half_t>(problem);
    case miopenFloat: return CheckCKApplicability<float>(problem);
    case miopenInt8: return CheckCKApplicability<int8_t>(problem);
    case miopenBFloat16:
        return (ctx.GetStream().GetDeviceName() == "gfx942" ||
                StartsWith(ctx.GetStream().GetDeviceName(), "gfx95")) &&
               CheckCKApplicability<ck::bhalf_t>(problem);
    case miopenInt64:
    case miopenInt32:
    case miopenFloat8_fnuz:
    case miopenBFloat8_fnuz:
    case miopenDouble: break;
    }
#endif
    return false;
}

ConvSolution ConvHipImplicitGemm3DGroupWrwXdlops::GetSolution(
    [[maybe_unused]] const ExecutionContext& ctx,
    [[maybe_unused]] const ::miopen::conv::ProblemDescription& problem,
    [[maybe_unused]] const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& config) const
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
    return MakeSolutionGroupConvImplicitGemmXdlops(
        problem,
        [&](auto data_type_val) {
            using T = decltype(data_type_val);
            switch(problem.GetAlphaBetaCase())
            {
            case BILINEAR:
                return InitInvokerFactoryWrwNCHW<3,
                                                 false,
                                                 DeviceOpGBwdWeightBilinearPtrs<T>,
                                                 CKArgs<T>,
                                                 miopen::conv::WrWInvokeParams>(
                    ctx, problem, config.kernel_id);
            case SCALE:
                return InitInvokerFactoryWrwNCHW<3,
                                                 false,
                                                 DeviceOpGBwdWeightScalePtrs<T>,
                                                 CKArgs<T>,
                                                 miopen::conv::WrWInvokeParams>(
                    ctx, problem, config.kernel_id);
            default:
                return InitInvokerFactoryWrwNCHW<3,
                                                 false,
                                                 DeviceOpGBwdWeightDefaultPtrs<T>,
                                                 CKArgs<T>,
                                                 miopen::conv::WrWInvokeParams>(
                    ctx, problem, config.kernel_id);
            }
        },
        [&](auto data_type_val) {
            using T = decltype(data_type_val);
            switch(problem.GetAlphaBetaCase())
            {
            case BILINEAR:
                return InitInvokerFactoryNHWC<false,
                                              DeviceOpGBwdWeightBilinearPtrs<T>,
                                              CKArgs<T>,
                                              miopen::conv::WrWInvokeParams>(
                    ctx, problem, config.kernel_id);
            case SCALE:
                return InitInvokerFactoryNHWC<false,
                                              DeviceOpGBwdWeightScalePtrs<T>,
                                              CKArgs<T>,
                                              miopen::conv::WrWInvokeParams>(
                    ctx, problem, config.kernel_id);
            default:
                return InitInvokerFactoryNHWC<false,
                                              DeviceOpGBwdWeightDefaultPtrs<T>,
                                              CKArgs<T>,
                                              miopen::conv::WrWInvokeParams>(
                    ctx, problem, config.kernel_id);
            }
        });

#else
    return {};
#endif
}

#if !(MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL)
void miopen::solver::conv::PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::InitValidKernels(
    const ::miopen::conv::ProblemDescription&)
{
    // No-op stub for non-CK builds
}
#endif

} // namespace conv
} // namespace solver
} // namespace miopen
