/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
! An example of SINGLE-precision batch real-to-complex out-of-place 1D FFTs on a 
! (GPU) device using the OpenMP target (offload) interface of oneMKL DFTI
!******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <omp.h>

#include "mkl_dfti_omp_offload.h"

static void init_r(float *x, MKL_LONG M, MKL_LONG N, MKL_LONG REAL_DIST, MKL_LONG H);
static int verify_c(MKL_Complex8 *x, MKL_LONG M, MKL_LONG N, MKL_LONG CMPLX_DIST, MKL_LONG H);
static void init_c(MKL_Complex8 *x, MKL_LONG M, MKL_LONG N, MKL_LONG CMPLX_DIST, MKL_LONG H);
static int verify_r(float *x, MKL_LONG M, MKL_LONG N, MKL_LONG REAL_DIST, MKL_LONG H);

// Define the format to printf MKL_LONG values
#if !defined(MKL_ILP64)
#define LI "%li"
#else
#define LI "%lli"
#endif

int main(void)
{
    const int devNum = 0;

    // Size of 1D FFT
    const MKL_LONG N = 32;
    
    const MKL_LONG REAL_DIST = N;
    const MKL_LONG CMPLX_DIST = N/2 + 1;
    
    // Number of transforms
    const MKL_LONG BATCH = 2;


    // Arbitrary harmonic used to verify FFT
    MKL_LONG H = -1;

    MKL_LONG status = 0;

    // Pointers to input and output data
    float *x = NULL;
    MKL_Complex8 *y = NULL;
    
    DFTI_DESCRIPTOR_HANDLE descHandle    = NULL;

    printf("DFTI_LENGTHS                  = {" LI "}\n", N);
    printf("DFTI_PLACEMENT                = DFTI_NOT_INPLACE\n");
    printf("DFTI_CONJUGATE_EVEN_STORAGE   = DFTI_COMPLEX_COMPLEX\n");
    printf("DFTI_NUMBER_OF_TRANSFORMS     = {" LI "}\n", BATCH);

    printf("Create DFTI descriptor\n");
    status = DftiCreateDescriptor(&descHandle, DFTI_SINGLE, DFTI_REAL, 1, N);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set configuration: DFTI_NOT_INPLACE\n");
    status = DftiSetValue(descHandle, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Set configuration: CCE storage\n");
    status = DftiSetValue(descHandle,
                          DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Set DFTI descriptor configuration: DFTI_NUMBER_OF_TRANSFORMS = " LI "\n", BATCH);
    status = DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, BATCH);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set DFTI descriptor configuration: DFTI_INPUT_DISTANCE = " LI "\n", REAL_DIST);
    status = DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, REAL_DIST);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Set DFTI descriptor configuration: DFTI_OUTPUT_DISTANCE = " LI "\n", CMPLX_DIST);
    status = DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, CMPLX_DIST);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Commit descriptor\n");
    status = DftiCommitDescriptor(descHandle);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Commit descriptor\n");
    {
#pragma omp dispatch device(devNum)
        status = DftiCommitDescriptor(descHandle);
    }
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Allocate data arrays\n");
    x  = (float*)mkl_malloc(REAL_DIST*BATCH*sizeof(float), 64);
    if (x == NULL) goto failed;
    y  = (MKL_Complex8*)mkl_malloc(CMPLX_DIST*BATCH*sizeof(MKL_Complex8), 64);
    if (y == NULL) goto failed;

    printf("Initialize data for real-to-complex FFT\n");
    init_r(x, BATCH, N, REAL_DIST, H);
    
    printf("Compute forward FFT\n");
#pragma omp target data map(to:x[0:REAL_DIST*BATCH]) map(from:y[0:CMPLX_DIST*BATCH]) device(devNum)
    {
        {
// Use need_device_ptr clause for out of place computation because
// DftiComputeForward is a variadic function where the out of place
// output is not explicit in the function declaration.
// The argument to the need_device_ptr clause is the one-based index
// of the pointer in the dispatched function's argument list.
// The input pointer is explicit in the function declaration, so the
// need_device_ptr clause is optional for it. That is, either
// need_device_ptr(2,3), referencing xGPU_real and xGPU_cmplx, or
// need_device_ptr(3), referencing just xGPU_cmplx, will work.
#pragma omp dispatch device(devNum) need_device_ptr(2,3)
            status = DftiComputeForward(descHandle, x, y);
        }
    }
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Verify the complex result\n");
    status = verify_c(y, BATCH, N, CMPLX_DIST, H);
    if (status != 0) goto failed;

    printf("Initialize data for complex-to-real FFT\n");
    init_c(y, BATCH, N, CMPLX_DIST, H);

    printf("Set DFTI descriptor configuration: DFTI_INPUT_DISTANCE = " LI "\n", CMPLX_DIST);
    status = DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, CMPLX_DIST);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Set DFTI descriptor configuration: DFTI_OUTPUT_DISTANCE = " LI "\n", REAL_DIST);
    status = DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, REAL_DIST);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Commit descriptor\n");
    {
#pragma omp dispatch device(devNum)
        status = DftiCommitDescriptor(descHandle);
    }
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Compute backward FFT\n");
#pragma omp target data map(to:y[0:CMPLX_DIST*BATCH]) map(from:x[0:REAL_DIST*BATCH]) device(devNum)
    {
        {
// Use need_device_ptr clause for out of place computation because
// DftiComputeBackward is a variadic function where the out of place
// output is not explicit in the function declaration.
// The argument to the need_device_ptr clause is the one-based index
// of the pointer in the dispatched function's argument list.
// The input pointer is explicit in the function declaration, so the
// need_device_ptr clause is optional for it. That is, either
// need_device_ptr(2,3), referencing xGPU_real and xGPU_cmplx, or
// need_device_ptr(3), referencing just xGPU_cmplx, will work.
#pragma omp dispatch device(devNum) need_device_ptr(2,3)
            status = DftiComputeBackward(descHandle, y, x);
        }
    }
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Verify the result\n");
    status = verify_r(x, BATCH, N, REAL_DIST, H);
    if (status != 0) goto failed;

 cleanup:

    printf("Free DFTI descriptor\n");
    DftiFreeDescriptor(&descHandle);

    printf("Free data array\n");
    mkl_free(x);
    mkl_free(y);

    {
        printf("TEST %s\n", status == 0 ? "PASSED" : "FAILED");
        return status;
    }

 failed:
    printf(" ERROR, status = " LI "\n", status);
    goto cleanup;
}

