!===============================================================================
! Copyright 2021-2022 Intel Corporation.
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

! Content:
! A simple example of asynchronous single-precision real-to-complex, 
! complex-to-real out-of-place 1D FFT using Intel(R) oneAPI Math Kernel Library 
! (oneMKL) DFTI
!
!*****************************************************************************

include "mkl_dfti_omp_offload.f90"

program sp_real_1d_outofplace_async

  use MKL_DFTI_OMP_OFFLOAD, forget => DFTI_SINGLE, DFTI_SINGLE => DFTI_SINGLE_R
  use omp_lib, ONLY : omp_get_num_devices
  use, intrinsic :: ISO_C_BINDING
  ! Size of 1D transforms
  integer, parameter :: N1 = 16
  integer, parameter :: N2 = 64
  
  integer, parameter :: halfN1plus1 = N1/2 + 1
  integer, parameter :: halfN2plus1 = N2/2 + 1

  ! Arbitrary harmonic used to verify FFT
  integer, parameter :: H1 = 1
  integer, parameter :: H2 = N2/2

  ! Working precision is single precision
  integer, parameter :: WP = selected_real_kind(6,37)

  ! Execution status
  integer :: status1, status2, ignored_status

  ! The input data arrays
  real(WP), allocatable :: x1 (:)
  real(WP), allocatable :: x2 (:)
  ! The output data arrays
  complex(WP), allocatable :: y1 (:)
  complex(WP), allocatable :: y2 (:)

  ! DFTI descriptor handle
  type(DFTI_DESCRIPTOR), POINTER :: hand1, hand2

  hand1 => null()
  hand2 => null()

  print *,"Example sp_real_1d_outofplace_async"
  print *,"Forward and backward single-precision real-to-complex",        &
    &      " and complex-to-real out-of-place 1D asynchronous transform"
  print *,"Configuration parameters:"
  print *,"DFTI_PRECISION      = DFTI_SINGLE"
  print *,"DFTI_FORWARD_DOMAIN = DFTI_REAL"
  print *,"DFTI_DIMENSION      = 1"
  print '(" DFTI_LENGTHS        = /"I0"/ & /"I0"/ " )', N1, N2

  print *,"Create DFTI descriptor 1"
  status1 = DftiCreateDescriptor(hand1, DFTI_SINGLE, DFTI_REAL, 1, N1)
  if (0 /= status1) goto 999
  
  print *,"Create DFTI descriptor 2"
  status2 = DftiCreateDescriptor(hand2, DFTI_SINGLE, DFTI_REAL, 1, N2)
  if (0 /= status2) goto 999
  
  print *,"Set DFTI descriptor 1 for out-of-place computation"
  status2 = DftiSetValue(hand1, DFTI_PLACEMENT, DFTI_NOT_INPLACE)
  if (0 /= status1) goto 999
  
  print *,"Set DFTI descriptor 2 for out-of-place computation"
  status1 = DftiSetValue(hand2, DFTI_PLACEMENT, DFTI_NOT_INPLACE)
  if (0 /= status2) goto 999
  
  print *,"Set DFTI descriptor 1 for CCE storage "
  status1 = DftiSetValue(hand1, DFTI_CONJUGATE_EVEN_STORAGE, &
                         DFTI_COMPLEX_COMPLEX)
  if (0 /= status1) goto 999
  
  print *,"Set DFTI descriptor 2 for CCE storage "
  status2 = DftiSetValue(hand1, DFTI_CONJUGATE_EVEN_STORAGE, &
                         DFTI_COMPLEX_COMPLEX)
  if (0 /= status2) goto 999
  
  print *,"Commit DFTI descriptor 1"
  !$omp dispatch
  status1 = DftiCommitDescriptor(hand1)
  if (0 /= status1) goto 999
  
  print *,"Commit DFTI descriptor 2"
  !$omp dispatch
  status2 = DftiCommitDescriptor(hand2)
  if (0 /= status2) goto 999

  print *,"Allocate array for input data 1"
  allocate ( x1(N1), STAT = status1)
  if (0 /= status1) goto 999
  
  print *,"Allocate array for input data 2"
  allocate ( x2(N2), STAT = status2)
  if (0 /= status2) goto 999
  
  print *,"Allocate array for output data 1"
  allocate ( y1(halfN1plus1), STAT = status1)
  if (0 /= status1) goto 999
  
  print *,"Allocate array for output data 2"
  allocate ( y2(halfN2plus1), STAT = status2)
  if (0 /= status2) goto 999

  print *,"Initialize inputs for real-to-complex forward FFT"
  call init_r(x1, N1, H1)
  call init_r(x2, N2, H2)

  !$omp target data map(to:x1, x2) map(from:y1, y2)
  print *,"Compute forward transform 1"
  !$omp dispatch nowait
  status1 = DftiComputeForward(hand1, x1, y1)
  print *,"Compute forward transform 2"
  !$omp dispatch nowait
  status2 = DftiComputeForward(hand2, x2, y2)
  !$omp taskwait
  !$omp end target data
  if (0 /= status1) goto 999
  if (0 /= status2) goto 999

  print *,"Verify the result of FFT1"
  status1 = verify_c(y1, N1, H1)
  if (0 /= status1) goto 999
  print *,"Verify the result of FFT2"
  status2 = verify_c(y2, N2, H2)
  if (0 /= status2) goto 999

  print *,"Initialize inputs for complex-to-real backward transform"
  call init_c(y1, N1, H1)
  call init_c(y2, N2, H2)

  print *,"Compute backward transforms out-of-place"
  !$omp target data map(to:y1, y2) map(from:x1, x2)
  print *,"Compute backward transform 1"
  !$omp dispatch nowait
  status1 = DftiComputeBackward(hand1, y1, x1)
  print *,"Compute backward transform 2"
  !$omp dispatch nowait
  status2 = DftiComputeBackward(hand2, y2, x2)
  !$omp taskwait
  !$omp end target data
  if (0 /= status1) goto 999
  if (0 /= status2) goto 999

  print *,"Verify the results"
  status1 = verify_r(x1, N1, H1)
  status2 = verify_r(x2, N2, H2)
  if (0 /= status1) goto 999
  if (0 /= status2) goto 999

