//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Device/Resolution/ConvolutionDetectorResolution.cpp
//! @brief     Implements class ConvolutionDetectorResolution.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Device/Resolution/ConvolutionDetectorResolution.h"
#include "Base/Axis/Frame.h"
#include "Base/Axis/Scale.h"
#include "Base/Util/Assert.h"
#include "Device/Resolution/Convolve.h"
#include <stdexcept>

ConvolutionDetectorResolution::ConvolutionDetectorResolution(cumulative_DF_1d res_function_1d)
    : m_rank(1)
    , m_res_function_1d(res_function_1d)
{
}

ConvolutionDetectorResolution::ConvolutionDetectorResolution(
    const IResolutionFunction2D& res_function_2d)
    : m_rank(2)
    , m_res_function_1d(nullptr)
{
    setResolutionFunction(res_function_2d);
}

ConvolutionDetectorResolution::~ConvolutionDetectorResolution() = default;

ConvolutionDetectorResolution::ConvolutionDetectorResolution(
    const ConvolutionDetectorResolution& other)
{
    m_rank = other.m_rank;
    m_res_function_1d = other.m_res_function_1d;
    if (other.m_res_function_2d)
        setResolutionFunction(*other.m_res_function_2d);
}

ConvolutionDetectorResolution* ConvolutionDetectorResolution::clone() const
{
    return new ConvolutionDetectorResolution(*this);
}

std::vector<const INode*> ConvolutionDetectorResolution::nodeChildren() const
{
    return std::vector<const INode*>() << m_res_function_2d;
}

void ConvolutionDetectorResolution::applyDetectorResolution(Datafield* intensity_map) const
{
    if (intensity_map->rank() != m_rank) {
        throw std::runtime_error(
            "ConvolutionDetectorResolution::applyDetectorResolution -> Error! "
            "Intensity map must have same dimension as detector resolution function.");
    }
    switch (m_rank) {
    case 1:
        apply1dConvolution(intensity_map);
        break;
    case 2:
        apply2dConvolution(intensity_map);
        break;
    default:
        throw std::runtime_error(
            "ConvolutionDetectorResolution::applyDetectorResolution -> Error! "
            "Class ConvolutionDetectorResolution must be initialized with dimension 1 or 2.");
    }
}

void ConvolutionDetectorResolution::setResolutionFunction(const IResolutionFunction2D& resFunc)
{
    m_res_function_2d.reset(resFunc.clone());
}

void ConvolutionDetectorResolution::apply1dConvolution(Datafield* intensity_map) const
{
    ASSERT(m_res_function_1d == nullptr);
    if (intensity_map->rank() != 1)
        throw std::runtime_error(
            "ConvolutionDetectorResolution::apply1dConvolution -> Error! "
            "Number of axes for intensity map does not correspond to the dimension of the map.");
    const Scale& axis = intensity_map->axis(0);
    // Construct source vector from original intensity map
    std::vector<double> source_vector = intensity_map->flatVector();
    size_t data_size = source_vector.size();
    if (data_size < 2)
        return; // No convolution for sets of zero or one element
    // Construct kernel vector from resolution function
    if (axis.size() != data_size)
        throw std::runtime_error(
            "ConvolutionDetectorResolution::apply1dConvolution -> Error! "
            "Size of axis for intensity map does not correspond to size of data in the map.");
    double step_size =
        std::abs(axis.binCenter(0) - axis.binCenter(axis.size() - 1)) / (data_size - 1);
    double mid_value = axis.binCenter(axis.size() / 2); // because Convolve expects zero at midpoint
    std::vector<double> kernel;
    for (size_t index = 0; index < data_size; ++index)
        kernel.push_back(getIntegratedPDF1d(axis.binCenter(index) - mid_value, step_size));
    // Calculate convolution
    std::vector<double> result;
    Convolve().fftconvolve(source_vector, kernel, result);
    // Truncate negative values that can arise because of finite precision of Fourier Transform
    for (double& e : result)
        e = std::max(0.0, e);
    // Populate intensity map with results
    intensity_map->setVector(result);
}