// Compute (K*L)%M accurately
static float moda(MKL_LONG K, MKL_LONG L, MKL_LONG M)
{
    return (float)(((long long)K * L) % M);
}

const float TWOPI = 6.2831853071795864769f;

// Initialize array x to produce unit peaks at y(H) and y(N-H)
static void init_r(float* x, MKL_LONG M, MKL_LONG N, MKL_LONG REAL_DIST, MKL_LONG H)
{
    const float factor = (2 * (N - H) % N == 0) ? 1.0f : 2.0f;
    for(MKL_LONG m = 0; m < M; ++m){
        for (MKL_LONG n = 0; n < N; ++n) {
            float phase = moda(n, H, N) / N;
            x[m*REAL_DIST + n] = factor * cosf(TWOPI * phase) / N;
        }
    }
}

// Verify that x has unit peak at H
static int verify_c(MKL_Complex8 *x, MKL_LONG M, MKL_LONG N, MKL_LONG CMPLX_DIST, MKL_LONG H)
{
    const float errthr = 2.5f * logf((float) N) / logf(2.0f) * FLT_EPSILON;
    printf(" Verifying the result, max err threshold = %.3lg\n", errthr);

    float maxerr = 0.0f;
    for (MKL_LONG m = 0; m < M; ++m){
        for (MKL_LONG n = 0; n < N/2+1; ++n) {
            float re_exp = 0.0f, im_exp = 0.0f, re_got, im_got;

            if ((n-H)%N == 0 || (-n-H)%N == 0) re_exp = 1.0f;

            re_got = x[m*CMPLX_DIST + n].real;
            im_got = x[m*CMPLX_DIST + n].imag;
            float err  = fabsf(re_got - re_exp) + fabsf(im_got - im_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                printf(" Batch #" LI "x[" LI "]: ", m, n);
                printf(" expected (%.7g,%.7g), ", re_exp, im_exp);
                printf(" got (%.7g,%.7g), ", re_got, im_got);
                printf(" err %.3lg\n", err);
                printf(" Verification FAILED\n");
                return 1;
            }
        } 
    }
    printf(" Verified,  maximum error was %.3lg\n", maxerr);
    return 0;
}

// Initialize array x to produce unit peak at y(H)
static void init_c(MKL_Complex8 *x, MKL_LONG M, MKL_LONG N, MKL_LONG CMPLX_DIST, MKL_LONG H)
{
    for (MKL_LONG m = 0; m < M; ++m){
        for (MKL_LONG n = 0; n < N/2+1; n++) {
            float phase  = moda(n, H, N) / N;
            x[m*CMPLX_DIST + n].real    =  cosf(TWOPI * phase) / N;
            x[m*CMPLX_DIST + n].imag    = -sinf(TWOPI * phase) / N;
        }
    }
    
}

// Verify that x has unit peak at H
static int verify_r(float *x, MKL_LONG M, MKL_LONG N, MKL_LONG REAL_DIST, MKL_LONG H)
{
    const float errthr = 2.5f * logf((float) N) / logf(2.0f) * FLT_EPSILON;
    printf(" Check if err is below errthr %.3lg\n", errthr);

    float maxerr = 0.0f;
    for (MKL_LONG m = 0; m < M; ++m){
        for (MKL_LONG n = 0; n < N; n++) {
            float re_exp = 0.0f, re_got;

            if ((n-H)%N == 0) re_exp = 1.0f;

            re_got = x[m*REAL_DIST + n];
            float err  = fabsf(re_got - re_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                printf(" Batch #" LI " x[" LI "]: ", m, n);
                printf(" expected %.7g, ", re_exp);
                printf(" got %.7g, ", re_got);
                printf(" err %.3lg\n", err);
                printf(" Verification FAILED\n");
                return 1;
            }
        }
    }
    printf(" Verified,  maximum error was %.3lg\n", maxerr);
    return 0;
}