100 continue

  print *,"Release the DFTI descriptors"
  ignored_status = DftiFreeDescriptor(hand1)
  ignored_status = DftiFreeDescriptor(hand2)

  if (allocated(x1)) then
      print *,"Deallocate input data array 1"
      deallocate(x1)
  endif
  
  if (allocated(x2)) then
      print *,"Deallocate input data array 2"
      deallocate(x2)
  endif
  
  if (allocated(y1)) then
      print *,"Deallocate output data array 1"
      deallocate(y1)
  endif
  
  if (allocated(y2)) then
      print *,"Deallocate output data array 2"
      deallocate(y2)
  endif

  if (status1 == 0 .AND. status2 == 0) then
    print *,"TEST PASSED"
    call exit(0)
  else
    print *,"TEST FAILED"
    call exit(1)
  endif

999 print '("  Error, status1 = ",I0)', status1
  print '("  Error, status2 = ",I0)', status2
  goto 100

contains

  ! Compute mod(K*L,M) accurately
  pure real(WP) function moda(k,l,m)
    integer, intent(in) :: k,l,m
    integer*8 :: k8
    k8 = k
    moda = real(mod(k8*l,m),WP)
  end function moda

  ! Initialize real array x to produce unit peaks at y(H) and y(N-H)
  subroutine init_r(x, N, H)
    integer N, H
    real(WP) :: x(:)

    integer k
    real(WP), parameter :: TWOPI = 6.2831853071795864769_WP
    real(WP) :: factor
    if (mod(2*(N - H), N) == 0) then
      factor = 1.0_WP
    else
      factor = 2.0_WP
    end if
    
    do k = 1, N
      x(k) = factor * cos(TWOPI*moda(k-1, H, N)/N) / N
    end do
  end subroutine init_r

  ! Verify that y(k) is unit peak at k = H
  integer function verify_c(y, N, H)
    integer N, H
    complex(WP) :: y(:)

    integer k
    real(WP) err, errthr, maxerr
    complex(WP) res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 2.5 * log(real(N, WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    maxerr = 0.0_WP
    do k = 1, N/2+1
      if (mod(k-1-H,N)==0 .OR. mod(1-k-H,N)==0) then
        res_exp = 1.0_WP
      else
        res_exp = 0.0_WP
      end if
      res_got = y(k)
      err = abs(res_got - res_exp)
      maxerr = max(err,maxerr)
      if (.not.(err < errthr)) then
        print '("  y("I0"): "$)', k
        print '(" expected ("G14.7", "G14.7"),"$)', res_exp
        print '(" got ("G14.7", "G14.7"),"$)', res_got
        print '(" err "G10.3)', err
        print *," Verification FAILED"
        verify_c = 100
        return
      end if
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify_c = 0
  end function verify_c
  
  ! Initialize complex array y to produce unit peaks at x(H)
  subroutine init_c(y, N, H)
    integer N, H
    complex(WP) :: y(:)

    integer k
    real(WP), parameter :: TWOPI = 6.2831853071795864769_WP
    real(WP) :: TWOPI_phase
    
    do k = 1, N/2 + 1
      TWOPI_phase = TWOPI*moda(k-1, H, N)/N
      y(k) = CMPLX(cos(TWOPI_phase)/N, -sin(TWOPI_phase)/N)
    end do
  end subroutine init_c
  
  ! Verify that x(k) is unit peak at k = H
  integer function verify_r(x, N, H)
    integer N, H
    real(WP) :: x(:)

    integer k
    real(WP) err, errthr, maxerr
    real(WP) res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 2.5 * log(real(N, WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    maxerr = 0.0_WP
    do k = 1, N
      if (mod(k-1-H,N)==0) then
        res_exp = 1.0_WP
      else
        res_exp = 0.0_WP
      end if
      res_got = x(k)
      err = abs(res_got - res_exp)
      maxerr = max(err,maxerr)
      if (.not.(err < errthr)) then
        print '("  x("I0"): "$)', k
        print '(" expected "G14.7","$)', res_exp
        print '(" got "G14.7","$)', res_got
        print '(" err "G10.3)', err
        print *," Verification FAILED"
        verify_r = 100
        return
      end if
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify_r = 0
  end function verify_r

end program sp_real_1d_outofplace_async
