#CBC vs CUSP injection/ event Testing
#CBC and CUSP Injection Recovery
#Aidan Brophy, SUNY New Paltz Fall 2021 Senior Project

#This program will take a file that is pre downloaded from the GWOSC data portal, and run
#an optimal matched filter analysis for either or both templates for CBC and CUSP events
# To use this program, change the gps time to the appopriate time of the event, and the parameters for
# the two template programs as neccessary

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import readligo as rl
import template 
import FindCuspTemplate
from scipy.io import wavfile
pi = np.pi



#Read in an already downloaded file, set dt and sampling frequency
#Read Ligo takes the arrays for strain and time from the hdf5 file
strain, strain_times, dq = rl.loaddata("H-H1_LOSC_4_V1-816332800-4096 (2).hdf5")
#if len(strain) < 2:
    #print("Failed to read file.")
    #exit()

dt = strain_times[1] - strain_times[0]
fs = 1.0 / dt


#Creating an Injection Segment from GPS Time
#given the gps time of the injection
gpstime_event =  816335770 

start = gpstime_event - 32
stop = gpstime_event + 32


#seg_strains = rl.getsegs(start, stop, 'H1')[0]
#for (begin, end) in segs:
seg_strains, meta, dq = rl.getstrain(start, stop, 'H1')
seg_duration = len(seg_strains)/fs
seg_times = np.arange(0, len(seg_strains))*dt + start
print("Last segment time: ", seg_times[-1], " GPS end: " , stop)
print ("The injection segment is {0} s long".format(seg_duration))

#create template, frequency domain inspiral template
temp, temp_freq = template.createTemplate(fs, seg_duration, 35, 25)
ctemp, ctemp_freq = FindCuspTemplate.create_a_template(fs, seg_duration, 25, 220)
#set template to zero at low frequencies
temp[temp_freq < 25] = 0
ctemp[ctemp_freq < 25] = 0

#Show template value vs frequency
plt.figure()
plt.title("Frequency Domain Templates for CBC and Cosmic String Cusp")
#plt.plot(temp_freq, abs(temp), label = "CBC Template")
plt.plot(ctemp_freq, abs(ctemp), label = "Cusp Template")
#plt.axis([10, 250, 1e-20, 1.5e-20])
plt.xlabel("Frequency (Hz)")
plt.ylabel("Template value (Strain/Hz)")
plt.grid()
plt.legend()
plt.title("FINDCUSP Template in Frequency Domain")
#plt.show()

#Show IFFT of template and plot as time series
t_temp = abs(temp)
t = np.arange(len(t_temp))*dt
inv_cusp_temp = np.fft.ifft(ctemp)
invtemp = np.fft.ifft(temp)

#adjust time domain templates
#minus 1 because the first element is index 0
middleIndex = (len(inv_cusp_temp) - 1)/2
middleIndex2 = (len(invtemp)-1)/2

inv_cusp_tempT = np.roll(inv_cusp_temp, -int(middleIndex))
invtempT = np.roll(invtemp, -int(middleIndex2))

plt.figure()
plt.plot((t), invtempT.real, label = "CBC Template")
#plt.plot((t), inv_cusp_tempT.real, label = "Cusp Template")
#residual = (invtempT.real)/(inv_cusp_tempT.real)
#plt.plot((t), residual, label = "Residual")
plt.axis([14,17, -1.5e-22,1.5e-22])
plt.xlabel("Time (s)")
plt.ylabel("Template value (Strain)")
plt.grid
plt.legend()
plt.title("CBC Template in Time Domain")
#plt.show()



#fft the data and add window
window = np.blackman(seg_strains.size)
windowed_strain = seg_strains*window
data_fft = np.fft.rfft(windowed_strain)

#for spectrogram

NFFT = 1024
windowSP = np.blackman(NFFT)
plt.figure
spec_power, freqs, bins, im = plt.specgram(seg_strains, NFFT=NFFT, Fs=fs, 
                                    window=windowSP,cmap='inferno')
med_power = np.zeros(freqs.shape)
norm_spec_power = np.zeros(spec_power.shape)
index = 0
for row in spec_power:
    med_power[index] = np.median(row)
    norm_spec_power[index] = row / med_power[index]
    index += 1