void ConvolutionDetectorResolution::apply2dConvolution(Datafield* intensity_map) const
{
    ASSERT(m_res_function_2d);
    if (intensity_map->rank() != 2)
        throw std::runtime_error(
            "ConvolutionDetectorResolution::apply2dConvolution -> Error! "
            "Number of axes for intensity map does not correspond to the dimension of the map.");
    const Scale& axis_1 = intensity_map->axis(0);
    const Scale& axis_2 = intensity_map->axis(1);
    size_t axis_size_1 = axis_1.size();
    size_t axis_size_2 = axis_2.size();
    if (axis_size_1 < 2 || axis_size_2 < 2)
        return; // No 2d convolution for 1d data
    // Construct source vector array from original intensity map
    std::vector<double> raw_source_vector = intensity_map->flatVector();
    std::vector<std::vector<double>> source;
    size_t raw_data_size = raw_source_vector.size();
    if (raw_data_size != axis_size_1 * axis_size_2)
        throw std::runtime_error(
            "ConvolutionDetectorResolution::apply2dConvolution -> Error! "
            "Intensity map data size does not match the product of its axes' sizes");
    for (auto it = raw_source_vector.begin(); it != raw_source_vector.end(); it += axis_size_2) {
        std::vector<double> row_vector(it, it + axis_size_2);
        source.push_back(row_vector);
    }
    // Construct kernel vector from resolution function
    std::vector<std::vector<double>> kernel;
    kernel.resize(axis_size_1);
    double mid_value_1 =
        axis_1.binCenter(axis_size_1 / 2); // because Convolve expects zero at midpoint
    double mid_value_2 =
        axis_2.binCenter(axis_size_2 / 2); // because Convolve expects zero at midpoint
    double step_size_1 =
        std::abs(axis_1.binCenter(0) - axis_1.binCenter(axis_size_1 - 1)) / (axis_size_1 - 1);
    double step_size_2 =
        std::abs(axis_2.binCenter(0) - axis_2.binCenter(axis_size_2 - 1)) / (axis_size_2 - 1);
    for (size_t index_1 = 0; index_1 < axis_size_1; ++index_1) {
        double value_1 = axis_1.binCenter(index_1) - mid_value_1;
        std::vector<double> row_vector;
        row_vector.resize(axis_size_2, 0.0);
        for (size_t index_2 = 0; index_2 < axis_size_2; ++index_2) {
            double value_2 = axis_2.binCenter(index_2) - mid_value_2;
            double z_value = getIntegratedPDF2d(value_1, step_size_1, value_2, step_size_2);
            row_vector[index_2] = z_value;
        }
        kernel[index_1] = row_vector;
    }
    // Calculate convolution
    std::vector<std::vector<double>> result;
    Convolve().fftconvolve(source, kernel, result);
    // Populate intensity map with results
    std::vector<double> result_vector;
    for (size_t index_1 = 0; index_1 < axis_size_1; ++index_1) {
        for (size_t index_2 = 0; index_2 < axis_size_2; ++index_2) {
            double value = result[index_1][index_2];
            result_vector.push_back(value);
        }
    }
    ASSERT(axis_size_1 * axis_size_2 == intensity_map->size());
    for (size_t i = 0; i < intensity_map->size(); ++i) {
        size_t i0 = intensity_map->frame().projectedIndex(i, 0);
        size_t i1 = intensity_map->frame().projectedIndex(i, 1);
        (*intensity_map)[i] = std::max(0.0, result[i0][i1]);
    }
}

double ConvolutionDetectorResolution::getIntegratedPDF1d(double x, double step) const
{
    double halfstep = step / 2.0;
    double xmin = x - halfstep;
    double xmax = x + halfstep;
    ASSERT(m_res_function_1d != nullptr);
    return m_res_function_1d(xmax) - m_res_function_1d(xmin);
}

double ConvolutionDetectorResolution::getIntegratedPDF2d(double x, double step_x, double y,
                                                         double step_y) const
{
    double halfstepx = step_x / 2.0;
    double halfstepy = step_y / 2.0;
    double xmin = x - halfstepx;
    double xmax = x + halfstepx;
    double ymin = y - halfstepy;
    double ymax = y + halfstepy;
    double result =
        m_res_function_2d->evaluateCDF(xmax, ymax) - m_res_function_2d->evaluateCDF(xmax, ymin)
        - m_res_function_2d->evaluateCDF(xmin, ymax) + m_res_function_2d->evaluateCDF(xmin, ymin);
    return result;
}
