diff --git a/behavelet/morlet.py b/behavelet/morlet.py index c4d708a..3d52d66 100644 --- a/behavelet/morlet.py +++ b/behavelet/morlet.py @@ -84,7 +84,7 @@ def _morlet_conj_ft(omegas, omega0=5.0, gpu=False): return ft_wavelet -def _morlet_fft_convolution(X, freqs, scales, dtime, omega0=5.0, gpu=False): +def _morlet_fft_convolution(X, freqs, scales, dtime, omega0=5.0, return_complex=False, gpu=False): """ Calculates a Morlet continuous wavelet transform for a given signal across a range of frequencies @@ -163,7 +163,8 @@ def _morlet_fft_convolution(X, freqs, scales, dtime, omega0=5.0, gpu=False): convolved *= backend.sqrt(scale) convolved = convolved[idx0:idx1] # remove zero padding - convolved = backend.abs(convolved) # use the norm of the complex values + if not return_complex: + convolved = backend.abs(convolved) # use the norm of the complex values # scale power to account for disproportionally # large wavelet response at low frequencies @@ -183,7 +184,7 @@ def _morlet_fft_convolution_parallel(feed_dict): def wavelet_transform(X, n_freqs, fsample, fmin, fmax, prob=True, omega0=5.0, log_scale=True, - n_jobs=1, gpu=False): + return_complex=False, n_jobs=1, gpu=False): """ Applies a Morlet continuous wavelet transform to a data set across a range of frequencies. @@ -215,6 +216,8 @@ def wavelet_transform(X, n_freqs, fsample, fmin, fmax, Whether to sample the frequencies on a log scale. omega0 : float (default = 5.0) Dimensionless omega0 parameter for wavelet transform. + return_complex: bool (default = False) + Whether to return complex wavelet transform. n_jobs : int (default = 1) Number of jobs to use for performing the wavelet transform. If -1, all CPUs are used. If 1 is given, no parallel computing is @@ -284,6 +287,7 @@ def wavelet_transform(X, n_freqs, fsample, fmin, fmax, "scales": scales, "dtime": dtime, "omega0": omega0, + "return_complex": return_complex, "gpu": gpu} for feature in X.T] @@ -299,7 +303,7 @@ def wavelet_transform(X, n_freqs, fsample, fmin, fmax, # for idx, conv in enumerate(convolved): # X_new[:, (n_freqs * idx):(n_freqs * (idx + 1))] = conv.T - power = X_new.sum(axis=1, keepdims=True) + power = np.abs(X_new).sum(axis=1, keepdims=True) if prob: X_new /= power