plt.title("Spectrogram of Strain Segment")
plt.colorbar().set_label('Normalized Energy')
plt.ylabel("Frequency(hz)")
plt.xlabel("Time(s)")
#plt.show()



#get noise segment for pxx
noise_strain, meta1, dq1 = rl.getstrain(start-512, gpstime_event-1  , 'H1')
#noise_strain, meta1, dq1 = rl.getstrain( gpstime_event+2, 'H1')

#setting NFFT equal to the size of the data segment
Pxx, psd_freq = mlab.psd(noise_strain, Fs=fs, NFFT=len(seg_strains))

#computing matched filter output, multiply data by the template and weight by the PSD
integrand = data_fft*np.ma.conjugate(temp)/Pxx
cusp_integrand = data_fft*np.ma.conjugate(ctemp)/Pxx
#
num_zeros = len(seg_strains) - len(data_fft)
#print("The number of zeros to pad:", num_zeros)
padded_int = np.append( integrand, np.zeros(num_zeros) )
cusp_padded_int = np.append( cusp_integrand, np.zeros(num_zeros) )
z = 4*np.fft.ifft(padded_int)
cz = 4*np.fft.ifft(cusp_padded_int)




###Compute the Normalization
kernal =  (np.abs(temp))**2 / Pxx
df = psd_freq[1] - psd_freq[0]
sig_sqr = 4*kernal.sum()*df
sigma = np.sqrt(sig_sqr)

###Compute the Normalization
cusp_kernal =  (np.abs(ctemp))**2 / Pxx
cusp_df = psd_freq[1] - psd_freq[0]
cusp_sig_sqr = 4*cusp_kernal.sum()*df
cusp_sigma = np.sqrt(cusp_sig_sqr)

expected_SNR = sigma / 440

####Compute the SNR
inv_win = (1.0 / window)
inv_win[:20*4096] = 0
inv_win[-20*4096:] = 0
rho = abs(z) / sigma * inv_win
crho = abs(cz) / cusp_sigma * inv_win

snr = rho.max()
####Output the results
# -- Plot rho as a function of time
#print(seg_times.size) 
#print(rho.size)
plt.figure()
plt.title("SNR for Event GW200224_222234" )

plt.plot(seg_times[::8]-seg_times[0], crho[::8], "tab:orange",  label = "FINDCUSP SNR")
plt.plot(seg_times[::8]-seg_times[0], rho[::8], "tab:blue", label = "FINDCHIRP SNR")

plt.xlabel("Time (s) since GPS Time 1266618140" )
plt.ylabel("SNR")
plt.axis([20,44, 0,14])
plt.legend()
plt.show()


#If needed, the data quality flags can be displayed for the segment
"""
plt.plot(dq['HW_CBC'] + 8, label='HW_CBC')
plt.plot(dq['DEFAULT']+1, label='Good Data')
plt.plot(dq['HW_BURST'] + 2, label='HW_BURST')
plt.plot(dq['BURST_CAT1'] + 3, label='BURST_CAT1')
plt.plot(dq['BURST_CAT2'] + 4 , label='BURST_CAT2')
plt.plot(dq['BURST_CAT3'] + 5, label='BURST_CAT3')
plt.plot(dq['BURST_CAT2E'] + 6, label='BURST_CAT2E')
plt.plot(dq['BURST_CAT3E']+ 7, label='BURST_CAT3E')
plt.xlabel("Seconds since GPS {0:.0f}".format(seg_times[0]) )
plt.ylabel("SNR")
plt.legend()
plt.show()
"""

#Compute and display results
snrx = np.average(rho)
snrr = snr / snrx
found_time = seg_times[ np.where(rho == snr) ]

cusp_snr = crho.max()
cusp_snrx = np.average(crho)
cusp_snrr = cusp_snr / cusp_snrx
#cusp_found_time = seg_times[ np.where(rho == cusp_snr) ]


#print("The expected SNR is ", expected_SNR)

#print ("The SNRR is " , snrr)
print ("Recovered time GPS {0:.1f}".format( found_time[0] ))
print ("Recovered SNR is ", snr)
#print("The expected Cusp SNR is ", expected_SNR)
print ("Recovered  Cusp SNR is ", cusp_snr)
#print ("The Cusp SNRR is " , snrr)
#print ("Recovered time GPS for Cusp{0:.1f}".format( cusp_found_time[0] ))
