diff --git a/Project.toml b/Project.toml index 6f40c99..6913b3e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ScatteringTransform" uuid = "eadaac29-8b6b-5395-80fd-9ce36bb46293" -version = "0.8.1" +version = "0.8.2" author = ["David Weber ", "Naoki Saito ", "Jared White "] [deps] diff --git a/docs/make.jl b/docs/make.jl index 7fe9f45..9c5ed78 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,7 +9,7 @@ makedocs( sitename = "ScatteringTransform.jl", format = Documenter.HTML(), modules = [ScatteringTransform, ScatteringPlots], - authors="David Weber, Naoki Saito", + authors="David Weber, Naoki Saito, Jared White", clean=true, checkdocs = :exports, # This ignores the ContinuousWavelets warnings during doctests diff --git a/docs/src/figures/firstLayer.png b/docs/src/figures/firstLayer.png index 7e438b0..3e8fc6c 100644 Binary files a/docs/src/figures/firstLayer.png and b/docs/src/figures/firstLayer.png differ diff --git a/docs/src/figures/jointPlot.png b/docs/src/figures/jointPlot.png new file mode 100644 index 0000000..77fcd7a Binary files /dev/null and b/docs/src/figures/jointPlot.png differ diff --git a/docs/src/figures/secondLayer.png b/docs/src/figures/secondLayer.png new file mode 100644 index 0000000..66beb75 Binary files /dev/null and b/docs/src/figures/secondLayer.png differ diff --git a/docs/src/figures/secondLayerSpecificPath.png b/docs/src/figures/secondLayerSpecificPath.png new file mode 100644 index 0000000..0d0f23c Binary files /dev/null and b/docs/src/figures/secondLayerSpecificPath.png differ diff --git a/docs/src/figures/zerothLayer.png b/docs/src/figures/zerothLayer.png new file mode 100644 index 0000000..6aefead Binary files /dev/null and b/docs/src/figures/zerothLayer.png differ diff --git a/docs/src/index.md b/docs/src/index.md index 0692850..e2ee5da 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,6 +1,6 @@ # ScatteringTransform.jl -A julia implementation of the scattering transform, which provides a prestructured alternative to a convolutional neural network. +A Julia implementation of the scattering transform, which provides a prestructured alternative to a convolutional neural network. In a similar vein to a CNN, it alternates between continuous wavelet transforms, nonlinear function applications, and subsampling. This library is end-to-end differentiable and runs on the GPU; there is a companion package, [ParallelScattering.jl](https://github.com/dsweber2/ParallelScattering.jl/) that runs on parallelized CPUs. @@ -26,6 +26,7 @@ N = 2047 f = testfunction(N, "Doppler") plot(f, legend=false, title="Doppler signal") savefig("figures/rawDoppler.svg"); #hide +nothing # hide ``` ![](figures/rawDoppler.svg) @@ -47,7 +48,7 @@ The results `sf` are stored in the `ScatteredOut` type; for a two layer scatteri The zeroth layer is simply a moving average of the original signal: ```@example ex -plot(sf[0][:, 1, 1], title="Zeroth Layer", legend=false) +plotZerothLayer1D(sf) ``` ### First Layer @@ -55,18 +56,14 @@ plot(sf[0][:, 1, 1], title="Zeroth Layer", legend=false) The first layer is the average of the absolute value of the scalogram: ```@example ex -plotFirstLayer(sf, St, "figures/firstLayer.png") -nothing # hide +plotFirstLayer(sf, St) ``` -![](figures/firstLayer.png) With the plotting utilities included in this package, you are able to display the previous plot along with the original signal and the first layer wavelet gradients: ```@example ex -plotFirstLayer1DAll(sf, f, "figures/firstLayerAll.png") -nothing # hide +plotFirstLayer1DAll(sf, f) ``` -![](figures/firstLayerAll.png) ### Second Layer @@ -76,7 +73,7 @@ With our plotting utilities, you can display the second layer with respect to sp To this end, lets make two gifs, the first with the _first_ layer frequency varying with time: ```@example ex -plotSecondLayerFixAndVary(sf, St, 1:30, 1, "figures/sliceByFirst.gif", 1) +plotSecondLayerFixAndVary(sf, St, 1:30, 1, fps=1, saveTo="figures/sliceByFirst.gif") nothing # hide ``` ![](figures/sliceByFirst.gif) @@ -87,7 +84,7 @@ As the first layer frequency increases, the energy concentrates to the beginning The second has the _second_ layer frequency varying with time: ```@example ex -plotSecondLayerFixAndVary(sf, St, 1, 1:28, "figures/sliceBySecond.gif", 1) +plotSecondLayerFixAndVary(sf, St, 1, 1:28, fps=1, saveTo="figures/sliceBySecond.gif") nothing # hide ``` ![](figures/sliceBySecond.gif) @@ -96,9 +93,7 @@ If desired, this package allows one to plot the results of a specific path. Here ```@example ex plotSecondLayerSpecificPath(sf, St, 3, 1, f) -savefig("figures/specificPath.png"); #hide ``` -![](figures/specificPath.png) For any fixed second layer frequency, we get approximately the curve in the first layer scalogram, with different portions emphasized, and the overall mass decreasing as the frequency increases, corresponding to the decreasing amplitude of the envelope for the doppler signal. These plots can also be created using various plotting utilities defined in this package. @@ -107,11 +102,8 @@ For example, we can generate a denser representation with the `plotSecondLayer` ```@example ex plotSecondLayer(sf, St) -savefig("figures/secondLayer.png"); #hide ``` -![](figures/secondLayer.png) - where the frequencies are along the axes, the heatmap gives the largest value across time for that path, and at each path is a small plot of the averaged timecourse. @@ -121,7 +113,4 @@ Finally, we can constuct a joint plot of much of our prior information. This plo ```@example ex jointPlot(sf, "Scattering Transform", :viridis, St) -savefig("figures/jointPlot.png"); #hide ``` - -![](figures/jointPlot.png) diff --git a/docs/src/plots.md b/docs/src/plots.md index c4080fb..b48985c 100644 --- a/docs/src/plots.md +++ b/docs/src/plots.md @@ -1,7 +1,14 @@ # Plotting Scattering Transforms -```@autodocs -Modules = [ScatteringTransform, ScatteringPlots] -Pages = ["scatteringplots.jl"] -Order = [:function] +```@docs +ScatteringTransform.plotZerothLayer1D +ScatteringTransform.plotFirstLayer1D +ScatteringTransform.gifFirstLayer +ScatteringTransform.plotFirstLayer1DAll +ScatteringTransform.plotFirstLayer +ScatteringTransform.plotSecondLayerSpecificPath +ScatteringTransform.plotSecondLayer1DSubsetGif +ScatteringTransform.plotSecondLayerFixAndVary +ScatteringTransform.plotSecondLayer +ScatteringTransform.jointPlot ``` diff --git a/src/ScatteringTransform.jl b/src/ScatteringTransform.jl index 9d81a01..3f33f86 100644 --- a/src/ScatteringTransform.jl +++ b/src/ScatteringTransform.jl @@ -46,5 +46,5 @@ export getWavelets, flatten, roll, importantCoords, batchOff, getParameters, get export roll, wrap, flatten include("adjoints.jl") include("scatteringplots.jl") -export jointPlot, plotFirstLayer1D, gifFirstLayer, plotFirstLayer1DAll, plotFirstLayer, plotSecondLayer, plotSecondLayer1D, plotSecondLayerSpecificPath, plotSecondLayer1DSubsetGif, plotSecondLayerFixAndVary +export plotZerothLayer1D, plotFirstLayer1D, gifFirstLayer, plotFirstLayer1DAll, plotFirstLayer, plotSecondLayer, plotSecondLayer1D, plotSecondLayerSpecificPath, plotSecondLayer1DSubsetGif, plotSecondLayerFixAndVary, jointPlot end # end Module diff --git a/src/scatteringplots.jl b/src/scatteringplots.jl index 9bed67a..b40b28d 100644 --- a/src/scatteringplots.jl +++ b/src/scatteringplots.jl @@ -1,11 +1,23 @@ """ - plotFirstLayer1D(j, origLoc, origSig, index) + plotZerothLayer1D(sf; saveTo=nothing, index=1) +Function that plots the zeroth layer of the scattering transform at a specified example index. +""" +function plotZerothLayer1D(sf; saveTo=nothing, index=1) + plt = plot(sf[0][:, 1, index], title="Zeroth Layer", legend=false, xlim=(0, length(sf[0][:, 1, index])+1), color=:blue, margin=5Plots.mm) + if !isnothing(saveTo) + savefig(plt, saveTo) + end + return plt +end + +""" + plotFirstLayer1D(j, origLoc, origSig; saveTo=nothing, index=1) Function that plots the first layer gradient wavelet at index `j` across space, along with the original signal. It also includes heatmaps of the gradient wavelet in both the spatial and frequency domains. The variable `j` specifies which wavelet to plot from the first layer, `index` specifies which example in the batch to plot, `origLoc` is the `ScatteredOut` object containing the scattering transform results, and `origSig` is the original input signal. """ -function plotFirstLayer1D(j, origLoc, origSig, index=1) +function plotFirstLayer1D(j, origLoc, origSig; saveTo=nothing, index=1) space = plot(origLoc[1][:, j, index], xlim=(0, length(origLoc[1][:, j, index])+1), legend=false, color=:red, title="First Layer - Gradient Wavelet $j - Varying Location") org = plot(origSig[:,:,index], legend=false, color=:red, title="Original Signal", xlim=(0, length(origSig[:,:,index])+1)) @@ -14,32 +26,38 @@ function plotFirstLayer1D(j, origLoc, origSig, index=1) ∇̂h = heatmap(log.(abs.(rfft(origLoc[1][:, j, index], 1)) .^ 2)', xlabel="frequency", yticks=false, ylabel="", title="Log-power Frequency Domain - Wavelet j=$j") l = Plots.@layout [a; b{0.1h}; [b c]] - plot(space, org, ∇h, ∇̂h, layout=l, size=(1200, 800), margin=5Plots.mm) + plt = plot(space, org, ∇h, ∇̂h, layout=l, size=(1200, 800), margin=5Plots.mm) + if !isnothing(saveTo) + savefig(plt, saveTo) + end + return plt end """ - gifFirstLayer(origLoc, origSig, saveTo="tmp.gif", fps = 2, index) + gifFirstLayer(origLoc, origSig; fps=2, saveTo=nothing, index=1) Function to create a GIF visualizing all wavelets in the first layer across space for each example in the batch. The variable `origLoc` is the `ScatteredOut` object containing the scattering transform results, `index` specifies which example in the batch to plot, `origSig` is the original input signal, `saveTo` specifies the file path to save the GIF, and `fps` sets the frames per second for the GIF animation. +If `saveTo` is provided, the GIF is saved to that file path. """ -function gifFirstLayer(origLoc, origSig, saveTo="gradientFigures/tmp.gif", fps=2, index=1) +function gifFirstLayer(origLoc, origSig; fps=2, saveTo=nothing, index=1) anim = Animation() for j = 1:size(origLoc[1])[end-1] - plotFirstLayer1D(j, origLoc, origSig, index) + plotFirstLayer1D(j, origLoc, origSig; index=index) frame(anim) end - gif(anim, saveTo, fps=fps) + filepath = isnothing(saveTo) ? "tmp.gif" : saveTo + return gif(anim, filepath, fps=fps) end """ - plotFirstLayer1DAll(origLoc, origSig, saveTo="gradientFigures/tmp2.png", index=1, cline=:darkrainbow) + plotFirstLayer1DAll(origLoc, origSig; saveTo=nothing, index=1, cline=:darkrainbow) Function that plots all first layer gradient wavelets for a specific example signal `index` across space, along with the original signal. It also includes heatmaps of the gradient wavelets in both the spatial and frequency domains. The variable `index` specifies which example in the batch to plot, `origLoc` is the `ScatteredOut` object containing the scattering transform results, `origSig` is the original input signal, and `saveTo` is the file path to save the plot. """ -function plotFirstLayer1DAll(origLoc, origSig, saveTo="gradientFigures/tmp2.png", index=1, cline=:darkrainbow) +function plotFirstLayer1DAll(origLoc, origSig; saveTo=nothing, index=1, cline=:darkrainbow) space = plot(origLoc[1][:, :, index], line_z=(1:size(origLoc[1], 2))', xlim=(0, length(origLoc[1][:, 1, index])+1), legend=false, colorbar=true, color=cline, title="first layer gradient wavelets") org = plot(origSig[:,:,index], legend=false, color=:red, title="Original Signal", xlim=(0, length(origSig[:,:,index])+1)) @@ -49,21 +67,27 @@ function plotFirstLayer1DAll(origLoc, origSig, saveTo="gradientFigures/tmp2.png" ylabel="wavelet index", title="Log-power Frequency Domain") l = Plots.@layout [a; b{0.1h}; [b c]] plt = plot(space, org, ∇h, ∇̂h, layout=l, size=(1200, 800), margin=5Plots.mm) - savefig(plt, saveTo) + if !isnothing(saveTo) + savefig(plt, saveTo) + end + return plt end """ - plotFirstLayer(stw, St, saveTo="gradientFigures/firstLayer.png", index=1) + plotFirstLayer(stw, St; saveTo=nothing, index=1) Function that creates a heatmap of the first layer scattering transform results at a specified example index. The variable `stw` is the scattered output, `St` is the scattering transform object, `saveTo` is the file path to save the plot, and `index` specifies which example in the batch to plot. """ -function plotFirstLayer(stw, St, saveTo="gradientFigures/firstLayer.png", index=1) +function plotFirstLayer(stw, St; saveTo=nothing, index=1) f1, f2, f3 = getMeanFreq(St) # the mean frequencies for the wavelets in each layer. plt = heatmap(1:size(stw[1], 1), f1[1:end-1], stw[1][:, :, index]', - xlabel="time index", ylabel="Frequency (Hz)", + xlabel="time index", ylabel="Frequency (Hz)", margin=5Plots.mm, color=:viridis, title="First Layer", size=(1200, 800)) - savefig(plt, saveTo) + if !isnothing(saveTo) + savefig(plt, saveTo) + end + return plt end @@ -102,12 +126,12 @@ function plotSecondLayer1D(loc, origLoc, wave1, wave2, original=false, subsamSz= end """ - plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original, index=1) + plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original; saveTo=nothing, index=1) `stw` is the scattered output, `St` is the scattering transform object, `firstLayerWaveletIndex` and `secondLayerWaveletIndex` specify the path to plot, `original` is the original signal, and `index` specifies which example in the batch to plot. This function creates a plot showing the original signal and the scattering result for the specified path. It also displays the mean frequencies associated with the selected wavelets. Finally, it displays the log-power norm of the second layer signal for the specified path. This value is used elsewhere to create heatmaps of the second layer scattering results. """ -function plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original, index=1) +function plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original; saveTo=nothing, index=1) # Plot of original signal. org = plot(original[:,:,index], legend=false, color=:red, title="Original Signal", xlabel="time (samples)", ylabel="amplitude", xlims=(0, length(original[:,:,index])+1)) f1, f2, f3 = getMeanFreq(St) @@ -127,32 +151,38 @@ function plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLaye normPlot = plot(title="Second Layer Signal norm (log-power) = $(round(secondLayerNorm, sigdigits=4))", grid=false, showaxis=false, xticks=nothing, yticks=nothing, titlefontsize=11, framestyle=:none) l = Plots.@layout [a{0.4h}; title{0.05h}; c; d{0.05h}] - plot(org, titlePlot, ∇h, normPlot, layout=l, margin=4Plots.mm, size=(900,600)) + plt = plot(org, titlePlot, ∇h, normPlot, layout=l, margin=4Plots.mm, size=(900,600)) + if !isnothing(saveTo) + savefig(plt, saveTo) + end + return plt end """ - plotSecondLayer1DSubsetGif(stw, St, firstLayerWavelets, secondLayerWavelets, original, saveTo="secondLayerFigures/tmp2.gif", fps=2, index=1) + plotSecondLayer1DSubsetGif(stw, St, firstLayerWavelets, secondLayerWavelets, original; fps=2, saveTo=nothing, index=1) Create a GIF visualizing the second layer scattering results for specified subsets of wavelets from the first and second layers. The variables `firstLayerWavelets` and `secondLayerWavelets` are arrays containing the indices of the wavelets to be visualized from the first and second layers, respectively. For example, to visualize all the wavelets from the first layer with respect to a specific wavelet from the second layer, you can set `firstLayerWavelets = 1:size(stw[1], 2)` and `secondLayerWavelets = k`, where `k` is the index of the desired second layer wavelet. Once again, the `index` -parameter specifies which example in the batch to plot. It defaults to the first example in the batch. +parameter specifies which example in the batch to plot. It defaults to the first example in the batch. If `saveTo` is not provided, the GIF is saved to "tmp.gif". """ -function plotSecondLayer1DSubsetGif(stw, St, firstLayerWavelets, secondLayerWavelets, original, saveTo="secondLayerFigures/tmp2.gif", fps=2, index=1) +function plotSecondLayer1DSubsetGif(stw, St, firstLayerWavelets, secondLayerWavelets, original; fps=2, saveTo=nothing, index=1) anim = Animation() for j in firstLayerWavelets, k in secondLayerWavelets - plotSecondLayerSpecificPath(stw, St, j, k, original, index) - frame(anim) + plt = plotSecondLayerSpecificPath(stw, St, j, k, original; index=index) + frame(anim, plt) end - gif(anim, saveTo, fps=fps) + filepath = isnothing(saveTo) ? "tmp.gif" : saveTo + return gif(anim, filepath, fps=fps) end """ - plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets, saveTo="secondLayerFigures/sliceByLayer.gif", fps=2, index=1) + plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets; fps=2, saveTo=nothing, index=1) Create a GIF visualizing slices of the second layer scattering results by fixing one layer's wavelet and varying the other layer's wavelet. The variables `firstLayerWavelets` and `secondLayerWavelets` are arrays containing the indices of the wavelets to be visualized from the first and second layers, respectively. If `firstLayerWavelets` contains only one index, the function fixes that wavelet and varies the second layer wavelets, and vice versa. +If `saveTo` is not provided, the GIF is saved to "tmp.gif". """ -function plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets, saveTo="secondLayerFigures/sliceByLayer.gif", fps=2, index=1) +function plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets; fps=2, saveTo=nothing, index=1) f1, f2, f3 = getMeanFreq(St) anim = Animation() @@ -164,9 +194,9 @@ function plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavel continue end toPlot = stw[2][:, jj, :, index] - heatmap(1:size(toPlot, 1), f1[1:end-1], toPlot', title="Second Layer Wavelet $jj, Frequency=$(round(f2[jj],sigdigits=4))Hz", + plt = heatmap(1:size(toPlot, 1), f1[1:end-1], toPlot', title="Second Layer Wavelet $jj, Frequency=$(round(f2[jj], sigdigits=4))Hz", xlabel="time (samples)", ylabel="First Layer Frequency (Hz)", c=cgrad(:viridis, scale=:exp)) - frame(anim) + frame(anim, plt) end else # Fixed second layer, vary first layer. @@ -175,16 +205,17 @@ function plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavel continue end toPlot = stw[2][:, :, jj, index] - heatmap(1:size(toPlot, 1), f2[1:end-1], toPlot', title="First Layer Wavelet $jj, Frequency=$(round(f1[jj],sigdigits=4))Hz", + plt = heatmap(1:size(toPlot, 1), f2[1:end-1], toPlot', title="First Layer Wavelet $jj, Frequency=$(round(f1[jj], sigdigits=4))Hz", xlabel="time (samples)", ylabel="Second Layer Frequency (Hz)", c=cgrad(:viridis, scale=:exp)) - frame(anim) + frame(anim, plt) end end - gif(anim, saveTo, fps=fps) + filepath = isnothing(saveTo) ? "tmp.gif" : saveTo + return gif(anim, filepath, fps=fps) end """ - plotSecondLayer(stw, St, index=1; title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0,.9]), threshold=0, linePalette=:greys, minLog=NaN, kwargs...) + plotSecondLayer(stw, St; saveTo=nothing, index=1, title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0,.9]), threshold=0, linePalette=:greys, minLog=NaN, kwargs...) TODO fix the similarity of these names. xVals and yVals give the spacing of the grid, as it doesn't seem to be done correctly by default. xVals gives the distance from the left and the right @@ -193,16 +224,16 @@ also as a tuple. Default values are `xVals = (.037, .852), yVals = (.056, .939)` If you have no colorbar, set `xVals = (.0015, .997), yVals = (.002, .992)` In the case that arbitrary space has been introduced, if you have a title, use `xVals = (.037, .852), yVals = (.056, .939)`, or if you have no title, use `xVals = (.0105, .882), yVals = (.056, .939)` """ -function plotSecondLayer(stw::ScatteredOut, St, index=1; kwargs...) +function plotSecondLayer(stw::ScatteredOut, St; saveTo=nothing, index=1, kwargs...) secondLayerRes = stw[2] if ndims(secondLayerRes) > 3 - return plotSecondLayer(secondLayerRes[:, :, :, index], St; kwargs...) + return plotSecondLayer(secondLayerRes[:, :, :, index], St; saveTo=saveTo, kwargs...) else - return plotSecondLayer(secondLayerRes, St; kwargs...) + return plotSecondLayer(secondLayerRes, St; saveTo=saveTo, kwargs...) end end -function plotSecondLayer(stw, St; title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0, 0.9]), threshold=0, freqsigdigits=3, linePalette=:greys, minLog=NaN, subClims=(Inf, -Inf), δt=1000, firstFreqSpacing=nothing, secondFreqSpacing=nothing, transp=true, labelRot=30, xlabel=nothing, ylabel=nothing, frameTypes=:box, miniFillAlpha=0.5, kwargs...) +function plotSecondLayer(stw, St; saveTo=nothing, title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0, 0.9]), threshold=0, freqsigdigits=3, linePalette=:greys, minLog=NaN, subClims=(Inf, -Inf), δt=1000, firstFreqSpacing=nothing, secondFreqSpacing=nothing, transp=true, labelRot=30, xlabel=nothing, ylabel=nothing, frameTypes=:box, miniFillAlpha=0.5, kwargs...) n, m = size(stw)[2:3] freqs = getMeanFreq(St, δt) freqs = map(x -> round.(x, sigdigits=freqsigdigits), freqs)[1:2] @@ -260,7 +291,7 @@ function plotSecondLayer(stw, St; title="Second Layer results", xVals=-1, yVals= bottom = min(minimum(toHeat), subClims[1]) top = max(subClims[2], maximum(toHeat)) totalRange = top - bottom - # substitute given x and y labels if needed + # Substitute given x and y labels if needed if isnothing(xlabel) xlabel = "Layer $(xInd) frequency (Hz)" end @@ -286,16 +317,19 @@ function plotSecondLayer(stw, St; title="Second Layer results", xVals=-1, yVals= nPlot += 1 end end - plt + if !isnothing(saveTo) + savefig(plt, saveTo) + end + return plt end """ - jointPlot(thingToPlot, thingName, cSymbol, St; sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false) + jointPlot(thingToPlot, thingName, cSymbol, St; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false) Create a joint plot visualizing the zeroth, first, and second layer scattering results for a specified example. The variable `thingToPlot` is a tuple containing the scattering results for the zeroth, first, and second layers, `thingName` is the title for the plot, `cSymbol` specifies the color gradient to use, and `St` is the scattering transform object. The function allows for various customization options, including shared color scaling, target example selection, frequency digit rounding, and additional plotting options. """ -function jointPlot(thingToPlot, thingName, cSymbol, St; sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false, kwargs...) +function jointPlot(thingToPlot, thingName, cSymbol, St; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false, kwargs...) if sharedColorbar clims = (min(minimum.(thingToPlot)...), max(maximum.(thingToPlot)...)) climszero = clims @@ -338,18 +372,22 @@ function jointPlot(thingToPlot, thingName, cSymbol, St; sharedColorScaling=:exp, # define the spatial locations as they correspond to the input spaceLocs = range(1, size(St)[1], length=length(zeroLay)) - p2 = plotSecondLayer(thingToPlot[2][:, :, :, targetExample], St; title="Second Layer", toHeat=toHeat, logPower=logPower, c=c, clims=climssecond, subClims=climssecond, cbar=false, xVals=(0.000, 0.993), yVals=(0.0, 0.994), transp=true, xlabel="", kwargs...) + p2 = plotSecondLayer(thingToPlot[2][:, :, :, targetExample], St; title="Second Layer", toHeat=toHeat, logPower=logPower, c=c, clims=climssecond, subClims=climssecond, cbar=false, xVals=(0.000, 0.993), yVals=(0.0, 0.994), xlabel="", transp=true, kwargs...) freqs = getMeanFreq(St, δt) freqs = map(x -> round.(x, sigdigits=freqigdigits), freqs) - p1 = heatmap(firstLay, c=c, title="First Layer", clims=climsfirst, cbar=false, yticks=((1:size(firstLay, 1)), ""), xticks=((1:size(firstLay, 2)), ""), bottom_margin=-10Plots.px) - p0 = heatmap(spaceLocs, 1:1, zeroLay, c=c, xlabel="time (ms)\nZeroth Layer", clims=climszero, cbar=false, yticks=nothing, top_margin=-10Plots.px, bottom_margin=10Plots.px) + p1 = heatmap(firstLay, c=c, title="First Layer", clims=climsfirst, cbar=false, yticks=((1:size(firstLay, 1)), ""), xticks=((1:size(firstLay, 2)), ""), bottom_margin=0Plots.px) + p0 = heatmap(spaceLocs, 1:1, zeroLay, c=c, xlabel="Zeroth Layer", clims=climszero, cbar=false, yticks=nothing, top_margin=0Plots.px, bottom_margin=5Plots.px) colorbarOnly = scatter([0, 0], [0, 1], zcolor=[0, 3], clims=climssecond, xlims=(1, 1.1), xshowaxis=false, yshowaxis=false, label="", c=c, grid=false, framestyle=:none) if extraPlot == nothing # extraPlot = scatter([0,0], [0,1], legend=false, grid=false, x="Layer 2 frequency", foreground_color_subplot=:white, top_margin=-10Plots.px, showaxis=false, yticks=nothing) - extraPlot = plot(xlabel="Layer 2 frequency (Hz)", grid=false, xticks=(1:5, ""), showaxis=false, yticks=nothing, bottom_margin=-10Plots.px, top_margin=-20Plots.px) + extraPlot = plot(xlabel="Layer 2 frequency (Hz)", grid=false, xticks=(1:5, ""), showaxis=false, yticks=nothing, bottom_margin=0Plots.px, top_margin=3Plots.px) # extraPlot = heatmap(zeroLay, c=c, xlabel="location\nZeroth Layer", clims=climszero, cbar=false, yticks=nothing, top_margin=-10Plots.px, bottom_margin=10Plots.px) end - titlePlot = plot(title=thingName, grid=false, showaxis=false, xticks=nothing, yticks=nothing, bottom_margin=-10Plots.px) + titlePlot = plot(title=thingName, grid=false, showaxis=false, xticks=nothing, yticks=nothing, bottom_margin=0Plots.px) lay = Plots.@layout [o{0.00001h}; [[a b; c{0.1h} d{0.1h}] b{0.04w}]] - plot(titlePlot, p2, p1, extraPlot, p0, colorbarOnly, layout=lay, size=(1200,800)) + plt = plot(titlePlot, p2, p1, extraPlot, p0, colorbarOnly, layout=lay, size=(1500,1000), margin=6Plots.mm) + if !isnothing(saveTo) + savefig(plt, saveTo) + end + return plt end