diff --git a/docs/docs/tutorials/analysis.ipynb b/docs/docs/tutorials/analysis.ipynb new file mode 100644 index 00000000..8403f24a --- /dev/null +++ b/docs/docs/tutorials/analysis.ipynb @@ -0,0 +1,352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8643b10c", + "metadata": {}, + "source": [ + "# Analysis\n", + "It is time to analyse some data. We here show how to set up an Analysis object and use it to first fit an artificial vanadium measurement. Next, we use the fitted resolution to fit an artificial measurement of a model with diffusion and some elastic scattering. \n", + "\n", + "We extract and plot the relevant parameters. Finally, we show how to fit directly to the diffusion model.\n", + "\n", + "In the near future, it will be possible to fit the width and area of the Lorentzian to the diffusion model as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bca91d3c", + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "from easydynamics.analysis.analysis import Analysis\n", + "from easydynamics.experiment import Experiment\n", + "from easydynamics.sample_model import BrownianTranslationalDiffusion\n", + "from easydynamics.sample_model import ComponentCollection\n", + "from easydynamics.sample_model import DeltaFunction\n", + "from easydynamics.sample_model import Gaussian\n", + "from easydynamics.sample_model import Lorentzian\n", + "from easydynamics.sample_model import Polynomial\n", + "from easydynamics.sample_model.background_model import BackgroundModel\n", + "from easydynamics.sample_model.instrument_model import InstrumentModel\n", + "from easydynamics.sample_model.resolution_model import ResolutionModel\n", + "from easydynamics.sample_model.sample_model import SampleModel\n", + "\n", + "%matplotlib widget" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8deca9b6", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the vanadium data\n", + "vanadium_experiment = Experiment('Vanadium')\n", + "vanadium_experiment.load_hdf5(filename='vanadium_data_example.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6762faba", + "metadata": {}, + "outputs": [], + "source": [ + "# Example of Analysis with a simple sample model and instrument model\n", + "# The scattering from vanadium is purely elastic, so we model it with a\n", + "# delta function\n", + "delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n", + "sample_model = SampleModel(\n", + " components=delta_function,\n", + ")\n", + "\n", + "# The resolution is in this case modeled as a Gaussian. However, we can\n", + "# add as many components as we like to the resolution model\n", + "res_gauss = Gaussian(width=0.1)\n", + "res_gauss.area.fixed = True\n", + "resolution_components = ComponentCollection()\n", + "resolution_components.append_component(res_gauss)\n", + "resolution_model = ResolutionModel(components=resolution_components)\n", + "\n", + "# The background model is created in the same way. In this case, we use\n", + "# a flat background\n", + "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n", + "\n", + "# We combine the resolution abd background model into an instrument\n", + "# model. This model also contains a small energy offset to account for\n", + "# instrument misalignment.\n", + "\n", + "instrument_model = InstrumentModel(\n", + " resolution_model=resolution_model,\n", + " background_model=background_model,\n", + ")\n", + "\n", + "# Collect everything into an analysis object.\n", + "vanadium_analysis = Analysis(\n", + " display_name='Vanadium Full Analysis',\n", + " experiment=vanadium_experiment,\n", + " sample_model=sample_model,\n", + " instrument_model=instrument_model,\n", + ")\n", + "\n", + "# Let us first fit a single Q index and plot the data and model to see\n", + "# how it looks\n", + "fit_result_independent_single_Q = vanadium_analysis.fit(fit_method='independent', Q_index=5)\n", + "vanadium_analysis.plot_data_and_model(Q_index=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e98e3d65", + "metadata": {}, + "outputs": [], + "source": [ + "# It looks good, so let us fit all Q indices independently and plot the\n", + "# results\n", + "fit_result_independent_all_Q = vanadium_analysis.fit(fit_method='independent')\n", + "vanadium_analysis.plot_data_and_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "133e682e", + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect the Parameters as a scipp Dataset\n", + "vanadium_analysis.parameters_to_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfacdf24", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot some of fitted parameters as a function of Q\n", + "vanadium_analysis.plot_parameters(names=['DeltaFunction area'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6f9f316", + "metadata": {}, + "outputs": [], + "source": [ + "vanadium_analysis.plot_parameters(names=['Gaussian width'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "572664a0", + "metadata": {}, + "outputs": [], + "source": [ + "vanadium_analysis.plot_parameters(names=['energy_offset'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3609e6c1", + "metadata": {}, + "outputs": [], + "source": [ + "# Now it's time to look at the data we want to fit. We first load the\n", + "# data\n", + "diffusion_experiment = Experiment('Diffusion')\n", + "diffusion_experiment.load_hdf5(filename='diffusion_data_example.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e685909a", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we set up the model, similarly to how we set up the model for the\n", + "# vanadium data.\n", + "delta_function = DeltaFunction(display_name='DeltaFunction', area=0.2)\n", + "lorentzian = Lorentzian(display_name='Lorentzian', area=0.5, width=0.3)\n", + "component_collection = ComponentCollection(\n", + " components=[delta_function, lorentzian],\n", + ")\n", + "\n", + "sample_model = SampleModel(\n", + " components=component_collection,\n", + ")\n", + "\n", + "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n", + "\n", + "instrument_model = InstrumentModel(\n", + " background_model=background_model,\n", + ")\n", + "\n", + "diffusion_analysis = Analysis(\n", + " display_name='Diffusion Full Analysis',\n", + " experiment=diffusion_experiment,\n", + " sample_model=sample_model,\n", + " instrument_model=instrument_model,\n", + ")\n", + "\n", + "# We need to hack in the resolution model from the vanadium analysis,\n", + "# since the setters and getters overwrite the model. This will be fixed\n", + "# asap.\n", + "diffusion_analysis.instrument_model._resolution_model = (\n", + " vanadium_analysis.instrument_model.resolution_model\n", + ")\n", + "\n", + "# We fix all parameters of the resolution model.\n", + "diffusion_analysis.instrument_model.resolution_model.fix_all_parameters()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c66828eb", + "metadata": {}, + "outputs": [], + "source": [ + "# Let us see how good the starting parameters are\n", + "diffusion_analysis.plot_data_and_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a4b7572", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we fit the data and plot the result. Looks good!\n", + "diffusion_analysis.fit(fit_method='independent')\n", + "diffusion_analysis.plot_data_and_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df14b5c4", + "metadata": {}, + "outputs": [], + "source": [ + "# Let us look at the most interesting fit parameters\n", + "diffusion_analysis.plot_parameters(names=['Lorentzian width', 'Lorentzian area'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb226c8f", + "metadata": {}, + "outputs": [], + "source": [ + "# It will be possible to fit this to a DiffusionModel, but that will\n", + "# come later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c6d7808", + "metadata": {}, + "outputs": [], + "source": [ + "# Let us now fit directly to a diffusion model. We replace the\n", + "# Lorentzian with a Brownian translational diffusion model and keep the\n", + "# other parameters the same.\n", + "delta_function = DeltaFunction(display_name='DeltaFunction', area=0.2)\n", + "component_collection = ComponentCollection(\n", + " components=[delta_function],\n", + ")\n", + "diffusion_model = BrownianTranslationalDiffusion(\n", + " display_name='Brownian Translational Diffusion', diffusion_coefficient=2.4e-9, scale=0.5\n", + ")\n", + "\n", + "sample_model = SampleModel(\n", + " components=component_collection,\n", + " diffusion_models=diffusion_model,\n", + ")\n", + "\n", + "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n", + "\n", + "instrument_model = InstrumentModel(\n", + " background_model=background_model,\n", + ")\n", + "\n", + "diffusion_model_analysis = Analysis(\n", + " display_name='Diffusion Full Analysis',\n", + " experiment=diffusion_experiment,\n", + " sample_model=sample_model,\n", + " instrument_model=instrument_model,\n", + ")\n", + "\n", + "# We again need to hack in the resolution model from the vanadium\n", + "# analysis, since the setters and getters overwrite the model. This will\n", + "# be fixed asap.\n", + "diffusion_model_analysis.instrument_model._resolution_model = (\n", + " vanadium_analysis.instrument_model.resolution_model\n", + ")\n", + "diffusion_model_analysis.instrument_model.resolution_model.fix_all_parameters()\n", + "\n", + "# Let us see how good the starting parameters are\n", + "diffusion_model_analysis.plot_data_and_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd04d359", + "metadata": {}, + "outputs": [], + "source": [ + "# We now fit all the data simultaneously to the diffusion model, then\n", + "# plot the result. Looks good.\n", + "diffusion_model_analysis.fit(fit_method='simultaneous')\n", + "diffusion_model_analysis.plot_data_and_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "842c1f01", + "metadata": {}, + "outputs": [], + "source": [ + "# Let us look at the fitted diffusion coefficient\n", + "diffusion_model.get_all_parameters()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "easydynamics_newbase", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/tutorials/analysis1d.ipynb b/docs/docs/tutorials/analysis1d.ipynb new file mode 100644 index 00000000..48eb8082 --- /dev/null +++ b/docs/docs/tutorials/analysis1d.ipynb @@ -0,0 +1,101 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8643b10c", + "metadata": {}, + "source": [ + "# Analysis1d\n", + "Sometimes, you will only be interested in a particular Q, not the full dataset. For this, use the Analysis1d object. We here show how to set it up to fit an artificial vanadium measurement." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "bca91d3c", + "metadata": {}, + "outputs": [], + "source": [ + "from easydynamics.analysis.analysis1d import Analysis1d\n", + "from easydynamics.experiment import Experiment\n", + "from easydynamics.sample_model import DeltaFunction\n", + "from easydynamics.sample_model import Gaussian\n", + "from easydynamics.sample_model import Polynomial\n", + "from easydynamics.sample_model.background_model import BackgroundModel\n", + "from easydynamics.sample_model.instrument_model import InstrumentModel\n", + "from easydynamics.sample_model.resolution_model import ResolutionModel\n", + "from easydynamics.sample_model.sample_model import SampleModel\n", + "\n", + "%matplotlib widget" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8deca9b6", + "metadata": {}, + "outputs": [], + "source": [ + "vanadium_experiment = Experiment('Vanadium')\n", + "vanadium_experiment.load_hdf5(filename='vanadium_data_example.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "41f842f0", + "metadata": {}, + "outputs": [], + "source": [ + "# Example of Analysis1d with a simple sample model and instrument model\n", + "delta_function = DeltaFunction(display_name='DeltaFunction', area=1)\n", + "sample_model = SampleModel(\n", + " components=delta_function,\n", + ")\n", + "\n", + "res_gauss = Gaussian(width=0.1)\n", + "resolution_model = ResolutionModel(components=res_gauss)\n", + "\n", + "\n", + "background_model = BackgroundModel(components=Polynomial(coefficients=[0.001]))\n", + "\n", + "instrument_model = InstrumentModel(\n", + " resolution_model=resolution_model,\n", + " background_model=background_model,\n", + ")\n", + "\n", + "my_analysis = Analysis1d(\n", + " display_name='Vanadium Analysis',\n", + " experiment=vanadium_experiment,\n", + " sample_model=sample_model,\n", + " instrument_model=instrument_model,\n", + " Q_index=5,\n", + ")\n", + "\n", + "fit_result = my_analysis.fit()\n", + "fig = my_analysis.plot_data_and_model()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "easydynamics_newbase", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/tutorials/convolution.ipynb b/docs/docs/tutorials/convolution.ipynb index 922970f9..b13d7973 100644 --- a/docs/docs/tutorials/convolution.ipynb +++ b/docs/docs/tutorials/convolution.ipynb @@ -109,7 +109,7 @@ "\n", "\n", "temperature = 10.0 # Temperature in Kelvin\n", - "offset = 0.5\n", + "energy_offset = 0.5\n", "upsample_factor = 5\n", "extension_factor = 0.5\n", "plt.figure()\n", @@ -119,7 +119,7 @@ "convolver = Convolution(\n", " sample_components=sample_components,\n", " resolution_components=resolution_components,\n", - " energy=energy - offset,\n", + " energy=energy - energy_offset,\n", " upsample_factor=upsample_factor,\n", " extension_factor=extension_factor,\n", " temperature=temperature,\n", @@ -132,8 +132,8 @@ "\n", "plt.plot(\n", " energy,\n", - " sample_components.evaluate(energy - offset)\n", - " * detailed_balance_factor(energy - offset, temperature),\n", + " sample_components.evaluate(energy - energy_offset)\n", + " * detailed_balance_factor(energy - energy_offset, temperature),\n", " label='Sample Model with DB',\n", " linestyle='--',\n", ")\n", @@ -145,6 +145,70 @@ "plt.ylim(0, 2.5)\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c318f9b8", + "metadata": {}, + "outputs": [], + "source": [ + "# Use some of the extra settings for the numerical convolution\n", + "sample_components = ComponentCollection()\n", + "gaussian = Gaussian(display_name='Gaussian', width=0.3, area=1)\n", + "dho = DampedHarmonicOscillator(display_name='DHO', center=1.0, width=0.3, area=2.0)\n", + "lorentzian = Lorentzian(display_name='Lorentzian', center=-1.0, width=0.2, area=1.0)\n", + "delta = DeltaFunction(display_name='Delta', center=0.4, area=0.5)\n", + "sample_components.append_component(gaussian)\n", + "# sample_components.append_component(dho)\n", + "sample_components.append_component(lorentzian)\n", + "# sample_components.append_component(delta)\n", + "\n", + "resolution_components = ComponentCollection()\n", + "resolution_gaussian = Gaussian(display_name='Resolution Gaussian', width=0.15, area=0.8)\n", + "resolution_lorentzian = Lorentzian(display_name='Resolution Lorentzian', width=0.25, area=0.2)\n", + "resolution_components.append_component(resolution_gaussian)\n", + "# resolution_components.append_component(resolution_lorentzian)\n", + "\n", + "energy = np.linspace(-2, 2, 100)\n", + "\n", + "\n", + "temperature = 10.0 # Temperature in Kelvin\n", + "energy_offset = 0.2\n", + "upsample_factor = 5\n", + "extension_factor = 0.5\n", + "plt.figure()\n", + "plt.xlabel('Energy (meV)')\n", + "plt.ylabel('Intensity (arb. units)')\n", + "\n", + "convolver = Convolution(\n", + " sample_components=sample_components,\n", + " resolution_components=resolution_components,\n", + " energy=energy,\n", + " upsample_factor=upsample_factor,\n", + " extension_factor=extension_factor,\n", + " energy_offset=energy_offset,\n", + " temperature=temperature,\n", + ")\n", + "y = convolver.convolution()\n", + "\n", + "\n", + "plt.plot(energy, y, label='Convoluted Model')\n", + "\n", + "plt.plot(\n", + " energy,\n", + " sample_components.evaluate(energy - energy_offset),\n", + " label='Sample Model',\n", + " linestyle='--',\n", + ")\n", + "\n", + "plt.plot(energy, resolution_components.evaluate(energy), label='Resolution Model', linestyle=':')\n", + "plt.title('Convolution of Sample Model with Resolution Model')\n", + "\n", + "plt.legend()\n", + "plt.ylim(0, 2.5)\n", + "plt.show()" + ] } ], "metadata": { diff --git a/pixi.lock b/pixi.lock index f51bc65b..2e77e94b 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5,6 +5,8 @@ environments: - url: https://conda.anaconda.org/conda-forge/ indexes: - https://pypi.org/simple + options: + pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -78,7 +80,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cc/8f/ec6289987824b29529d0dfda0d74a07cec60e54b9c92f3c9da4c0ac732de/contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/ea/b4/694159c15c52b9f7ec7adf49d50e5f8ee71d3e9ef38adb4445d13dd56c20/coverage-7.13.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -333,7 +335,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/be/45/adfee365d9ea3d853550b2e735f9d66366701c65db7855cd07621732ccfc/contourpy-1.3.3-cp312-cp312-macosx_10_13_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/ce/8a/87af46cccdfa78f53db747b09f5f9a21d5fc38d796834adac09b30a8ce74/coverage-7.13.1-cp312-cp312-macosx_10_13_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -588,7 +590,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/53/3e/405b59cfa13021a56bba395a6b3aca8cec012b45bf177b0eaf7a202cde2c/contourpy-1.3.3-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/82/a8/6e22fdc67242a4a5a153f9438d05944553121c8f4ba70cb072af4c41362e/coverage-7.13.1-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -836,7 +838,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/19/e8/6026ed58a64563186a9ee3f29f41261fd1828f527dd93d33b60feca63352/contourpy-1.3.3-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/fa/dc/7282856a407c621c2aad74021680a01b23010bb8ebf427cf5eacda2e876f/coverage-7.13.1-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -1029,6 +1031,8 @@ environments: - url: https://conda.anaconda.org/conda-forge/ indexes: - https://pypi.org/simple + options: + pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -1102,7 +1106,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5f/4b/6157f24ca425b89fe2eb7e7be642375711ab671135be21e6faa100f7448c/contourpy-1.3.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/f7/7c/347280982982383621d29b8c544cf497ae07ac41e44b1ca4903024131f55/coverage-7.13.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -1358,7 +1362,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/91/2e/c4390a31919d8a78b90e8ecf87cd4b4c4f05a5b48d05ec17db8e5404c6f4/contourpy-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/b4/9b/77baf488516e9ced25fc215a6f75d803493fc3f6a1a1227ac35697910c2a/coverage-7.13.1-cp311-cp311-macosx_10_9_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -1614,7 +1618,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0d/44/c4b0b6095fef4dc9c420e041799591e3b63e9619e3044f7f4f6c21c0ab24/contourpy-1.3.3-cp311-cp311-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/d7/cd/7ab01154e6eb79ee2fab76bf4d89e94c6648116557307ee4ebbb85e5c1bf/coverage-7.13.1-cp311-cp311-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -1863,7 +1867,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/4b/9bd370b004b5c9d8045c6c33cf65bae018b27aca550a3f657cdc99acdbd8/contourpy-1.3.3-cp311-cp311-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/27/56/c216625f453df6e0559ed666d246fcbaaa93f3aa99eaa5080cea1229aa3d/coverage-7.13.1-cp311-cp311-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -2057,6 +2061,8 @@ environments: - url: https://conda.anaconda.org/conda-forge/ indexes: - https://pypi.org/simple + options: + pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -2130,7 +2136,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cc/8f/ec6289987824b29529d0dfda0d74a07cec60e54b9c92f3c9da4c0ac732de/contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/ea/b4/694159c15c52b9f7ec7adf49d50e5f8ee71d3e9ef38adb4445d13dd56c20/coverage-7.13.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -2385,7 +2391,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/be/45/adfee365d9ea3d853550b2e735f9d66366701c65db7855cd07621732ccfc/contourpy-1.3.3-cp312-cp312-macosx_10_13_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/ce/8a/87af46cccdfa78f53db747b09f5f9a21d5fc38d796834adac09b30a8ce74/coverage-7.13.1-cp312-cp312-macosx_10_13_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -2640,7 +2646,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/53/3e/405b59cfa13021a56bba395a6b3aca8cec012b45bf177b0eaf7a202cde2c/contourpy-1.3.3-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/82/a8/6e22fdc67242a4a5a153f9438d05944553121c8f4ba70cb072af4c41362e/coverage-7.13.1-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -2888,7 +2894,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/19/e8/6026ed58a64563186a9ee3f29f41261fd1828f527dd93d33b60feca63352/contourpy-1.3.3-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b8/01/74922a1c552137c05a41fee0c61153753dddc9117d19c7c5902c146c25ab/copier-9.11.3-py3-none-any.whl - - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 + - pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 - pypi: https://files.pythonhosted.org/packages/fa/dc/7282856a407c621c2aad74021680a01b23010bb8ebf427cf5eacda2e876f/coverage-7.13.1-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/f2/728f041460f1b9739b85ee23b45fa5a505962ea11fd85bdbe2a02b021373/darkdetect-0.8.0-py3-none-any.whl @@ -4085,7 +4091,7 @@ packages: requires_python: '>=3.5' - pypi: ./ name: easydynamics - version: 0.1.0+devdirty6 + version: 0.1.1+devdirty20 sha256: de299c914d4a865b9e2fdefa5e3947f37b1f26f73ff9087f7918ee417f3dd288 requires_dist: - darkdetect @@ -4128,8 +4134,7 @@ packages: - validate-pyproject[all] ; extra == 'dev' - versioningit ; extra == 'dev' requires_python: '>=3.11' - editable: true -- pypi: git+https://github.com/easyscience/corelib.git?rev=develop#bd106537fcf522336aa0176aa6ccf215be8a5b86 +- pypi: git+https://github.com/easyscience/corelib.git#ac01d891e271c7e2e5044da69b9ecd7b7114f0c3 name: easyscience version: 2.1.0 requires_dist: diff --git a/src/easydynamics/analysis/__init__.py b/src/easydynamics/analysis/__init__.py new file mode 100644 index 00000000..4cb511b4 --- /dev/null +++ b/src/easydynamics/analysis/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors +# SPDX-License-Identifier: BSD-3-Clause + +from .analysis import Analysis + +__all__ = [ + 'Analysis', +] diff --git a/src/easydynamics/analysis/analysis.py b/src/easydynamics/analysis/analysis.py new file mode 100644 index 00000000..ba120037 --- /dev/null +++ b/src/easydynamics/analysis/analysis.py @@ -0,0 +1,564 @@ +# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors +# SPDX-License-Identifier: BSD-3-Clause + + +import numpy as np +import scipp as sc +from easyscience.fitting.minimizers.utils import FitResults +from easyscience.fitting.multi_fitter import MultiFitter +from easyscience.variable import Parameter +from plopp.backends.matplotlib.figure import InteractiveFigure +from scipp import UnitError + +from easydynamics.analysis.analysis1d import Analysis1d +from easydynamics.analysis.analysis_base import AnalysisBase +from easydynamics.experiment import Experiment +from easydynamics.sample_model import SampleModel +from easydynamics.sample_model.instrument_model import InstrumentModel +from easydynamics.utils.utils import _in_notebook + + +class Analysis(AnalysisBase): + """For analysing two-dimensional data, i.e. intensity as function of + energy and Q. Supports independent fits of each Q value and + simultaneous fits of all Q. + + Args: + display_name (str): Display name of the analysis. + unique_name (str or None): Unique name of the analysis. If None, + a unique name is automatically generated. + experiment (Experiment | None): The Experiment associated with + this Analysis. If None, a default Experiment is created. + sample_model (SampleModel | None): The SampleModel associated + with this Analysis. If None, a default SampleModel is + created. + instrument_model (InstrumentModel | None): The InstrumentModel + associated with this Analysis. If None, a default + InstrumentModel is created. + extra_parameters (Parameter | list[Parameter] | None): + Extra parameters to be included in the analysis for advanced + users. If None, no extra parameters are added. + + Attributes: + experiment (Experiment): The Experiment associated with this + Analysis. + sample_model (SampleModel): The SampleModel associated with this + Analysis. + instrument_model (InstrumentModel): The InstrumentModel + associated with this Analysis. + Q (sc.Variable | None): The Q values from the associated + Experiment, if available. + energy (sc.Variable | None): The energy values from the + associated Experiment, if available. + temperature (Parameter | None): The temperature from the + associated SampleModel, if available. + extra_parameters (list[Parameter]): The extra parameters + included in this Analysis. + """ + + def __init__( + self, + display_name: str = 'MyAnalysis', + unique_name: str | None = None, + experiment: Experiment | None = None, + sample_model: SampleModel | None = None, + instrument_model: InstrumentModel | None = None, + extra_parameters: Parameter | list[Parameter] | None = None, + ): + + # Avoid triggering updates before the object is fully + # initialized + self._call_updaters = False + super().__init__( + display_name=display_name, + unique_name=unique_name, + experiment=experiment, + sample_model=sample_model, + instrument_model=instrument_model, + extra_parameters=extra_parameters, + ) + + self._analysis_list = [] + if self.Q is not None: + for Q_index in range(len(self.Q)): + analysis = Analysis1d( + display_name=f'{self.display_name}_Q{Q_index}', + unique_name=(f'{self.unique_name}_Q{Q_index}'), + experiment=self.experiment, + sample_model=self.sample_model, + instrument_model=self.instrument_model, + extra_parameters=self._extra_parameters, + Q_index=Q_index, + ) + self._analysis_list.append(analysis) + # Now we can allow updates to trigger recalculations + self._call_updaters = True + + ############# + # Properties + ############# + + @property + def analysis_list(self) -> list[Analysis1d]: + """Get the Analysis1d objects associated with this Analysis. + + Returns: + list[Analysis1d]: A list of Analysis1d objects, one for + each Q index. + """ + return self._analysis_list + + @analysis_list.setter + def analysis_list(self, value: list[Analysis1d]) -> None: + """analysis_list is read-only. + + To change the analysis list, modify the experiment, sample + model, or instrument model. + + Raises: + AttributeError: Always raised, since analysis_list is + read-only. + """ + + raise AttributeError( + 'analysis_list is read-only. ' + 'To change the analysis list, modify the experiment, sample model, ' + 'or instrument model.' + ) + + ############# + # Other methods + ############# + def calculate( + self, + Q_index: int | None = None, + ) -> list[np.ndarray] | np.ndarray: + """Calculate model data for a specific Q index. If Q_index is + None, calculate for all Q indices and return a list of arrays. + + Args: + Q_index (int or None): Index of the Q value to calculate + for. If None, calculate for all Q values. + + Returns: + list[np.ndarray] | np.ndarray: If Q_index is None, returns + a list of numpy arrays, one for each Q index. + If Q_index is an integer, returns a single numpy array + for that Q index. + + Raises: + IndexError: If Q_index is not None and is out of bounds. + """ + + if Q_index is None: + return [analysis.calculate() for analysis in self.analysis_list] + + Q_index = self._verify_Q_index(Q_index) + return self.analysis_list[Q_index].calculate() + + def fit( + self, + fit_method: str = 'independent', + Q_index: int | None = None, + ) -> FitResults | list[FitResults]: + """Fit the model to the experimental data. + + Args: + fit_method (str): Method to use for fitting. Options are + "independent" (fit each Q index independently, one after + the other) or "simultaneous" (fit all Q indices + simultaneously). Default is "independent". + Q_index (int or None): If fit_method is "independent", + specify which Q index to fit. If None, fit all Q indices + independently. Ignored if fit_method is "simultaneous". + Default is None. + + Returns: Fit results, which may be a list of FitResults if + fitting independently, or a single FitResults object if + fitting simultaneously. + + Raises: + ValueError: If fit_method is not "independent" or + "simultaneous" + IndexError: If fit_method is "independent" and Q_index is + out of bounds. + """ + + if self.Q is None: + raise ValueError( + 'No Q values available for fitting. Please check the experiment data.' + ) + + Q_index = self._verify_Q_index(Q_index) + + if fit_method == 'independent': + if Q_index is not None: + return self._fit_single_Q(Q_index) + else: + return self._fit_all_Q_independently() + elif fit_method == 'simultaneous': + return self._fit_all_Q_simultaneously() + else: + raise ValueError("Invalid fit method. Choose 'independent' or 'simultaneous'.") + + def plot_data_and_model( + self, + Q_index: int | None = None, + plot_components: bool = True, + add_background: bool = True, + **kwargs, + ) -> InteractiveFigure: + """Plot the experimental data and the model prediction. + Optionally also plot the individual components of the model. + + Uses Plopp for plotting: https://scipp.github.io/plopp/ + + Args: + Q_index (int or None): Index of the Q value to plot. If + None, plot all Q values. Default is None. + plot_components (bool): Whether to plot the individual + components. Default is True. + add_background (bool): Whether to add background components + to the sample model components when plotting. Default is + True. + **kwargs: Additional keyword arguments passed to plopp + for customizing the plot. + + Raises: + ValueError: If Q_index is out of bounds, or if there is no + data to plot, or if there are no Q values available for + plotting. + RuntimeError: If not in a Jupyter notebook environment. + TypeError: If plot_components or add_background is not True + or False. + + Returns: + InteractiveFigure: A Plopp InteractiveFigure containing the + plot of the data and model. + """ + + if Q_index is not None: + Q_index = self._verify_Q_index(Q_index) + return self.analysis_list[Q_index].plot_data_and_model( + plot_components=plot_components, + add_background=add_background, + **kwargs, + ) + + if self.experiment.binned_data is None: + raise ValueError('No data to plot. Please load data first.') + + if not _in_notebook(): + raise RuntimeError('plot_data() can only be used in a Jupyter notebook environment.') + + if self.Q is None: + raise ValueError( + 'No Q values available for plotting. Please check the experiment data.' + ) + + if not isinstance(plot_components, bool): + raise TypeError('plot_components must be True or False.') + + if not isinstance(add_background, bool): + raise TypeError('add_background must be True or False.') + + import plopp as pp + + plot_kwargs_defaults = { + 'title': self.display_name, + 'linestyle': {'Data': 'none', 'Model': '-'}, + 'marker': {'Data': 'o', 'Model': None}, + 'color': {'Data': 'black', 'Model': 'red'}, + 'markerfacecolor': {'Data': 'none', 'Model': 'none'}, + } + data_and_model = { + 'Data': self.experiment.binned_data, + 'Model': self._create_model_array(), + } + + if plot_components: + components = self._create_components_dataset(add_background=add_background) + for key in components.keys(): + data_and_model[key] = components[key] + plot_kwargs_defaults['linestyle'][key] = '--' + plot_kwargs_defaults['marker'][key] = None + + # Overwrite defaults with any user-provided kwargs + plot_kwargs_defaults.update(kwargs) + + fig = pp.slicer( + data_and_model, + **plot_kwargs_defaults, + ) + return fig + + def parameters_to_dataset(self) -> sc.Dataset: + """Creates a scipp dataset with copies of the Parameters in the + model. + + Ensures unit consistency across Q. + + Returns: + sc.Dataset: A dataset where each entry is a parameter, with + dimensions "Q" and values corresponding to the parameter + values. + + Raises: + UnitError: If there are inconsistent units for the same + parameter across different Q values. + """ + + ds = sc.Dataset(coords={'Q': self.Q}) + + # Collect all parameter names + all_names = { + param.name + for analysis in self.analysis_list + for param in analysis.get_all_parameters() + } + + # Storage + values = {name: [] for name in all_names} + variances = {name: [] for name in all_names} + units = {} + + for analysis in self.analysis_list: + pars = {p.name: p for p in analysis.get_all_parameters()} + + for name in all_names: + if name in pars: + p = pars[name] + + # Unit consistency check + if name not in units: + units[name] = p.unit + elif units[name] != p.unit: + try: + p.convert_unit(units[name]) + except Exception as e: + raise UnitError( + f"Inconsistent units for parameter '{name}': " + f'{units[name]} vs {p.unit}' + ) from e + + values[name].append(p.value) + variances[name].append(p.variance) + else: + values[name].append(np.nan) + variances[name].append(np.nan) + + # Build dataset variables + for name in all_names: + ds[name] = sc.Variable( + dims=['Q'], + values=np.asarray(values[name], dtype=float), + variances=np.asarray(variances[name], dtype=float), + unit=units.get(name, None), + ) + + return ds + + def plot_parameters( + self, + names: str | list[str] | None = None, + **kwargs, + ) -> InteractiveFigure: + """Plot fitted parameters as a function of Q. + + Args: + names (str | list[str] | None): Name(s) of the + parameter(s) to plot. If None, plots all parameters. + kwargs: Additional keyword arguments passed to plopp.slicer for + customizing the plot (e.g., title, linestyle, marker, + color). + + Returns: + InteractiveFigure: A Plopp InteractiveFigure containing the + plot of the parameters. + """ + + ds = self.parameters_to_dataset() + + if not names: + names = list(ds.keys()) + + if isinstance(names, str): + names = [names] + + if not isinstance(names, list) or not all(isinstance(name, str) for name in names): + raise TypeError('names must be a string or a list of strings.') + + for name in names: + if name not in ds: + raise ValueError(f"Parameter '{name}' not found in dataset.") + + data_to_plot = {name: ds[name] for name in names} + plot_kwargs_defaults = { + 'linestyle': {name: 'none' for name in names}, + 'marker': {name: 'o' for name in names}, + 'markerfacecolor': {name: 'none' for name in names}, + } + + plot_kwargs_defaults.update(kwargs) + + import plopp as pp + + fig = pp.plot( + data_to_plot, + **plot_kwargs_defaults, + ) + return fig + + ############# + # Private methods - updating models when things change + ############# + + def _on_experiment_changed(self) -> None: + """Update the Q values in the sample and instrument models when + the experiment changes. + + Also update all the Analysi1d objects with the new experiment. + """ + if self._call_updaters: + super()._on_experiment_changed() + for analysis in self.analysis_list: + analysis.experiment = self.experiment + + def _on_sample_model_changed(self) -> None: + """Update the Q values in the sample model when the sample model + changes. + + Also update all the Analysi1d objects with the new sample model. + """ + if self._call_updaters: + super()._on_sample_model_changed() + for analysis in self.analysis_list: + analysis.sample_model = self.sample_model + + def _on_instrument_model_changed(self) -> None: + """Update the Q values in the instrument model when the + instrument model changes. + + Also update all the Analysi1d objects with the new instrument + model. + """ + if self._call_updaters: + super()._on_instrument_model_changed() + for analysis in self.analysis_list: + analysis.instrument_model = self.instrument_model + + ############# + # Private methods + ############# + + def _fit_single_Q(self, Q_index: int) -> FitResults: + """Fit data for a single Q index. + + Args: + Q_index (int): Index of the Q value to fit. + + Returns: + FitResults: The results of the fit for the specified + Q index. + """ + + Q_index = self._verify_Q_index(Q_index) + + return self.analysis_list[Q_index].fit() + + def _fit_all_Q_independently(self) -> list[FitResults]: + """Fit data for all Q indices independently. + + Returns: + list[FitResults]: A list of FitResults, one for each Q + index. + """ + return [analysis.fit() for analysis in self.analysis_list] + + def _fit_all_Q_simultaneously(self) -> FitResults: + """Fit data for all Q indices simultaneously. + + Returns: + FitResults: The results of the simultaneous fit across all + Q indices. + """ + + xs = [] + ys = [] + ws = [] + + for analysis in self.analysis_list: + x, y, weight = self._extract_x_y_weights_from_experiment(analysis.Q_index) + xs.append(x) + ys.append(y) + ws.append(weight) + + # Make sure the convolver is up to date for this Q index + analysis._convolver = analysis._create_convolver() + + mf = MultiFitter( + fit_objects=self.analysis_list, + fit_functions=self.get_fit_functions(), + ) + + results = mf.fit( + x=xs, + y=ys, + weights=ws, + ) + return results + + def get_fit_functions(self) -> list[callable]: + """Get fit functions for all Q indices, which can be used for + simultaneous fitting. + + Returns: + list[callable]: A list of fit functions, one for each + Q index. + """ + return [analysis.as_fit_function() for analysis in self.analysis_list] + + def _create_model_array(self) -> sc.DataArray: + """Create a scipp array for the model. + + Returns: + sc.DataArray: A DataArray containing the model values, with + dimensions "Q" and "energy". + """ + + model = sc.array(dims=['Q', 'energy'], values=self.calculate()) + model_data_array = sc.DataArray( + data=model, + coords={'Q': self.Q, 'energy': self.experiment.energy}, + ) + return model_data_array + + def _create_components_dataset(self, add_background: bool = True) -> sc.Dataset: + """Create a scipp dataset containing the individual components + of the model for plotting. + + Args: + add_background (bool): Whether to add background components + to the sample model components when creating the + dataset. Default is True. + + Raises: + TypeError: If add_background is not True or False. + + Returns: + sc.Dataset: A scipp Dataset where each entry is a component + of the model, with dimensions "Q". + """ + if not isinstance(add_background, bool): + raise TypeError('add_background must be True or False.') + + datasets = [ + analysis._create_components_dataset_single_Q(add_background=add_background) + for analysis in self.analysis_list + ] + + return sc.concat(datasets, dim='Q') + + ############# + # Dunder methods + ############# diff --git a/src/easydynamics/analysis/analysis1d.py b/src/easydynamics/analysis/analysis1d.py new file mode 100644 index 00000000..07fdf6ec --- /dev/null +++ b/src/easydynamics/analysis/analysis1d.py @@ -0,0 +1,583 @@ +# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import scipp as sc +from easyscience.fitting.fitter import Fitter as EasyScienceFitter +from easyscience.fitting.minimizers.utils import FitResults +from easyscience.variable import DescriptorNumber +from easyscience.variable import Parameter +from plopp.backends.matplotlib.figure import InteractiveFigure + +from easydynamics.analysis.analysis_base import AnalysisBase +from easydynamics.convolution.convolution import Convolution +from easydynamics.experiment import Experiment +from easydynamics.sample_model import InstrumentModel +from easydynamics.sample_model import SampleModel +from easydynamics.sample_model.component_collection import ComponentCollection +from easydynamics.sample_model.components.model_component import ModelComponent + + +class Analysis1d(AnalysisBase): + """For analysing one-dimensional data, i.e. intensity as function of + energy for a single Q index. Is used primarily in the Analysis + class, but can also be used on its own for simpler analyses. + + Args: + display_name (str): Display name of the analysis. + unique_name (str or None): Unique name of the analysis. If None, + a unique name is automatically generated. + experiment (Experiment | None): The Experiment associated with + this Analysis. If None, a default Experiment is created. + sample_model (SampleModel | None): The SampleModel associated + with this Analysis. If None, a default SampleModel is + created. + instrument_model (InstrumentModel | None): The InstrumentModel + associated with this Analysis. If None, a default + InstrumentModel is created. + Q_index (int | None): The Q index to analyze. If None, the + analysis will not be able to calculate or fit until a + Q index is set. + extra_parameters (Parameter | list[Parameter] | None): + Extra parameters to be included in the analysis for advanced + users. If None, no extra parameters are added. + + Attributes: + experiment (Experiment): The Experiment associated with this + Analysis. + sample_model (SampleModel): The SampleModel associated with this + Analysis. + instrument_model (InstrumentModel): The InstrumentModel + associated with this Analysis. + Q (sc.Variable | None): The Q values from the associated + Experiment, if available. + energy (sc.Variable | None): The energy values from the + associated Experiment, if available. + temperature (Parameter | None): The temperature from the + associated SampleModel, if available. + Q_index (int | None): The Q index being analyzed. + extra_parameters (list[Parameter]): The extra parameters + included in this Analysis. + """ + + def __init__( + self, + display_name: str = 'MyAnalysis', + unique_name: str | None = None, + experiment: Experiment | None = None, + sample_model: SampleModel | None = None, + instrument_model: InstrumentModel | None = None, + Q_index: int | None = None, + extra_parameters: Parameter | list[Parameter] | None = None, + ): + super().__init__( + display_name=display_name, + unique_name=unique_name, + experiment=experiment, + sample_model=sample_model, + instrument_model=instrument_model, + extra_parameters=extra_parameters, + ) + + self._Q_index = self._verify_Q_index(Q_index) + + self._fit_result = None + if self._Q_index is not None: + self._convolver = self._create_convolver() + else: + self._convolver = None + + ############# + # Properties + ############# + + @property + def Q_index(self) -> int | None: + """Get the Q index associated with this Analysis. + + Returns: + Experiment: The Experiment associated with this Analysis. + """ + + return self._Q_index + + @Q_index.setter + def Q_index(self, value: int | None) -> None: + """Set the Q index for single Q analysis. + + Args: + index (int | None): The Q index. + """ + + self._Q_index = self._verify_Q_index(value) + self._on_Q_index_changed() + + ############# + # Other methods + ############# + + def calculate(self) -> np.ndarray: + """Calculate the model prediction for the chosen Q index. Makes + sure the convolver is up to date before calculating. + + Returns: + np.ndarray: The calculated model prediction. + """ + + self._convolver = self._create_convolver() + + return self._calculate() + + def _calculate(self) -> np.ndarray: + """Calculate the model prediction for the chosen Q index. Does + not check if the convolver is up to date. + + Returns: + np.ndarray: The calculated model prediction. + """ + + sample_intensity = self._evaluate_sample() + + background_intensity = self._evaluate_background() + + sample_plus_background = sample_intensity + background_intensity + + return sample_plus_background + + def fit(self) -> FitResults: + """Fit the model to the experimental data for the chosen Q + index. + + The energy grid is fixed for the duration of the fit. + Convolution objects are created once and reused during + parameter optimization for performance reasons. + + Returns: + FitResult: The result of the fit. + + Raises: + ValueError: If no experiment is associated with this + Analysis. + + Returns: + FitResults: The result of the fit. + """ + if self._experiment is None: + raise ValueError('No experiment is associated with this Analysis.') + + # Create convolver once to reuse during fitting + self._convolver = self._create_convolver() + + fitter = EasyScienceFitter( + fit_object=self, + fit_function=self.as_fit_function(), + ) + + x, y, weights = self._extract_x_y_weights_from_experiment(Q_index=self._require_Q_index()) + fit_result = fitter.fit(x=x, y=y, weights=weights) + + self._fit_result = fit_result + + return fit_result + + def as_fit_function(self, x=None, **kwargs) -> callable: + """Return self._calculate as a fit function. + + The EasyScience fitter requires x as input, but + self._calculate() already uses the correct energy from the + experiment. So we ignore the x input and just return the + calculated model. + + Args: + x: Ignored. The energy grid is taken from the experiment. + kwargs: Ignored. Included for compatibility with the + EasyScience fitter. + """ + + def fit_function(x, **kwargs): + return self._calculate() + + return fit_function + + def get_all_variables(self) -> list[DescriptorNumber]: + """Get all variables used in the analysis. + + Returns: + List[Descriptor]: A list of all variables. + """ + variables = self.sample_model.get_all_variables(Q_index=self.Q_index) + + variables.extend(self.instrument_model.get_all_variables(Q_index=self.Q_index)) + + if self._extra_parameters: + variables.extend(self._extra_parameters) + + return variables + + def plot_data_and_model( + self, + plot_components: bool = True, + add_background=True, + **kwargs, + ) -> InteractiveFigure: + """Plot the experimental data and the model prediction for the + chosen Q index. Optionally also plot the individual components + of the model. + + Uses Plopp for plotting: https://scipp.github.io/plopp/ + + Args: + plot_components (bool): Whether to plot the individual + components of the model. Default is True. + add_background (bool): Whether to add the background to the + model prediction when plotting individual components. + kwargs: Keyword arguments to pass to the plotting + function. + + Returns: + InteractiveFigure: A plot of the data and model. + """ + import plopp as pp + + if self.experiment.data is None: + raise ValueError('No data to plot. Please load data first.') + + data = self.experiment.data['Q', self.Q_index] + model_array = self._create_sample_scipp_array() + + component_dataset = self._create_components_dataset_single_Q(add_background=add_background) + + # Create a dataset containing the data, model, and individual + # components for plotting. + data_and_model = sc.Dataset({ + 'Data': data, + 'Model': model_array, + }) + + data_and_model = sc.merge(data_and_model, component_dataset) + plot_kwargs_defaults = { + 'title': self.display_name, + 'linestyle': {'Data': 'none', 'Model': '-'}, + 'marker': {'Data': 'o', 'Model': 'none'}, + 'color': {'Data': 'black', 'Model': 'red'}, + 'markerfacecolor': {'Data': 'none', 'Model': 'none'}, + } + + if plot_components: + for comp_name in component_dataset.keys(): + plot_kwargs_defaults['linestyle'][comp_name] = '--' + plot_kwargs_defaults['marker'][comp_name] = None + + # Overwrite defaults with any user-provided kwargs + plot_kwargs_defaults.update(kwargs) + + fig = pp.plot( + data_and_model, + **plot_kwargs_defaults, + ) + return fig + + ############# + # Private methods: small utilities + ############# + + def _require_Q_index(self) -> int: + """Get the Q index, ensuring it is set. Raises a ValueError if + the Q index is not set. + + Returns: + int: The Q index. + + Raises: + ValueError: If the Q index is not set. + """ + if self._Q_index is None: + raise ValueError('Q_index must be set.') + return self._Q_index + + def _on_Q_index_changed(self) -> None: + """Handle changes to the Q index. + + This method is called whenever the Q index is changed. It + updates the Convolution object for the new Q index. + """ + self._convolver = self._create_convolver() + + ############# + # Private methods: evaluation + ############# + + def _evaluate_components( + self, + components: ComponentCollection | ModelComponent, + convolver: Convolution | None = None, + convolve: bool = True, + ) -> np.ndarray: + """Calculate the contribution of a set of components, optionally + convolving with the resolution. + + If convolve is True and a + Convolution object is provided (for full model evaluation), we + use it to perform the convolution of the components with the + resolution. + If convolve is True but no Convolution object is + provided, create a new Convolution object for the given + components (for individual components). + If convolve is False, evaluate the components directly without + convolution (for background). + + Args: + components (ComponentCollection | ModelComponent): + The components to evaluate. + convolver (Convolution | None): An optional Convolution + object to use for convolution. If None, a new + Convolution object will be created if convolve is True. + convolve (bool): + Whether to perform convolution with the resolution. + Default is True. + """ + + Q_index = self._require_Q_index() + energy = self.energy.values + energy_offset = self.instrument_model.get_energy_offset_at_Q(Q_index) + + # If there are no components, return zero + if isinstance(components, ComponentCollection) and components.is_empty: + return np.zeros_like(energy) + + # No convolution + if not convolve: + return components.evaluate(energy - energy_offset.value) + + # If a convolver is provided, use it. This allows reusing the + # same convolver for multiple evaluations during fitting for + # performance reasons. + if convolver is not None: + return convolver.convolution() + + # If no convolver is provided, create a new one. This is for + # evaluating individual components for plotting, where + # performance is not important. + + # We don't create a convolver if the resolution is empty. + resolution = self.instrument_model.resolution_model.get_component_collection(Q_index) + if resolution.is_empty: + return components.evaluate(energy - energy_offset.value) + + conv = Convolution( + sample_components=components, + resolution_components=resolution, + energy=energy, + temperature=self.temperature, + energy_offset=energy_offset, + ) + return conv.convolution() + + def _evaluate_sample(self) -> np.ndarray: + """Evaluate the sample contribution for a given Q index. + + Assumes that self._convolver is up to date. + + Returns: + np.ndarray: The evaluated sample contribution. + """ + Q_index = self._require_Q_index() + components = self.sample_model.get_component_collection(Q_index=Q_index) + return self._evaluate_components( + components=components, + convolver=self._convolver, + convolve=True, + ) + + def _evaluate_sample_component( + self, + component: ModelComponent, + ) -> np.ndarray: + """Evaluate a single sample component for the chosen Q index. + + Args: + component (ModelComponent): The sample component to + evaluate. + + Returns: + np.ndarray: The evaluated sample component contribution. + """ + return self._evaluate_components( + components=component, + convolver=None, + convolve=True, + ) + + def _evaluate_background(self) -> np.ndarray: + """Evaluate the background contribution for the chosen Q index. + + Returns: + np.ndarray: The evaluated background contribution. + """ + Q_index = self._require_Q_index() + background_components = self.instrument_model.background_model.get_component_collection( + Q_index=Q_index + ) + return self._evaluate_components( + components=background_components, + convolver=None, + convolve=False, + ) + + def _evaluate_background_component( + self, + component: ModelComponent, + ) -> np.ndarray: + """Evaluate a single background component for the chosen Q + index. + + Args: + component (ModelComponent): The background component to + evaluate. + + Returns: + np.ndarray: The evaluated background component contribution. + """ + + return self._evaluate_components( + components=component, + convolver=None, + convolve=False, + ) + + def _create_convolver(self) -> Convolution | None: + """Initialize and return a Convolution object for the chosen Q + index. If the necessary components for convolution are not + available, return None. + + Returns: + Convolution | None: The initialized Convolution object or + None if not available. + """ + Q_index = self._require_Q_index() + + sample_components = self.sample_model.get_component_collection(Q_index) + if sample_components.is_empty: + return None + + resolution_components = self.instrument_model.resolution_model.get_component_collection( + Q_index + ) + if resolution_components.is_empty: + return None + energy = self.energy + # TODO: allow convolution options to be set. + convolver = Convolution( + sample_components=sample_components, + resolution_components=resolution_components, + energy=energy, + temperature=self.temperature, + energy_offset=self.instrument_model.get_energy_offset_at_Q(Q_index), + ) + return convolver + + ############# + # Private methods: create scipp arrays for plotting + ############# + + def _create_component_scipp_array( + self, + component: ModelComponent, + background: np.ndarray | None = None, + ) -> sc.DataArray: + """Create a scipp DataArray for a single component. Adds the + background if it is not None. + + Args: + component (ModelComponent): The component to evaluate + background (np.ndarray | None): Optional background to add + to the component. + + Returns: + sc.DataArray: The model calculation of the component. + """ + + values = self._evaluate_sample_component(component=component) + if background is not None: + values += background + return self._to_scipp_array(values=values) + + def _create_background_component_scipp_array( + self, + component: ModelComponent, + ) -> sc.DataArray: + """Create a scipp DataArray for a single background component. + + Args: + component (ModelComponent): The component to evaluate. + + Returns: + sc.DataArray: The model calculation of the component. + """ + + values = self._evaluate_background_component(component=component) + return self._to_scipp_array(values=values) + + def _create_sample_scipp_array(self) -> sc.DataArray: + """Create a scipp DataArray for the full sample model including + background. + + Returns: + sc.DataArray: The model calculation of the full sample + model. + """ + values = self._calculate() + return self._to_scipp_array(values=values) + + def _create_components_dataset_single_Q( + self, + add_background: bool = True, + ) -> dict[str, sc.DataArray]: + """Create sc.DataArrays for all sample and background + components. + + Args: + add_background (bool): Whether to add background components. + + Returns: + dict[str, sc.DataArray]: A dictionary of component names to + their corresponding sc.DataArrays. + """ + + scipp_arrays = {} + sample_components = self.sample_model.get_component_collection( + Q_index=self.Q_index + ).components + + background_components = self.instrument_model.background_model.get_component_collection( + Q_index=self.Q_index + ).components + background = self._evaluate_background() if add_background else None + for component in sample_components: + scipp_arrays[component.display_name] = self._create_component_scipp_array( + component=component, background=background + ) + for component in background_components: + scipp_arrays[component.display_name] = self._create_background_component_scipp_array( + component=component + ) + return sc.Dataset(scipp_arrays) + + def _to_scipp_array(self, values: np.ndarray) -> sc.DataArray: + """Convert a numpy array of values to a sc.DataArray with the + correct coordinates for energy and Q. + + Args: + values (np.ndarray): The values to convert. + + Returns: + sc.DataArray: The converted sc.DataArray. + """ + + return sc.DataArray( + data=sc.array(dims=['energy'], values=values), + coords={ + 'energy': self.energy, + 'Q': self.Q[self.Q_index], + }, + ) diff --git a/src/easydynamics/analysis/analysis_base.py b/src/easydynamics/analysis/analysis_base.py new file mode 100644 index 00000000..07136062 --- /dev/null +++ b/src/easydynamics/analysis/analysis_base.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import scipp as sc +from easyscience.base_classes.model_base import ModelBase as EasyScienceModelBase +from easyscience.variable import Parameter + +from easydynamics.experiment import Experiment +from easydynamics.sample_model import InstrumentModel +from easydynamics.sample_model import SampleModel + + +class AnalysisBase(EasyScienceModelBase): + """Base class for analysis in EasyDynamics. This class is not meant + to be used directly. + + An Analysis consists of an Experiment, a SampleModel, and an + InstrumentModel. The Experiment contains the data to be fitted, the + SampleModel contains the model for the sample, and the + InstrumentModel contains the model for the instrument, including + background and resolution + + Args: + display_name (str): Display name of the analysis. + unique_name (str or None): Unique name of the analysis. If None, + a unique name is automatically generated. + experiment (Experiment | None): The Experiment associated with + this Analysis. If None, a default Experiment is created. + sample_model (SampleModel | None): The SampleModel associated + with this Analysis. If None, a default SampleModel is + created. + instrument_model (InstrumentModel | None): The InstrumentModel + associated with this Analysis. If None, a default + InstrumentModel is created. + extra_parameters (Parameter | list[Parameter] | None): + Extra parameters to be included in the analysis for advanced + users. If None, no extra parameters are added. + + Attributes: + experiment (Experiment): The Experiment associated with this + Analysis. + sample_model (SampleModel): The SampleModel associated with this + Analysis. + instrument_model (InstrumentModel): The InstrumentModel + associated with this Analysis. + Q (sc.Variable | None): The Q values from the associated + Experiment, if available. + energy (sc.Variable | None): The energy values from the + associated Experiment, if available. + temperature (Parameter | None): The temperature from the + associated SampleModel, if available. + extra_parameters (list[Parameter]): The extra parameters + included in this Analysis. + """ + + def __init__( + self, + display_name: str = 'MyAnalysis', + unique_name: str | None = None, + experiment: Experiment | None = None, + sample_model: SampleModel | None = None, + instrument_model: InstrumentModel | None = None, + extra_parameters: Parameter | list[Parameter] | None = None, + ): + super().__init__(display_name=display_name, unique_name=unique_name) + + if experiment is None: + self._experiment = Experiment() + elif isinstance(experiment, Experiment): + self._experiment = experiment + else: + raise TypeError('experiment must be an instance of Experiment or None.') + + if sample_model is None: + self._sample_model = SampleModel() + elif isinstance(sample_model, SampleModel): + self._sample_model = sample_model + else: + raise TypeError('sample_model must be an instance of SampleModel or None.') + + if instrument_model is None: + self._instrument_model = InstrumentModel() + elif isinstance(instrument_model, InstrumentModel): + self._instrument_model = instrument_model + else: + raise TypeError('instrument_model must be an instance of InstrumentModel or None.') + + if extra_parameters is not None: + if isinstance(extra_parameters, Parameter): + self._extra_parameters = [extra_parameters] + elif isinstance(extra_parameters, list) and all( + isinstance(p, Parameter) for p in extra_parameters + ): + self._extra_parameters = extra_parameters + else: + raise TypeError('extra_parameters must be a Parameter or a list of Parameters.') + else: + self._extra_parameters = [] + + self._on_experiment_changed() + + ############# + # Properties + ############# + + @property + def experiment(self) -> Experiment: + """Get the Experiment associated with this Analysis. + + Returns: + Experiment: The Experiment associated with this Analysis. + """ + + return self._experiment + + @experiment.setter + def experiment(self, value: Experiment) -> None: + """Set the Experiment for this Analysis. + + Raises: + TypeError: if value is not an Experiment. + """ + + if not isinstance(value, Experiment): + raise TypeError('experiment must be an instance of Experiment') + self._experiment = value + self._on_experiment_changed() + + @property + def sample_model(self) -> SampleModel: + """Get the SampleModel associated with this Analysis. + + Returns: + SampleModel: The SampleModel associated with this Analysis. + """ + + return self._sample_model + + @sample_model.setter + def sample_model(self, value: SampleModel) -> None: + """Set the SampleModel for this Analysis. + + Raises: + TypeError: if value is not a SampleModel. + """ + if not isinstance(value, SampleModel): + raise TypeError('sample_model must be an instance of SampleModel') + self._sample_model = value + self._on_sample_model_changed() + + @property + def instrument_model(self) -> InstrumentModel: + """Get the InstrumentModel associated with this Analysis. + + Returns: + InstrumentModel: The InstrumentModel associated with this + Analysis. + """ + return self._instrument_model + + @instrument_model.setter + def instrument_model(self, value: InstrumentModel) -> None: + """Set the InstrumentModel for this Analysis. + + Raises: + TypeError: if value is not an InstrumentModel. + """ + if not isinstance(value, InstrumentModel): + raise TypeError('instrument_model must be an instance of InstrumentModel') + self._instrument_model = value + self._on_instrument_model_changed() + + @property + def Q(self) -> sc.Variable | None: + """Get the Q values from the associated Experiment, if + available. + + Returns: + sc.Variable: The Q values from the associated Experiment, + if available. + None: If the Experiment does not have any data. + """ + return self.experiment.Q + + @Q.setter + def Q(self, value) -> None: + """Q cannot be set, as it is a read-only property derived from + the Experiment. + + Raises: + AttributeError: If trying to set Q. + """ + raise AttributeError('Q is a read-only property derived from the Experiment.') + + @property + def energy(self) -> sc.Variable | None: + """Get the energy values from the associated Experiment, if + available. + + Returns: + sc.Variable: The energy values from the associated + Experiment, if available. + None: If the Experiment does not have any data. + """ + + return self.experiment.energy + + @energy.setter + def energy(self, value) -> None: + """Energy cannot be set, as it is a read-only property derived + from the Experiment. + + Raises: + AttributeError: If trying to set energy. + """ + + raise AttributeError('energy is a read-only property derived from the Experiment.') + + @property + def temperature(self) -> Parameter | None: + """Get the temperature from the associated SampleModel, if + available. + + Returns: + Parameter: The temperature from the associated SampleModel, + if available. + None: If the SampleModel does not have a temperature. + """ + + return self.sample_model.temperature + + @temperature.setter + def temperature(self, value) -> None: + """Temperature cannot be set, as it is a read-only property + derived from the SampleModel. + + Raises: + AttributeError: If trying to set temperature. + """ + + raise AttributeError('temperature is a read-only property derived from the SampleModel.') + + @property + def extra_parameters(self) -> list[Parameter]: + """Get the extra parameters included in this Analysis. + + Returns: + list[Parameter]: The extra parameters included in this + Analysis. + """ + return self._extra_parameters + + @extra_parameters.setter + def extra_parameters(self, value: Parameter | list[Parameter]) -> None: + """Set the extra parameters for this Analysis. + + Args: + value (Parameter | list[Parameter]): The extra parameters to + include in this Analysis. + + Raises: + TypeError: If value is not a Parameter or a list of + Parameters. + """ + if isinstance(value, Parameter): + self._extra_parameters = [value] + elif isinstance(value, list) and all(isinstance(p, Parameter) for p in value): + self._extra_parameters = value + else: + raise TypeError('extra_parameters must be a Parameter or a list of Parameters.') + + ############# + # Other methods + ############# + + ############# + # Private methods + ############# + + def _on_experiment_changed(self) -> None: + """Update the Q values in the sample and instrument models when + the experiment changes. + """ + self.sample_model.Q = self.Q + self.instrument_model.Q = self.Q + + def _on_sample_model_changed(self) -> None: + """Update the Q values in the sample model when the sample model + changes. + """ + self.sample_model.Q = self.Q + + def _on_instrument_model_changed(self) -> None: + """Update the Q values in the instrument model when the + instrument model changes. + """ + self.instrument_model.Q = self.Q + + def _verify_Q_index(self, Q_index: int | None) -> int | None: + """Verify that the Q index is valid. + + Args: + Q_index (int | None): The Q index to verify. + + Returns: + int | None: The verified Q index. + + Raises: + IndexError: If the Q index is not valid. + """ + if Q_index is not None: + if ( + not isinstance(Q_index, int) + or Q_index < 0 + or (self.Q is not None and Q_index >= len(self.Q)) + ): + raise IndexError('Q_index must be a valid index for the Q values.') + return Q_index + + def _extract_x_y_weights_from_experiment( + self, Q_index: int + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extract the x, y, and weights arrays from the experiment for + the given Q index. + + Args: + Q_index (int): The Q index to extract the data for. + + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray]: The x, y, and + weights arrays extracted from the experiment for the + given Q index. + """ + data = self.experiment.data['Q', Q_index] + x = data.coords['energy'].values + y = data.values + e = data.variances**0.5 + if np.any(e == 0): + raise ValueError('Cannot compute weights: some variances are zero.') + weights = 1.0 / e + return x, y, weights + + ############# + # Dunder methods + ############# + + def __repr__(self) -> str: + """Return a string representation of the Analysis. + + Returns: + str: A string representation of the Analysis. + """ + return f' {self.__class__.__name__} (display_name={self.display_name}, \ + unique_name={self.unique_name})' diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py index cfa56c9f..031d5975 100644 --- a/src/easydynamics/convolution/analytical_convolution.py +++ b/src/easydynamics/convolution/analytical_convolution.py @@ -3,6 +3,7 @@ import numpy as np import scipp as sc +from easyscience.variable import Parameter from scipy.special import voigt_profile from easydynamics.convolution.convolution_base import ConvolutionBase @@ -12,8 +13,7 @@ from easydynamics.sample_model import Voigt from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent - -Numerical = float | int +from easydynamics.utils.utils import Numeric class AnalyticalConvolution(ConvolutionBase): @@ -49,12 +49,14 @@ def __init__( energy_unit: str | sc.Unit = 'meV', sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, + energy_offset: Numeric | Parameter = 0.0, ): super().__init__( energy=energy, energy_unit=energy_unit, sample_components=sample_components, resolution_components=resolution_components, + energy_offset=energy_offset, ) def convolution( @@ -77,16 +79,8 @@ def convolution( If component pair cannot be handled analytically. """ - # prepare list of components - if isinstance(self.sample_components, ComponentCollection): - sample_components = self.sample_components.components - else: - sample_components = [self.sample_components] - - if isinstance(self.resolution_components, ComponentCollection): - resolution_components = self.resolution_components.components - else: - resolution_components = [self.resolution_components] + sample_components = self.sample_components.components + resolution_components = self.resolution_components.components total = np.zeros_like(self.energy.values, dtype=float) @@ -199,7 +193,7 @@ def _convolute_delta_any( The evaluated convolution values at self.energy. """ return sample_component.area.value * resolution_components.evaluate( - self.energy.values - sample_component.center.value + self.energy_with_offset.values - sample_component.center.value ) def _convolute_gaussian_gaussian( @@ -420,7 +414,7 @@ def _gaussian_eval( """ normalization = 1 / (np.sqrt(2 * np.pi) * width) - exponent = -0.5 * ((self.energy.values - center) / width) ** 2 + exponent = -0.5 * ((self.energy_with_offset.values - center) / width) ** 2 return area * normalization * np.exp(exponent) @@ -443,7 +437,7 @@ def _lorentzian_eval(self, area: float, center: float, width: float) -> np.ndarr """ normalization = width / np.pi - denominator = (self.energy.values - center) ** 2 + width**2 + denominator = (self.energy_with_offset.values - center) ** 2 + width**2 return area * normalization / denominator @@ -471,4 +465,6 @@ def _voigt_eval( The evaluated Voigt profile values at self.energy. """ - return area * voigt_profile(self.energy.values - center, gaussian_width, lorentzian_width) + return area * voigt_profile( + self.energy_with_offset.values - center, gaussian_width, lorentzian_width + ) diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py index 542515e9..b4fa19e3 100644 --- a/src/easydynamics/convolution/convolution.py +++ b/src/easydynamics/convolution/convolution.py @@ -14,8 +14,7 @@ from easydynamics.sample_model import Lorentzian from easydynamics.sample_model import Voigt from easydynamics.sample_model.components.model_component import ModelComponent - -Numerical = float | int +from easydynamics.utils.utils import Numeric class Convolution(NumericalConvolutionBase): @@ -77,9 +76,10 @@ def __init__( energy: np.ndarray | sc.Variable, sample_components: ComponentCollection | ModelComponent, resolution_components: ComponentCollection | ModelComponent, - upsample_factor: Numerical = 5, - extension_factor: Numerical = 0.2, - temperature: Parameter | Numerical | None = None, + energy_offset: Numeric | Parameter = 0.0, + upsample_factor: Numeric = 5, + extension_factor: Numeric = 0.2, + temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', energy_unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, @@ -90,6 +90,7 @@ def __init__( energy=energy, sample_components=sample_components, resolution_components=resolution_components, + energy_offset=energy_offset, upsample_factor=upsample_factor, extension_factor=extension_factor, temperature=temperature, @@ -140,7 +141,9 @@ def _convolve_delta_functions(self) -> np.ndarray: 'No detailed balance correction is applied to delta functions.' return sum( delta.area.value - * self._resolution_components.evaluate(self.energy.values - delta.center.value) + * self._resolution_components.evaluate( + self.energy_with_offset.values - delta.center.value + ) for delta in self._delta_sample_components.components ) @@ -245,6 +248,7 @@ def _set_convolvers(self) -> None: if self._analytical_sample_components.components: self._analytical_convolver = AnalyticalConvolution( energy=self.energy, + energy_offset=self.energy_offset, sample_components=self._analytical_sample_components, resolution_components=self._resolution_components, ) @@ -254,6 +258,7 @@ def _set_convolvers(self) -> None: if self._numerical_sample_components.components: self._numerical_convolver = NumericalConvolution( energy=self.energy, + energy_offset=self.energy_offset, sample_components=self._numerical_sample_components, resolution_components=self._resolution_components, upsample_factor=self.upsample_factor, diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py index 34eab3f4..d0235eaf 100644 --- a/src/easydynamics/convolution/convolution_base.py +++ b/src/easydynamics/convolution/convolution_base.py @@ -3,11 +3,11 @@ import numpy as np import scipp as sc +from easyscience.variable import Parameter from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent - -Numerical = float | int +from easydynamics.utils.utils import Numeric class ConvolutionBase: @@ -31,8 +31,9 @@ def __init__( sample_components: ComponentCollection | ModelComponent = None, resolution_components: ComponentCollection | ModelComponent = None, energy_unit: str | sc.Unit = 'meV', + energy_offset: Numeric | Parameter = 0.0, ): - if isinstance(energy, Numerical): + if isinstance(energy, Numeric): energy = np.array([float(energy)]) if not isinstance(energy, (np.ndarray, sc.Variable)): @@ -44,8 +45,17 @@ def __init__( if isinstance(energy, np.ndarray): energy = sc.array(dims=['energy'], values=energy, unit=energy_unit) + if isinstance(energy_offset, Numeric): + energy_offset = Parameter( + name='energy_offset', value=float(energy_offset), unit=energy_unit + ) + + if not isinstance(energy_offset, Parameter): + raise TypeError('Energy_offset must be a number or a Parameter.') + self._energy = energy self._energy_unit = energy_unit + self._energy_offset = energy_offset if sample_components is not None and not ( isinstance(sample_components, ComponentCollection) @@ -54,6 +64,8 @@ def __init__( raise TypeError( f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) + if isinstance(sample_components, ModelComponent): + sample_components = ComponentCollection(components=[sample_components]) self._sample_components = sample_components if resolution_components is not None and not ( @@ -63,8 +75,50 @@ def __init__( raise TypeError( f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) + if isinstance(resolution_components, ModelComponent): + resolution_components = ComponentCollection(components=[resolution_components]) self._resolution_components = resolution_components + @property + def energy_offset(self) -> Parameter: + """Get the energy offset.""" + return self._energy_offset + + @energy_offset.setter + def energy_offset(self, energy_offset: Numeric | Parameter) -> None: + """Set the energy offset. + Args: + energy_offset : Number or Parameter + The energy offset to apply to the convolution. + + Raises: + TypeError: If energy_offset is not a number or a Parameter. + """ + if not isinstance(energy_offset, Parameter | Numeric): + raise TypeError('Energy_offset must be a number or a Parameter.') + + if isinstance(energy_offset, Numeric): + self._energy_offset.value = float(energy_offset) + + if isinstance(energy_offset, Parameter): + self._energy_offset = energy_offset + + @property + def energy_with_offset(self) -> sc.Variable: + """Get the energy with the offset applied.""" + energy_with_offset = self.energy.copy() + energy_with_offset.values = self.energy.values - self.energy_offset.value + return energy_with_offset + + @energy_with_offset.setter + def energy_with_offset(self, value) -> None: + """Energy with offset is a read-only property derived from + energy and energy_offset. + """ + raise AttributeError( + 'Energy with offset is a read-only property derived from energy and energy_offset.' + ) + @property def energy(self) -> sc.Variable: """Get the energy.""" @@ -84,7 +138,7 @@ def energy(self, energy: np.ndarray) -> None: scipp Variable. """ - if isinstance(energy, Numerical): + if isinstance(energy, Numeric): energy = np.array([float(energy)]) if not isinstance(energy, (np.ndarray, sc.Variable)): @@ -112,18 +166,34 @@ def energy_unit(self, unit_str: str) -> None: ) # noqa: E501 def convert_energy_unit(self, energy_unit: str | sc.Unit) -> None: - """Convert the energy to the specified unit + """Convert the energy and energy_offset to the specified unit. + Args: energy_unit : str or sc.Unit The unit of the energy. Raises: TypeError: If energy_unit is not a string or scipp unit. + UnitError: If energy cannot be converted to the specified + unit. """ if not isinstance(energy_unit, (str, sc.Unit)): raise TypeError('Energy unit must be a string or scipp unit.') - self.energy = sc.to_unit(self.energy, energy_unit) + old_energy = self.energy.copy() + try: + self.energy = sc.to_unit(self.energy, energy_unit) + except Exception as e: + self.energy = old_energy + raise e + + old_energy_offset = self.energy_offset + try: + self.energy_offset.convert_unit(energy_unit) + except Exception as e: + self.energy_offset = old_energy_offset + raise e + self._energy_unit = energy_unit @property @@ -146,6 +216,9 @@ def sample_components(self, sample_components: ComponentCollection | ModelCompon raise TypeError( f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) + + if isinstance(sample_components, ModelComponent): + sample_components = ComponentCollection(components=[sample_components]) self._sample_components = sample_components @property @@ -171,4 +244,7 @@ def resolution_components( raise TypeError( f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) + + if isinstance(resolution_components, ModelComponent): + resolution_components = ComponentCollection(components=[resolution_components]) self._resolution_components = resolution_components diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index 125c4451..1b8ca6d1 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -10,8 +10,7 @@ from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.detailed_balance import _detailed_balance_factor as detailed_balance_factor - -Numerical = float | int +from easydynamics.utils.utils import Numeric class NumericalConvolution(NumericalConvolutionBase): @@ -53,9 +52,10 @@ def __init__( energy: np.ndarray | sc.Variable, sample_components: ComponentCollection | ModelComponent, resolution_components: ComponentCollection | ModelComponent, - upsample_factor: Numerical = 5, - extension_factor: float = 0.2, - temperature: Parameter | float | None = None, + energy_offset: Numeric | Parameter = 0.0, + upsample_factor: Numeric = 5, + extension_factor: Numeric = 0.2, + temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', energy_unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, @@ -64,6 +64,7 @@ def __init__( energy=energy, sample_components=sample_components, resolution_components=resolution_components, + energy_offset=energy_offset, upsample_factor=upsample_factor, extension_factor=extension_factor, temperature=temperature, @@ -97,13 +98,15 @@ def convolution( # Evaluate sample model. If called via the Convolution class, # delta functions are already filtered out. sample_vals = self.sample_components.evaluate( - self._energy_grid.energy_dense - self._energy_grid.energy_even_length_offset + self._energy_grid.energy_dense + - self._energy_grid.energy_even_length_offset + - self.energy_offset.value ) # Detailed balance correction if self.temperature is not None: detailed_balance_factor_correction = detailed_balance_factor( - energy=self._energy_grid.energy_dense, + energy=self._energy_grid.energy_dense - self.energy_offset.value, temperature=self.temperature, energy_unit=self.energy.unit, divide_by_temperature=self.normalize_detailed_balance, diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index ffcf0058..5a60f5d8 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -63,6 +63,7 @@ def __init__( energy: np.ndarray | sc.Variable, sample_components: ComponentCollection | ModelComponent, resolution_components: ComponentCollection | ModelComponent, + energy_offset: Numerical | Parameter = 0.0, upsample_factor: Numerical = 5, extension_factor: float = 0.2, temperature: Parameter | float | None = None, @@ -75,6 +76,7 @@ def __init__( sample_components=sample_components, resolution_components=resolution_components, energy_unit=energy_unit, + energy_offset=energy_offset, ) if temperature is not None and not isinstance(temperature, (Numerical, Parameter)): @@ -239,9 +241,9 @@ def _create_energy_grid( The dense grid created by upsampling and extending energy. The EnergyGrid has the following attributes: - energy_dense : np.ndarray + energy_dense : np.ndarray The upsampled and extended energy array. - energy_dense_centered : np.ndarray + energy_dense_centered : np.ndarray The centered version of energy_dense (used for resolution evaluation). energy_dense_step : float diff --git a/src/easydynamics/experiment/experiment.py b/src/easydynamics/experiment/experiment.py index b3df2a11..ff48706f 100644 --- a/src/easydynamics/experiment/experiment.py +++ b/src/easydynamics/experiment/experiment.py @@ -1,6 +1,4 @@ import os -import warnings -from typing import Optional import plopp as pp import scipp as sc @@ -8,6 +6,8 @@ from scipp.io import load_hdf5 as sc_load_hdf5 from scipp.io import save_hdf5 as sc_save_hdf5 +from easydynamics.utils.utils import _in_notebook + class Experiment(NewBase): """Holds data from an experiment as a sc.DataArray along with @@ -29,7 +29,7 @@ def __init__( ) if data is None: - self._data: Optional[sc.DataArray] = None + self._data = None elif isinstance(data, str): self.load_hdf5(filename=data) elif isinstance(data, sc.DataArray): @@ -54,7 +54,7 @@ def data(self) -> sc.DataArray | None: return self._data @data.setter - def data(self, value: sc.DataArray): + def data(self, value: sc.DataArray) -> None: """Set the dataset associated with this experiment.""" if not isinstance(value, sc.DataArray): raise TypeError(f'Data must be a sc.DataArray, not {type(value).__name__}') @@ -70,7 +70,7 @@ def binned_data(self) -> sc.DataArray | None: return self._binned_data @binned_data.setter - def binned_data(self, value: sc.DataArray): + def binned_data(self, value: sc.DataArray) -> None: """Set the binned dataset associated with this experiment.""" raise AttributeError('binned_data is a read-only property. Use rebin() to rebin the data') @@ -78,25 +78,23 @@ def binned_data(self, value: sc.DataArray): def Q(self) -> sc.Variable | None: """Get the Q values from the dataset.""" if self._data is None: - warnings.warn('No data loaded.', UserWarning) return None return self._binned_data.coords['Q'] @Q.setter - def Q(self, value: sc.Variable): + def Q(self, value: sc.Variable) -> None: """Set the Q values for the dataset.""" raise AttributeError('Q is a read-only property derived from the data.') @property - def energy(self) -> sc.Variable: + def energy(self) -> sc.Variable | None: """Get the energy values from the dataset.""" if self._data is None: - warnings.warn('No data loaded.', UserWarning) return None return self._binned_data.coords['energy'] @energy.setter - def energy(self, value: sc.Variable): + def energy(self, value: sc.Variable) -> None: """Set the energy values for the dataset.""" raise AttributeError('energy is a read-only property derived from the data.') @@ -215,7 +213,7 @@ def plot_data(self, slicer=False, **kwargs) -> None: if self._binned_data is None: raise ValueError('No data to plot. Please load data first.') - if not self._in_notebook(): + if not _in_notebook(): raise RuntimeError('plot_data() can only be used in a Jupyter notebook environment.') from IPython.display import display @@ -241,26 +239,6 @@ def plot_data(self, slicer=False, **kwargs) -> None: # private methods ########### - @staticmethod - def _in_notebook() -> bool: - """Check if the code is running in a Jupyter notebook. - - Returns: - bool: True if in a Jupyter notebook, False otherwise. - """ - try: - from IPython import get_ipython - - shell = get_ipython().__class__.__name__ - if shell == 'ZMQInteractiveShell': - return True # Jupyter notebook or JupyterLab - elif shell == 'TerminalInteractiveShell': - return False # Terminal IPython - else: - return False - except (NameError, ImportError): - return False # Standard Python (no IPython) - @staticmethod def _validate_coordinates(data: sc.DataArray) -> None: """Validate that required coordinates are present in the data. diff --git a/src/easydynamics/sample_model/__init__.py b/src/easydynamics/sample_model/__init__.py index 5929fc50..1f1602aa 100644 --- a/src/easydynamics/sample_model/__init__.py +++ b/src/easydynamics/sample_model/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors # SPDX-License-Identifier: BSD-3-Clause +from .background_model import BackgroundModel from .component_collection import ComponentCollection from .components import DampedHarmonicOscillator from .components import DeltaFunction @@ -9,6 +10,9 @@ from .components import Polynomial from .components import Voigt from .diffusion_model.brownian_translational_diffusion import BrownianTranslationalDiffusion +from .instrument_model import InstrumentModel +from .resolution_model import ResolutionModel +from .sample_model import SampleModel __all__ = [ 'ComponentCollection', @@ -19,4 +23,8 @@ 'DampedHarmonicOscillator', 'Polynomial', 'BrownianTranslationalDiffusion', + 'SampleModel', + 'ResolutionModel', + 'BackgroundModel', + 'InstrumentModel', ] diff --git a/src/easydynamics/sample_model/component_collection.py b/src/easydynamics/sample_model/component_collection.py index 586a6649..f04d7ae8 100644 --- a/src/easydynamics/sample_model/component_collection.py +++ b/src/easydynamics/sample_model/component_collection.py @@ -67,13 +67,28 @@ def __init__( self.append_component(comp) def append_component(self, component: ModelComponent | 'ComponentCollection') -> None: - match component: - case ModelComponent(): - components = (component,) - case ComponentCollection(components=components): - pass - case _: - raise TypeError('Component must be a ModelComponent or ComponentCollection.') + """Append a model component or the components from another + ComponentCollection to this ComponentCollection. + + Parameters + ---------- + component : ModelComponent or ComponentCollection + The component to append. + Raises + ------ + TypeError + If the component is not a ModelComponent or + ComponentCollection. + """ + if not isinstance(component, (ModelComponent, ComponentCollection)): + raise TypeError( + 'Component must be an instance of ModelComponent or ComponentCollection. ' + f'Got {type(component).__name__} instead.' + ) + if isinstance(component, ModelComponent): + components = (component,) + if isinstance(component, ComponentCollection): + components = component.components for comp in components: if comp in self._components: @@ -116,6 +131,17 @@ def components(self, components: List[ModelComponent]) -> None: self._components = components + @property + def is_empty(self) -> bool: + return not self._components + + @is_empty.setter + def is_empty(self, value: bool) -> None: + raise AttributeError( + 'is_empty is a read-only property that indicates ' + 'whether the collection has components.' + ) + def list_component_names(self) -> List[str]: """List the names of all components in the model. diff --git a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py index d277d227..2853132a 100644 --- a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py +++ b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py @@ -218,6 +218,7 @@ def create_component_collections( lorentzian_component.width.make_dependent_on( dependency_expression=dependency_expression, dependency_map=dependency_map, + desired_unit=self.unit, ) # Make the area dependent on Q @@ -227,9 +228,6 @@ def create_component_collections( dependency_map=area_dependency_map, ) - # Resolving the dependency can do weird things to the units, - # so we make sure it's correct. - lorentzian_component.width.convert_unit(self.unit) component_collection_list[i].append_component(lorentzian_component) return component_collection_list diff --git a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py index 8bb65480..d6ab64b6 100644 --- a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py +++ b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py @@ -50,11 +50,11 @@ def __init__( unit : str or sc.Unit, optional Energy unit for the underlying Lorentzian components. Defaults to "meV". - scale : float , optional + scale : float, optional Scale factor for the diffusion model. - diffusion_coefficient : float , optional + diffusion_coefficient : float, optional Diffusion coefficient D in m^2/s. Defaults to 1.0. - relaxation_time : float , optional + relaxation_time : float, optional Relaxation time t in ps. Defaults to 1.0. """ super().__init__( @@ -254,6 +254,7 @@ def create_component_collections( lorentzian_component.width.make_dependent_on( dependency_expression=dependency_expression, dependency_map=dependency_map, + desired_unit=self.unit, ) # Make the area dependent on Q @@ -263,9 +264,6 @@ def create_component_collections( dependency_map=area_dependency_map, ) - # Resolving the dependency can do weird things to the units, - # so we make sure it's correct. - lorentzian_component.width.convert_unit(self.unit) component_collection_list[i].append_component(lorentzian_component) return component_collection_list diff --git a/src/easydynamics/sample_model/instrument_model.py b/src/easydynamics/sample_model/instrument_model.py index bef6bd92..33b6aacb 100644 --- a/src/easydynamics/sample_model/instrument_model.py +++ b/src/easydynamics/sample_model/instrument_model.py @@ -258,6 +258,32 @@ def free_resolution_parameters(self) -> None: """Free all parameters in the resolution model.""" self.resolution_model.free_all_parameters() + def get_energy_offset_at_Q(self, Q_index: int) -> Parameter: + """Get the energy offset Parameter at a specific Q index. + + Parameters + ---------- + Q_index : int + The index of the Q value to get the energy offset for. + + Returns + ------- + Parameter + The energy offset Parameter at the specified Q index. + + Raises + ------ + IndexError + If Q_index is out of bounds. + """ + if self._Q is None: + raise ValueError('No Q values are set in the InstrumentModel.') + + if Q_index < 0 or Q_index >= len(self._Q): + raise IndexError(f'Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}') + + return self._energy_offsets[Q_index] + # -------------------------------------------------------------- # Private methods # -------------------------------------------------------------- diff --git a/src/easydynamics/sample_model/model_base.py b/src/easydynamics/sample_model/model_base.py index b6b8bcdd..570234a2 100644 --- a/src/easydynamics/sample_model/model_base.py +++ b/src/easydynamics/sample_model/model_base.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2025-2026 EasyDynamics contributors # SPDX-License-Identifier: BSD-3-Clause -import warnings from copy import copy import numpy as np @@ -192,7 +191,17 @@ def Q(self) -> np.ndarray | None: @Q.setter def Q(self, value: Q_type | None) -> None: """Set the Q values of the SampleModel.""" - self._Q = _validate_and_convert_Q(value) + old_Q = self._Q + new_Q = _validate_and_convert_Q(value) + + if ( + old_Q is not None + and new_Q is not None + and len(old_Q) == len(new_Q) + and all(np.isclose(old_Q, new_Q)) + ): + return # No change in Q, so do nothing + self._Q = new_Q self._on_Q_change() # ------------------------------------------------------------------ @@ -241,26 +250,42 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: all_vars = self._component_collections[Q_index].get_all_variables() return all_vars + def get_component_collection(self, Q_index: int) -> ComponentCollection: + """Get the ComponentCollection at the given Q index. + + Parameters + ---------- + Q_index : int + The index of the desired ComponentCollection. + + Returns + ------- + ComponentCollection + The ComponentCollection at the specified Q index. + """ + if not isinstance(Q_index, int): + raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}') + if Q_index < 0 or Q_index >= len(self._component_collections): + raise IndexError( + f'Q_index {Q_index} is out of bounds for component collections ' + f'of length {len(self._component_collections)}' + ) + return self._component_collections[Q_index] + # ------------------------------------------------------------------ # Private methods # ------------------------------------------------------------------ def _generate_component_collections(self) -> None: """Generate ComponentCollections for each Q value.""" - # TODO regenerate automatically if Q or components have changed if self._Q is None: - warnings.warn('Q is not set. No component collections generated', UserWarning) self._component_collections = [] return - self._component_collections = [ComponentCollection() for _ in self._Q] - - # Add copies of components from self._components to each - # component collection - for collection in self._component_collections: - for component in self._components.components: - collection.append_component(copy(component)) + self._component_collections = [] + for _ in self._Q: + self._component_collections.append(copy(self._components)) def _on_Q_change(self) -> None: """Handle changes to the Q values.""" diff --git a/src/easydynamics/utils/utils.py b/src/easydynamics/utils/utils.py index 576b451d..e3cc842d 100644 --- a/src/easydynamics/utils/utils.py +++ b/src/easydynamics/utils/utils.py @@ -68,3 +68,23 @@ def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None: if isinstance(unit, str): unit = sc.Unit(unit) return unit + + +def _in_notebook() -> bool: + """Check if the code is running in a Jupyter notebook. + + Returns: + bool: True if in a Jupyter notebook, False otherwise. + """ + try: + from IPython import get_ipython + + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell': + return True # Jupyter notebook or JupyterLab + elif shell == 'TerminalInteractiveShell': + return False # Terminal IPython + else: + return False + except (NameError, ImportError): + return False # Standard Python (no IPython) diff --git a/tests/unit/easydynamics/analysis/test_analysis.py b/tests/unit/easydynamics/analysis/test_analysis.py new file mode 100644 index 00000000..812bcac9 --- /dev/null +++ b/tests/unit/easydynamics/analysis/test_analysis.py @@ -0,0 +1,630 @@ +from unittest.mock import MagicMock +from unittest.mock import PropertyMock +from unittest.mock import patch + +import numpy as np +import pytest +import scipp as sc + +from easydynamics.analysis.analysis import Analysis +from easydynamics.experiment import Experiment +from easydynamics.sample_model import InstrumentModel +from easydynamics.sample_model import SampleModel +from easydynamics.sample_model.components.gaussian import Gaussian + + +class TestAnalysis: + @pytest.fixture + def analysis(self): + Q = sc.array(dims=['Q'], values=[1, 2, 3], unit='1/Angstrom') + energy = sc.array(dims=['energy'], values=[10.0, 20.0, 30.0], unit='meV') + data = sc.array( + dims=['Q', 'energy'], + values=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + variances=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + ) + + data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy}) + + experiment = Experiment(data=data_array) + sample_model = SampleModel(components=Gaussian(), display_name='Gaussian') + instrument_model = InstrumentModel() + + analysis = Analysis( + display_name='TestAnalysis', + experiment=experiment, + sample_model=sample_model, + instrument_model=instrument_model, + extra_parameters=None, + ) + + return analysis + + def test_init(self, analysis): + # WHEN THEN + + # EXPECT + assert analysis.display_name == 'TestAnalysis' + assert isinstance(analysis._experiment, Experiment) + assert isinstance(analysis._sample_model, SampleModel) + assert isinstance(analysis._instrument_model, InstrumentModel) + assert analysis._extra_parameters == [] + assert np.array_equal(analysis.Q.values, [1, 2, 3]) + assert len(analysis.analysis_list) == 3 + + def test_init_raises_with_invalid_experiment(self): + # WHEN / THEN / EXPECT + with pytest.raises( + TypeError, + match='experiment must be an instance of Experiment', + ): + Analysis(experiment='invalid_experiment') + + def test_analysis_list_contains_all_Q_indices(self, analysis): + # WHEN THEN + + # EXPECT + assert len(analysis.analysis_list) == 3 + for i in range(3): + assert analysis.analysis_list[i].Q_index == i + + def test_analysis_list_setter_raises(self, analysis): + # WHEN / THEN / EXPECT + with pytest.raises( + AttributeError, + match='analysis_list is read-only', + ): + analysis.analysis_list = 'invalid_analysis_list' + + def test_calculate_with_Q_index(self, analysis): + # WHEN + analysis.analysis_list[1].calculate = MagicMock(return_value=np.array([4.0, 5.0, 6.0])) + + # THEN + result = analysis.calculate(Q_index=1) + + # EXPECT + analysis.analysis_list[1].calculate.assert_called_once() + np.testing.assert_array_equal(result, np.array([4.0, 5.0, 6.0])) + + def test_calculate_without_Q_index(self, analysis): + # WHEN + for i in range(3): + analysis.analysis_list[i].calculate = MagicMock( + return_value=np.array([1.0, 2.0, 3.0]) + i + ) + + # THEN + result = analysis.calculate() + + # EXPECT + for i in range(3): + analysis.analysis_list[i].calculate.assert_called_once() + np.testing.assert_array_equal(result[i], np.array([1.0, 2.0, 3.0]) + i) + + def test_calculate_with_invalid_Q_index(self, analysis): + # WHEN / THEN / EXPECT + with pytest.raises( + IndexError, + match='must be a valid index', + ): + analysis.calculate(Q_index=3) + + def test_fit_no_Q_values_raises(self, analysis): + # WHEN + analysis.experiment = Experiment() + + # THEN EXPECT + with pytest.raises( + ValueError, + match='No Q values available for fitting', + ): + analysis.fit() + + def test_fit_fit_method_independent_with_Q_index(self, analysis): + # WHEN + analysis.analysis_list[1].fit = MagicMock(return_value='fit_result_Q1') + + # THEN + result = analysis.fit(fit_method='independent', Q_index=1) + + # EXPECT + analysis.analysis_list[1].fit.assert_called_once() + assert result == 'fit_result_Q1' + + def test_fit_fit_method_independent_without_Q_index(self, analysis): + # WHEN + for i in range(3): + analysis.analysis_list[i].fit = MagicMock(return_value=f'fit_result_Q{i}') + + # THEN + result = analysis.fit(fit_method='independent') + + # EXPECT + for i in range(3): + analysis.analysis_list[i].fit.assert_called_once() + assert result[i] == f'fit_result_Q{i}' + + def test_fit_fit_method_simultaneous(self, analysis): + # WHEN + analysis._fit_all_Q_simultaneously = MagicMock(return_value='simultaneous_fit_result') + + # THEN + result = analysis.fit(fit_method='simultaneous') + + # EXPECT + analysis._fit_all_Q_simultaneously.assert_called_once() + assert result == 'simultaneous_fit_result' + + def test_fit_with_invalid_fit_method(self, analysis): + # WHEN / THEN / EXPECT + with pytest.raises( + ValueError, + match="Invalid fit method. Choose 'independent' or 'simultaneous'.", + ): + analysis.fit(fit_method='invalid_fit_method') + + def test_plot_data_and_model_not_in_notebook_raises(self, analysis): + # WHEN / THEN / EXPECT + with patch('easydynamics.analysis.analysis._in_notebook', return_value=False): + with pytest.raises( + RuntimeError, + match=' can only be used in a Jupyter notebook environment', + ): + analysis.plot_data_and_model() + + def test_plot_data_and_model_Q_index(self, analysis): + + # WHEN + analysis.analysis_list[1].plot_data_and_model = MagicMock(return_value='plot_Q1') + + kwargs = { + 'marker': {'amplitude': 'x', 'width': 's'}, + 'title': 'My Plot', + } + + # THEN + result = analysis.plot_data_and_model( + Q_index=1, plot_components=True, add_background=True, **kwargs + ) + + # EXPECT + analysis.analysis_list[1].plot_data_and_model.assert_called_once_with( + plot_components=True, add_background=True, **kwargs + ) + assert result == 'plot_Q1' + + def test_plot_data_and_model_no_data_raises(self, analysis): + # WHEN + analysis.experiment = Experiment() + + # THEN EXPECT + with pytest.raises( + ValueError, + match='No data to plot', + ): + analysis.plot_data_and_model() + + def test_plot_data_and_model_invalid_plot_components_raises(self, analysis): + # WHEN / THEN / EXPECT + + with ( + patch('easydynamics.analysis.analysis._in_notebook', return_value=True), + ): + with pytest.raises( + TypeError, + match='plot_components must be True or False', + ): + analysis.plot_data_and_model(plot_components='not_a_boolean') + + def test_plot_data_and_model_invalid_add_background_raises(self, analysis): + # WHEN / THEN / EXPECT + with ( + patch('easydynamics.analysis.analysis._in_notebook', return_value=True), + ): + with pytest.raises( + TypeError, + match='add_background must be True or False', + ): + analysis.plot_data_and_model(add_background='not_a_boolean') + + def test_plot_data_and_model_defaults(self, analysis): + + # WHEN + fake_fig = object() + + analysis._create_model_array = MagicMock(return_value='MODEL') + with ( + patch('plopp.slicer', return_value=fake_fig) as mock_slicer, + patch.object( + type(analysis.experiment), + 'binned_data', + new_callable=PropertyMock, + ) as mock_binned, + patch('easydynamics.analysis.analysis._in_notebook', return_value=True), + ): + mock_binned.return_value = 'DATA' + # THEN + fig = analysis.plot_data_and_model(plot_components=False) + + # EXPECT + mock_slicer.assert_called_once() + assert fig == fake_fig + # Inspect arguments passed to slicer + args, kwargs = mock_slicer.call_args + + data_passed = args[0] + assert 'Data' in data_passed + assert 'Model' in data_passed + + assert data_passed['Data'] == 'DATA' + assert data_passed['Model'] == 'MODEL' + + # Check the default kwargs + assert kwargs['title'] == 'TestAnalysis' + assert kwargs['linestyle'] == {'Data': 'none', 'Model': '-'} + assert kwargs['marker'] == {'Data': 'o', 'Model': None} + assert kwargs['color'] == {'Data': 'black', 'Model': 'red'} + assert kwargs['markerfacecolor'] == { + 'Data': 'none', + 'Model': 'none', + } + + def test_plot_data_and_model_plot_components_true(self, analysis): + + # WHEN + fake_fig = object() + + analysis._create_model_array = MagicMock(return_value='MODEL') + with ( + patch('plopp.slicer', return_value=fake_fig) as mock_slicer, + patch.object( + type(analysis.experiment), + 'binned_data', + new_callable=PropertyMock, + ) as mock_binned, + patch('easydynamics.analysis.analysis._in_notebook', return_value=True), + ): + mock_binned.return_value = 'DATA' + # THEN + fig = analysis.plot_data_and_model(plot_components=True) + + # EXPECT + mock_slicer.assert_called_once() + assert fig == fake_fig + # Inspect arguments passed to slicer + args, kwargs = mock_slicer.call_args + + data_passed = args[0] + assert 'Data' in data_passed + assert 'Model' in data_passed + + assert data_passed['Data'] == 'DATA' + assert data_passed['Model'] == 'MODEL' + # Check the default kwargs + assert kwargs['title'] == 'TestAnalysis' + assert kwargs['linestyle'] == {'Data': 'none', 'Model': '-', 'Gaussian': '--'} + assert kwargs['marker'] == {'Data': 'o', 'Model': None, 'Gaussian': None} + assert kwargs['color'] == {'Data': 'black', 'Model': 'red'} + assert kwargs['markerfacecolor'] == { + 'Data': 'none', + 'Model': 'none', + } + + def test_parameters_to_dataset(self, analysis): + # WHEN + analysis.sample_model.append_component(Gaussian(display_name='Gaussian2', area=0.5)) + # THEN + parameters_dataset = analysis.parameters_to_dataset() + + # EXPECT + assert isinstance(parameters_dataset, sc.Dataset) + parameter_names = [ + 'Gaussian area', + 'Gaussian center', + 'Gaussian width', + 'Gaussian2 area', + 'Gaussian2 center', + 'Gaussian2 width', + 'energy_offset', + ] + for parameter_name in parameter_names: + assert parameter_name in parameters_dataset + assert 'Q' in parameters_dataset[parameter_name].dims + + def test_parameters_to_dataset_different_units(self, analysis): + + # WHEN + analysis.sample_model.append_component(Gaussian(display_name='Gaussian2', area=0.5)) + + # Convert the unit of a component to eV. + analysis.sample_model.get_component_collection(Q_index=1).components[0].convert_unit('eV') + + # THEN + parameters_dataset = analysis.parameters_to_dataset() + + # EXPECT + assert isinstance(parameters_dataset, sc.Dataset) + parameter_names = [ + 'Gaussian area', + 'Gaussian center', + 'Gaussian width', + 'Gaussian2 area', + 'Gaussian2 center', + 'Gaussian2 width', + 'energy_offset', + ] + for parameter_name in parameter_names: + assert parameter_name in parameters_dataset + assert 'Q' in parameters_dataset[parameter_name].dims + + @pytest.mark.parametrize( + 'parameter_names', + [ + 123, # not str or list + ['parameter_name', 123], # list contains non-string + {'a': 1}, # completely wrong type + ], + ids=[ + 'not_string_or_list', + 'list_contains_non_string', + 'wrong_container_type', + ], + ) + def test_plot_parameters_raises_with_invalid_parameter_names(self, analysis, parameter_names): + + with pytest.raises( + TypeError, + match='names must be a string or a list of strings', + ): + analysis.plot_parameters(names=parameter_names) + + def test_plot_parameters_raises_with_nonexistent_parameter_names(self, analysis): + with pytest.raises( + ValueError, + match='not found in dataset', + ): + analysis.plot_parameters(names='nonexistent_parameter') + + def test_plot_parameters(self, analysis): + + # WHEN + + # Mock all the methods that are called. + fake_fig = object() + user_kwargs = { + 'title': 'My Plot', + 'marker': {'amplitude': 'x', 'width': 's'}, + } + + fake_dataset = { + 'amplitude': object(), + 'width': object(), + } + + analysis.parameters_to_dataset = MagicMock(return_value=fake_dataset) + + with patch('plopp.plot', return_value=fake_fig) as mock_plot: + # THEN + result = analysis.plot_parameters(**user_kwargs) + + # EXPECT + mock_plot.assert_called_once() + + # Inspect arguments + args, kwargs = mock_plot.call_args + + dataset_passed = args[0] + + assert dataset_passed == fake_dataset + + # Check default kwargs + assert 'linestyle' in kwargs + assert kwargs['linestyle'] == { + 'amplitude': 'none', + 'width': 'none', + } + + assert 'markerfacecolor' in kwargs + assert kwargs['markerfacecolor'] == { + 'amplitude': 'none', + 'width': 'none', + } + + # Check that user kwargs override defaults + assert kwargs['marker'] == user_kwargs['marker'] + assert kwargs['title'] == 'My Plot' + + # and that we return the figure + assert result is fake_fig + + def test_on_experiment_changed(self, analysis): + # WHEN + # Create a new experiment. + Q = sc.array(dims=['Q'], values=[2, 3, 4], unit='1/Angstrom') + energy = sc.array(dims=['energy'], values=[20.0, 30.0, 40.0], unit='meV') + data = sc.array( + dims=['Q', 'energy'], + values=[[2.0, 3.0, 4.0], [5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], + variances=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + ) + + data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy}) + + new_experiment = Experiment(data=data_array) + + # THEN (this call _on_experiment_changed internally) + analysis.experiment = new_experiment + + # EXPECT + assert np.array_equal(analysis.Q.values, [2, 3, 4]) + assert len(analysis.analysis_list) == 3 + for analysis in analysis.analysis_list: + assert analysis.experiment is new_experiment + + def test_on_sample_model_changed(self, analysis): + # WHEN + # Create a new sample model. + new_sample_model = SampleModel() + + # THEN (this call _on_sample_model_changed internally) + analysis.sample_model = new_sample_model + + # EXPECT + assert analysis.sample_model is new_sample_model + for analysis in analysis.analysis_list: + assert analysis.sample_model is new_sample_model + + def test_on_instrument_model_changed(self, analysis): + # WHEN + # Create a new instrument model. + new_instrument_model = InstrumentModel() + + # THEN (this call _on_instrument_model_changed internally) + analysis.instrument_model = new_instrument_model + + # EXPECT + assert analysis.instrument_model is new_instrument_model + for analysis in analysis.analysis_list: + assert analysis.instrument_model is new_instrument_model + + def test_fit_single_Q_valid(self, analysis): + # WHEN + analysis.analysis_list[1].fit = MagicMock(return_value='fit_result_Q1') + + # THEN + result = analysis._fit_single_Q(Q_index=1) + + # EXPECT + analysis.analysis_list[1].fit.assert_called_once() + assert result == 'fit_result_Q1' + + def test_fit_single_Q_invalid_Q_index(self, analysis): + # WHEN / THEN / EXPECT + with pytest.raises( + IndexError, + match='must be a valid index', + ): + analysis._fit_single_Q(Q_index=3) + + def test_fit_all_Q_independently(self, analysis): + # WHEN + for i in range(3): + analysis.analysis_list[i].fit = MagicMock(return_value=f'fit_result_Q{i}') + + # THEN + result = analysis._fit_all_Q_independently() + + # EXPECT + for i in range(3): + analysis.analysis_list[i].fit.assert_called_once() + assert result[i] == f'fit_result_Q{i}' + + def test_fit_all_Q_simultaneously(self, analysis): + # WHEN + # Mock the MultiFitter and its fit method + + fake_fit_result = object() + + fake_fitter_instance = MagicMock() + fake_fitter_instance.fit.return_value = fake_fit_result + + # Also mock the get_fit_functions method to return a list of fit + # functions for each Q index + analysis.get_fit_functions = MagicMock( + return_value=['fit_function_Q0', 'fit_function_Q1', 'fit_function_Q2'] + ) + with patch( + 'easydynamics.analysis.analysis.MultiFitter', + return_value=fake_fitter_instance, + ) as mock_fitter: + result = analysis._fit_all_Q_simultaneously() + + # EXPECT + # Check that the correct objects are passed to the MultiFitter + expected_fit_objects = analysis.analysis_list + expected_fit_functions = analysis.get_fit_functions() + mock_fitter.assert_called_once() + args, kwargs = mock_fitter.call_args + assert kwargs['fit_objects'] == expected_fit_objects + assert kwargs['fit_functions'] == expected_fit_functions + + # And check that the correct x, y, and weights arrays are passed + # to the fit method of the MultiFitter + expected_xs = [] + expected_ys = [] + expected_ws = [] + for analysis in analysis.analysis_list: + data = analysis.experiment.data['Q', analysis.Q_index] + + expected_xs.append(data.coords['energy'].values) + expected_ys.append(data.values) + expected_ws.append(1.0 / np.sqrt(data.variances)) + fake_fitter_instance.fit.assert_called_once() + + args, kwargs = fake_fitter_instance.fit.call_args + np.testing.assert_array_equal(kwargs['x'], expected_xs) + np.testing.assert_array_equal(kwargs['y'], expected_ys) + np.testing.assert_array_equal(kwargs['weights'], expected_ws) + + # And that the result from the fit method is returned + assert result == fake_fit_result + + def test_get_fit_functions(self, analysis): + # WHEN + + # THEN + fit_functions = analysis.get_fit_functions() + + # EXPECT + assert isinstance(fit_functions, list) + assert len(fit_functions) == len(analysis.analysis_list) + for fit_function in fit_functions: + assert callable(fit_function) + + def test_create_model_array(self, analysis): + # WHEN + analysis.calculate = MagicMock( + return_value=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + ) + + # THEN + model_array = analysis._create_model_array() + + # EXPECT + analysis.calculate.assert_called_once() + assert isinstance(model_array, sc.DataArray) + assert 'Q' in model_array.dims + assert 'energy' in model_array.dims + assert sc.identical(model_array.coords['Q'], analysis.Q) + assert sc.identical( + model_array.coords['energy'], analysis.experiment.data.coords['energy'] + ) + np.testing.assert_array_equal( + model_array.values, + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), + ) + + def test_create_components_dataset_raises(self, analysis): + # WHEN / THEN / EXPECT + with pytest.raises( + TypeError, + match='add_background must be True or False', + ): + analysis._create_components_dataset(add_background='123') + + def test_create_components_dataset(self, analysis): + # WHEN + # Add another component so that there are two components + analysis.sample_model.append_component(Gaussian(display_name='Gaussian2', area=0.5)) + + # THEN + components_dataset = analysis._create_components_dataset(add_background=True) + + # THEN EXPECT + assert isinstance(components_dataset, sc.Dataset) + component_names = [comp.display_name for comp in analysis.sample_model.components] + for component_name in component_names: + assert component_name in components_dataset + assert 'Q' in components_dataset[component_name].dims + assert 'energy' in components_dataset[component_name].dims diff --git a/tests/unit/easydynamics/analysis/test_analysis1d.py b/tests/unit/easydynamics/analysis/test_analysis1d.py new file mode 100644 index 00000000..3ddf74b3 --- /dev/null +++ b/tests/unit/easydynamics/analysis/test_analysis1d.py @@ -0,0 +1,780 @@ +from collections import Counter +from unittest.mock import MagicMock +from unittest.mock import patch + +import numpy as np +import pytest +import scipp as sc +from easyscience.variable import Parameter + +from easydynamics.analysis.analysis1d import Analysis1d +from easydynamics.experiment import Experiment +from easydynamics.sample_model import InstrumentModel +from easydynamics.sample_model import SampleModel +from easydynamics.sample_model.component_collection import ComponentCollection +from easydynamics.sample_model.components.gaussian import Gaussian +from easydynamics.sample_model.components.polynomial import Polynomial + + +class TestAnalysis1d: + @pytest.fixture + def analysis1d(self): + Q = sc.array(dims=['Q'], values=[1, 2, 3], unit='1/Angstrom') + energy = sc.array(dims=['energy'], values=[10.0, 20.0, 30.0], unit='meV') + data = sc.array( + dims=['Q', 'energy'], + values=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + variances=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + ) + + data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy}) + + experiment = Experiment(data=data_array) + sample_model = SampleModel(components=Gaussian()) + instrument_model = InstrumentModel() + + analysis1d = Analysis1d( + display_name='TestAnalysis', + experiment=experiment, + sample_model=sample_model, + instrument_model=instrument_model, + Q_index=0, + extra_parameters=None, + ) + + return analysis1d + + def test_init(self, analysis1d): + # WHEN THEN + + # EXPECT + assert analysis1d.display_name == 'TestAnalysis' + assert isinstance(analysis1d._experiment, Experiment) + assert isinstance(analysis1d._sample_model, SampleModel) + assert isinstance(analysis1d._instrument_model, InstrumentModel) + assert analysis1d._extra_parameters == [] + assert np.array_equal(analysis1d.Q.values, [1, 2, 3]) + assert analysis1d.Q_index == 0 + + def test_init_no_experiment(self): + # WHEN + analysis1d = Analysis1d(display_name='TestAnalysisNoExperiment') + + # THEN EXPECT + assert isinstance(analysis1d._experiment, Experiment) + assert analysis1d._convolver is None + + def test_Q_index_setter(self, analysis1d): + # WHEN + analysis1d.Q_index = 1 + + # THEN / EXPECT + assert analysis1d.Q_index == 1 + + @pytest.mark.parametrize( + 'invalid_Q_index, expected_exception, expected_message', + [ + (-1, IndexError, 'Q_index must be'), + (10, IndexError, 'Q_index must be'), + ('invalid', IndexError, 'Q_index must be '), + (np.nan, IndexError, 'Q_index must be '), + ([1, 2], IndexError, 'Q_index must be '), + ], + ids=[ + 'Negative index', + 'Index out of range', + 'Non-integer string', + 'NaN value', + 'List instead of integer', + ], + ) + def test_Q_index_setter_incorrect_Q( + self, analysis1d, invalid_Q_index, expected_exception, expected_message + ): + # WHEN / THEN / EXPECT + with pytest.raises(expected_exception, match=expected_message): + analysis1d.Q_index = invalid_Q_index + + def test_calculate_updates_convolver_and_calls_calculate(self, analysis1d): + # WHEN + + # mock the _create_convolver and _calculate methods to verify + # they are called + fake_convolver = object() + expected_result = np.array([42.0]) + + analysis1d._create_convolver = MagicMock(return_value=fake_convolver) + analysis1d._calculate = MagicMock(return_value=expected_result) + + # THEN + result = analysis1d.calculate() + + # EXPECT + + analysis1d._create_convolver.assert_called_once() + assert analysis1d._convolver is fake_convolver + analysis1d._calculate.assert_called_once() + np.testing.assert_array_equal(result, expected_result) + + def test__calculate_adds_sample_and_background(self, analysis1d): + sample = np.array([1.0, 2.0, 3.0]) + background = np.array([0.5, 0.5, 0.5]) + + analysis1d._evaluate_sample = MagicMock(return_value=sample) + analysis1d._evaluate_background = MagicMock(return_value=background) + + result = analysis1d._calculate() + + np.testing.assert_array_equal(result, sample + background) + + analysis1d._evaluate_sample.assert_called_once() + analysis1d._evaluate_background.assert_called_once() + + def test_fit_raises_if_no_experiment(self, analysis1d): + # WHEN THEN + analysis1d._experiment = None + + # EXPECT + with pytest.raises(ValueError, match='No experiment'): + analysis1d.fit() + + def test_fit_calls_fitter_with_correct_arguments(self, analysis1d): + + # WHEN + + # Mock all the methods that are called during fit to verify they + # are called with the correct arguments + fake_x = np.array([1, 2, 3]) + fake_y = np.array([10, 20, 30]) + fake_weights = np.array([0.1, 0.2, 0.3]) + + analysis1d._extract_x_y_weights_from_experiment = MagicMock( + return_value=(fake_x, fake_y, fake_weights) + ) + + analysis1d._create_convolver = MagicMock(return_value='fake_convolver') + + fake_fit_result = object() + fake_fitter_instance = MagicMock() + fake_fitter_instance.fit.return_value = fake_fit_result + + with patch( + 'easydynamics.analysis.analysis1d.EasyScienceFitter', + return_value=fake_fitter_instance, + ) as mock_fitter: + analysis1d.as_fit_function = MagicMock(return_value='fit_func') + + # THEN + result = analysis1d.fit() + + # EXPECT + + # Check that all the mocked methods were called with the correct + # arguments + analysis1d._create_convolver.assert_called_once() + + mock_fitter.assert_called_once_with( + fit_object=analysis1d, + fit_function='fit_func', + ) + + analysis1d._extract_x_y_weights_from_experiment.assert_called_once() + + fake_fitter_instance.fit.assert_called_once_with( + x=fake_x, + y=fake_y, + weights=fake_weights, + ) + + # And that the result is returned + assert analysis1d._fit_result is fake_fit_result + assert result is fake_fit_result + + def test_as_fit_function_calls_calculate(self, analysis1d): + # WHEN + expected = np.array([1.0, 2.0, 3.0]) + analysis1d._calculate = MagicMock(return_value=expected) + + # THEN + fit_func = analysis1d.as_fit_function() + + # EXPECT + assert callable(fit_func) + + # THEN + # call the fit function with some x values + result = fit_func(x=[1, 2, 3]) # should be ignored + + # EXPECT + analysis1d._calculate.assert_called_once() + + assert result is expected + + def test_get_all_variables(self, analysis1d): + # WHEN + extra_par1 = Parameter(name='extra_par1', value=1.0) + extra_par2 = Parameter(name='extra_par2', value=2.0) + analysis1d._extra_parameters = [extra_par1, extra_par2] + + # THEN + variables = analysis1d.get_all_variables() + + # EXPECT + assert isinstance(variables, list) + sample_vars = analysis1d.sample_model.get_all_variables(Q_index=analysis1d.Q_index) + instrument_vars = analysis1d.instrument_model.get_all_variables(Q_index=analysis1d.Q_index) + extra_vars = [extra_par1, extra_par2] + expected_vars = sample_vars + instrument_vars + extra_vars + assert Counter(variables) == Counter(expected_vars) + + def test_plot_raises_if_no_data(self, analysis1d): + analysis1d.experiment._data = None + + with pytest.raises(ValueError, match='No data'): + analysis1d.plot_data_and_model() + + def test_plot_calls_plopp_with_correct_arguments(self, analysis1d): + # WHEN + + # Mock the data and model components to be plotted + fake_model = sc.DataArray(data=sc.array(dims=['energy'], values=[1, 2, 3])) + analysis1d._create_sample_scipp_array = MagicMock(return_value=fake_model) + + fake_components = sc.Dataset({ + 'Component1': sc.DataArray(data=sc.array(dims=['energy'], values=[0.1, 0.2, 0.3])) + }) + analysis1d._create_components_dataset_single_Q = MagicMock(return_value=fake_components) + + fake_fig = object() + + with patch('plopp.plot', return_value=fake_fig) as mock_plot: + # THEN + result = analysis1d.plot_data_and_model() + + # EXPECT + + # Ensure component dataset created + analysis1d._create_components_dataset_single_Q.assert_called_once() + + # Ensure plot called + mock_plot.assert_called_once() + + # Inspect arguments + args, kwargs = mock_plot.call_args + + dataset_passed = args[0] + + assert 'Data' in dataset_passed + assert 'Model' in dataset_passed + assert 'Component1' in dataset_passed + + assert result is fake_fig + + ############# + # Private methods: small utilities + ############# + + def test_require_Q_index(self, analysis1d): + # WHEN THEN + Q_index = analysis1d._require_Q_index() + + # EXPECT + assert Q_index == analysis1d.Q_index + + def test_require_Q_index_raises_if_no_Q_index(self, analysis1d): + # WHEN THEN + analysis1d._Q_index = None + + # EXPECT + with pytest.raises(ValueError, match='Q_index must be set'): + analysis1d._require_Q_index() + + def test_on_Q_index_changed(self, analysis1d): + # WHEN + analysis1d._create_convolver = MagicMock() + + # THEN + analysis1d._on_Q_index_changed() + + # EXPECT + analysis1d._create_convolver.assert_called_once() + + ############# + # Private methods: evaluation + ############# + + def test_evaluate_components_no_components(self, analysis1d): + # WHEN + components = ComponentCollection() + + # THEN + result = analysis1d._evaluate_components(components=components) + + # EXPECT + assert isinstance(result, np.ndarray) + assert result.shape == (len(analysis1d.experiment.energy),) + assert np.all(result == 0.0) + + def test_evaluate_components_no_convolution(self, analysis1d): + # WHEN + components = Polynomial(coefficients=[1.0]) + # THEN + result = analysis1d._evaluate_components( + components=components, convolver=None, convolve=False + ) + # EXPECT + assert np.array_equal(result, np.array([1.0, 1.0, 1.0])) + + def test_evaluate_components_convolution(self, analysis1d): + # WHEN + components = Gaussian() + convolver = MagicMock() + convolver.convolution = MagicMock(return_value=np.array([1, 2, 3])) + + # THEN + result = analysis1d._evaluate_components( + components=components, convolver=convolver, convolve=True + ) + + # EXPECT + convolver.convolution.assert_called_once() + assert result is convolver.convolution.return_value + + def test_evaluate_components_empty_resolution(self, analysis1d): + # WHEN + components = MagicMock() + components.evaluate = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) + + # The default analysis1d has no resolution model components, so + # no convolution should be applied even if convolve=True + + # THEN + result = analysis1d._evaluate_components( + components=components, convolver=None, convolve=True + ) + + # EXPECT + components.evaluate.assert_called_once() + assert np.array_equal(result, np.array([1.0, 2.0, 3.0])) + + def test_evaluate_with_resolution(self, analysis1d): + # WHEN (set up the resolution model and create a component to + # evaluate) + analysis1d.instrument_model.resolution_model.components = Gaussian() + components = Gaussian() + + with patch('easydynamics.analysis.analysis1d.Convolution') as MockConvolution: + # THEN + analysis1d._evaluate_components( + components=components, + convolver=None, + convolve=True, + ) + + # EXPECT + # Ensure constructor called once + MockConvolution.assert_called_once() + + # The convolver should be created with the correct arguments + resolution_components = ( + analysis1d.instrument_model.resolution_model.get_component_collection( + analysis1d.Q_index + ) + ) + + energy_offset = analysis1d.instrument_model.get_energy_offset_at_Q(analysis1d.Q_index) + + # Extract call arguments + _, kwargs = MockConvolution.call_args + + assert kwargs['sample_components'] == components + assert kwargs['resolution_components'] == resolution_components + assert kwargs['temperature'] == analysis1d.temperature + assert kwargs['energy_offset'] == energy_offset + + # check that the energy array passed to the convolver is the + # same as the analysis1d energy array + np.testing.assert_array_equal( + kwargs['energy'], + analysis1d.energy.values, + ) + + # and check that convolution() was called + MockConvolution.return_value.convolution.assert_called_once_with() + + def test_evaluate_sample(self, analysis1d): + # WHEN + analysis1d.sample_model.get_component_collection = MagicMock() + analysis1d._evaluate_components = MagicMock() + + # THEN + analysis1d._evaluate_sample() + + # EXPECT + + # The correct component collection is requested with the correct + # Q_index + analysis1d.sample_model.get_component_collection.assert_called_once_with( + Q_index=analysis1d.Q_index + ) + + # The components are evaluated with the correct convolver and + # convolve=True + analysis1d._evaluate_components.assert_called_once_with( + components=analysis1d.sample_model.get_component_collection(), + convolver=analysis1d._convolver, + convolve=True, + ) + + def test_evaluate_sample_component(self, analysis1d): + # WHEN + analysis1d._evaluate_components = MagicMock() + component = object() + + # THEN + analysis1d._evaluate_sample_component(component=component) + + # EXPECT + + # The components are evaluated with the correct convolver and + # convolve=True + analysis1d._evaluate_components.assert_called_once_with( + components=component, + convolver=None, + convolve=True, + ) + + def test_evaluate_background(self, analysis1d): + # WHEN + analysis1d.instrument_model.background_model.get_component_collection = MagicMock() + analysis1d._evaluate_components = MagicMock() + + # THEN + analysis1d._evaluate_background() + + # EXPECT + + # The correct component collection is requested with the correct + # Q_index + analysis1d.instrument_model.background_model.get_component_collection.assert_called_once_with( + Q_index=analysis1d.Q_index + ) + + # The components are evaluated with the correct convolver and + # convolve=True + analysis1d._evaluate_components.assert_called_once_with( + components=analysis1d.instrument_model.background_model.get_component_collection(), + convolver=None, + convolve=False, + ) + + def test_evaluate_background_component(self, analysis1d): + # WHEN + analysis1d._evaluate_components = MagicMock() + component = object() + + # THEN + analysis1d._evaluate_background_component(component=component) + + # EXPECT + + # The components are evaluated with the correct convolver and + # convolve=True + analysis1d._evaluate_components.assert_called_once_with( + components=component, + convolver=None, + convolve=False, + ) + + def test_create_convolver(self, analysis1d): + # WHEN + # Mock sample components + sample_components = MagicMock() + sample_components.is_empty = False + + # Mock resolution components + resolution_components = MagicMock() + resolution_components.is_empty = False + + # And all the other inputs to the convolver + analysis1d.sample_model.get_component_collection = MagicMock( + return_value=sample_components + ) + + analysis1d.instrument_model.resolution_model.get_component_collection = MagicMock( + return_value=resolution_components + ) + + analysis1d.instrument_model.get_energy_offset_at_Q = MagicMock(return_value=123.0) + + with patch('easydynamics.analysis.analysis1d.Convolution') as MockConvolution: + # THEN + result = analysis1d._create_convolver() + + # EXPECT + # Check the convolver was created with the correct arguments + MockConvolution.assert_called_once() + + _, kwargs = MockConvolution.call_args + + assert kwargs['sample_components'] is sample_components + assert kwargs['resolution_components'] is resolution_components + assert sc.identical(kwargs['energy'], analysis1d.energy) + assert kwargs['temperature'] is analysis1d.temperature + assert kwargs['energy_offset'] == 123.0 + + assert result == MockConvolution.return_value + + def test_create_convolver_returns_none_if_no_resolution_components(self, analysis1d): + # WHEN + analysis1d.instrument_model.resolution_model.clear_components() + + # THEN + convolver = analysis1d._create_convolver() + + # EXPECT + assert convolver is None + + def test_create_convolver_returns_none_if_no_sample_components(self, analysis1d): + # WHEN + analysis1d.sample_model.clear_components() + + # THEN + convolver = analysis1d._create_convolver() + + # EXPECT + assert convolver is None + + ############# + # Private methods: create scipp arrays for plotting + ############# + + @pytest.mark.parametrize( + 'background', + [ + None, + np.array([0.5, 0.5, 0.5]), + ], + ids=[ + 'No background', + 'With background', + ], + ) + def test_create_component_scipp_array(self, analysis1d, background): + """ + Test that _create_component_scipp_array correctly evaluates + the component, adds the background and calls _to_scipp_array + with the correct values. + """ + # WHEN + + # Mock the functions that will be called. + analysis1d._evaluate_sample_component = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) + + analysis1d._to_scipp_array = MagicMock() + + component = object() + + # THEN + analysis1d._create_component_scipp_array(component=component, background=background) + + # EXPECT + analysis1d._evaluate_sample_component.assert_called_once_with(component=component) + + expected_values = np.array([1.0, 2.0, 3.0]) + if background is not None: + expected_values += background + + analysis1d._to_scipp_array.assert_called_once() + + # Extract the actual call + _, kwargs = analysis1d._to_scipp_array.call_args + + np.testing.assert_array_equal( + kwargs['values'], + expected_values, + ) + + def test_create_background_component_scipp_array(self, analysis1d): + """Test that _create_background_component_scipp_array correctly + evaluates the component, adds the background and calls + _to_scipp_array with the correct values.""" + + # WHEN + + # Mock the functions that will be called. + analysis1d._evaluate_background_component = MagicMock( + return_value=np.array([1.0, 2.0, 3.0]) + ) + analysis1d._to_scipp_array = MagicMock() + + component = object() + + # THEN + analysis1d._create_background_component_scipp_array(component=component) + + # EXPECT + analysis1d._evaluate_background_component.assert_called_once_with(component=component) + + analysis1d._to_scipp_array.assert_called_once() + + # Extract the actual call + _, kwargs = analysis1d._to_scipp_array.call_args + + np.testing.assert_array_equal( + kwargs['values'], + np.array([1.0, 2.0, 3.0]), + ) + + def test_create_sample_scipp_array(self, analysis1d): + """Test that _create_sample_scipp_array correctly + evaluates the full model and calls _to_scipp_array with the + correct values.""" + + # WHEN + + # Mock the functions that will be called. + analysis1d._calculate = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) + analysis1d._to_scipp_array = MagicMock() + + # THEN + analysis1d._create_sample_scipp_array() + + # EXPECT + analysis1d._calculate.assert_called_once() + + analysis1d._to_scipp_array.assert_called_once() + + # Extract the actual call + _, kwargs = analysis1d._to_scipp_array.call_args + + np.testing.assert_array_equal( + kwargs['values'], + np.array([1.0, 2.0, 3.0]), + ) + + @pytest.mark.parametrize( + 'add_background', + [True, False], + ids=['With background', 'Without background'], + ) + def test_create_components_dataset_single_Q( + self, + analysis1d, + add_background, + ): + """Test orchestration of _create_components_dataset_single_Q.""" + + # WHEN + + # Choose a particular Q_index, but without using the setter to + # avoid validation logic + analysis1d._Q_index = 5 + + # Mock all the things + + # ---- Sample component ---- + sample_component = MagicMock() + sample_component.display_name = 'sample_comp' + + sample_collection = MagicMock() + sample_collection.components = [sample_component] + + analysis1d.sample_model.get_component_collection = MagicMock( + return_value=sample_collection + ) + + # ---- Background component ---- + background_component = MagicMock() + background_component.display_name = 'background_comp' + + background_collection = MagicMock() + background_collection.components = [background_component] + + analysis1d.instrument_model.background_model.get_component_collection = MagicMock( + return_value=background_collection + ) + + # ---- Background evaluation ---- + background_value = np.array([10.0, 20.0, 30.0]) + analysis1d._evaluate_background = MagicMock(return_value=background_value) + + # ---- Return scipp DataArrays ---- + fake_sample_da = sc.DataArray(data=sc.array(dims=['energy'], values=[1.0, 2.0, 3.0])) + + analysis1d._create_component_scipp_array = MagicMock(return_value=fake_sample_da) + + fake_background_da = sc.DataArray(data=sc.array(dims=['energy'], values=[4.0, 5.0, 6.0])) + + analysis1d._create_background_component_scipp_array = MagicMock( + return_value=fake_background_da + ) + + # THEN + dataset = analysis1d._create_components_dataset_single_Q(add_background=add_background) + + # EXPECT + + # The correct component collections are requested with the + # correct Q_index + analysis1d.sample_model.get_component_collection.assert_called_once_with( + Q_index=analysis1d.Q_index + ) + + analysis1d.instrument_model.background_model.get_component_collection.assert_called_once_with( + Q_index=analysis1d.Q_index + ) + + # Background is evaluated if add_background=True, and not + # evaluated if False + if add_background: + analysis1d._evaluate_background.assert_called_once() + expected_background = background_value + else: + analysis1d._evaluate_background.assert_not_called() + expected_background = None + + # The sample component scipp array is created with the correct + # component and background + analysis1d._create_component_scipp_array.assert_called_once() + _, kwargs = analysis1d._create_component_scipp_array.call_args + + assert kwargs['component'] is sample_component + + if expected_background is None: + assert kwargs['background'] is None + else: + np.testing.assert_array_equal( + kwargs['background'], + expected_background, + ) + + # Background component creation + analysis1d._create_background_component_scipp_array.assert_called_once_with( + component=background_component + ) + + # Dataset content + assert isinstance(dataset, sc.Dataset) + assert 'sample_comp' in dataset + assert 'background_comp' in dataset + + def test_to_scipp_array(self, analysis1d): + # WHEN + numpy_array = np.array([1.0, 2.0, 3.0]) + + # THEN + scipp_array = analysis1d._to_scipp_array(numpy_array) + + # EXPECT + assert isinstance(scipp_array, sc.DataArray) + np.testing.assert_array_equal(scipp_array.values, numpy_array) + + np.testing.assert_array_equal( + scipp_array.coords['energy'].values, analysis1d.experiment.energy.values + ) + + np.testing.assert_array_equal( + scipp_array.coords['Q'].values, + analysis1d.experiment.Q[analysis1d.Q_index].values, + ) diff --git a/tests/unit/easydynamics/analysis/test_analysis_base.py b/tests/unit/easydynamics/analysis/test_analysis_base.py new file mode 100644 index 00000000..f4e937ad --- /dev/null +++ b/tests/unit/easydynamics/analysis/test_analysis_base.py @@ -0,0 +1,368 @@ +# from unittest.mock import Mock + +from unittest.mock import PropertyMock +from unittest.mock import patch + +import numpy as np +import pytest +import scipp as sc +from easyscience.variable import Parameter + +from easydynamics.analysis.analysis_base import AnalysisBase +from easydynamics.experiment import Experiment +from easydynamics.sample_model import InstrumentModel +from easydynamics.sample_model import SampleModel + + +class TestAnalysisBase: + @pytest.fixture + def analysis_base(self): + experiment = Experiment() + sample_model = SampleModel() + instrument_model = InstrumentModel() + analysis_base = AnalysisBase( + display_name='TestAnalysis', + experiment=experiment, + sample_model=sample_model, + instrument_model=instrument_model, + ) + return analysis_base + + def test_init(self, analysis_base): + # WHEN THEN + + # EXPECT + assert analysis_base.display_name == 'TestAnalysis' + assert isinstance(analysis_base._experiment, Experiment) + assert isinstance(analysis_base._sample_model, SampleModel) + assert isinstance(analysis_base._instrument_model, InstrumentModel) + assert analysis_base._extra_parameters == [] + + def test_init_extra_parameter(self): + extra_parameter = Parameter(name='param1', value=1.0) + analysis = AnalysisBase(extra_parameters=extra_parameter) + assert analysis._extra_parameters == [extra_parameter] + + def test_init_extra_parameters(self): + extra_parameters = [ + Parameter(name='param1', value=1.0), + Parameter(name='param2', value=2.0), + ] + analysis = AnalysisBase(extra_parameters=extra_parameters) + assert analysis._extra_parameters == extra_parameters + + def test_init_calls_on_experiment_changed(self): + with patch.object(AnalysisBase, '_on_experiment_changed') as mock_on_experiment_changed: + AnalysisBase() + mock_on_experiment_changed.assert_called_once() + + @pytest.mark.parametrize( + 'kwargs, expected_exception, expected_message', + [ + ( + {'experiment': 123}, + TypeError, + 'experiment must be an instance of Experiment', + ), + ( + {'sample_model': 'not a model'}, + TypeError, + 'sample_model must be an instance of SampleModel', + ), + ( + {'instrument_model': 'not a model'}, + TypeError, + 'instrument_model must be an instance of InstrumentModel', + ), + ( + {'extra_parameters': 123}, + TypeError, + 'extra_parameters must be a Parameter or a list of Parameters.', + ), + ( + {'extra_parameters': [123]}, + TypeError, + 'extra_parameters must be a Parameter or a list of Parameters.', + ), + ], + ids=[ + 'invalid experiment', + 'invalid sample_model', + 'invalid instrument_model', + 'invalid extra_parameters', + 'invalid extra_parameters list', + ], + ) + def test_init_invalid_inputs(self, kwargs, expected_exception, expected_message): + with pytest.raises(expected_exception, match=expected_message): + AnalysisBase(**kwargs) + + def test_experiment_setter_calls_on_experiment_changed(self, analysis_base): + with patch.object(analysis_base, '_on_experiment_changed') as mock_on_experiment_changed: + new_experiment = Experiment() + analysis_base.experiment = new_experiment + mock_on_experiment_changed.assert_called_once() + + def test_experiment_setter_invalid_type(self, analysis_base): + with pytest.raises(TypeError, match='experiment must be an instance of Experiment'): + analysis_base.experiment = 'not an experiment' + + def test_experiment_setter_valid(self, analysis_base): + new_experiment = Experiment() + analysis_base.experiment = new_experiment + assert analysis_base.experiment == new_experiment + + def test_sample_model_setter_invalid_type(self, analysis_base): + with pytest.raises(TypeError, match='sample_model must be an instance of SampleModel'): + analysis_base.sample_model = 'not a sample model' + + def test_sample_model_setter_valid(self, analysis_base): + new_sample_model = SampleModel() + analysis_base.sample_model = new_sample_model + assert analysis_base.sample_model == new_sample_model + + def test_sample_model_setter_calls_on_sample_model_changed(self, analysis_base): + with patch.object( + analysis_base, '_on_sample_model_changed' + ) as mock_on_sample_model_changed: + new_sample_model = SampleModel() + analysis_base.sample_model = new_sample_model + mock_on_sample_model_changed.assert_called_once() + + def test_instrument_model_setter_invalid_type(self, analysis_base): + with pytest.raises( + TypeError, match='instrument_model must be an instance of InstrumentModel' + ): + analysis_base.instrument_model = 'not an instrument model' + + def test_instrument_model_setter_valid(self, analysis_base): + new_instrument_model = InstrumentModel() + analysis_base.instrument_model = new_instrument_model + assert analysis_base.instrument_model == new_instrument_model + + def test_instrument_model_setter_calls_on_instrument_model_changed(self, analysis_base): + with patch.object( + analysis_base, '_on_instrument_model_changed' + ) as mock_on_instrument_model_changed: + new_instrument_model = InstrumentModel() + analysis_base.instrument_model = new_instrument_model + mock_on_instrument_model_changed.assert_called_once() + + def test_Q_property(self, analysis_base): + # Create a mock Q value + fake_Q = [1, 2, 3] + + # Patch the 'experiment' attribute's Q property + with patch.object( + type(analysis_base.experiment), 'Q', new_callable=PropertyMock + ) as mock_Q: + mock_Q.return_value = fake_Q + result = analysis_base.Q # Access the property + assert result == fake_Q + mock_Q.assert_called_once() + + def test_Q_setter_raises(self, analysis_base): + with pytest.raises( + AttributeError, + match='Q is a read-only property derived from the Experiment.', + ): + analysis_base.Q = [1, 2, 3] + + def test_energy_property(self, analysis_base): + # Create a mock energy value + fake_energy = [10, 20, 30] + + # Patch the 'experiment' attribute's energy property + with patch.object( + type(analysis_base.experiment), 'energy', new_callable=PropertyMock + ) as mock_energy: + mock_energy.return_value = fake_energy + result = analysis_base.energy # Access the property + assert result == fake_energy + mock_energy.assert_called_once() + + def test_energy_setter_raises(self, analysis_base): + with pytest.raises( + AttributeError, + match='energy is a read-only property derived from the Experiment.', + ): + analysis_base.energy = [10, 20, 30] + + def test_temperature_property_no_temperature(self, analysis_base): + # Patch the 'experiment' attribute's temperature property to + # return None + with patch.object( + type(analysis_base.sample_model), 'temperature', new_callable=PropertyMock + ) as mock_temperature: + mock_temperature.return_value = None + result = analysis_base.temperature # Access the property + assert result is None + mock_temperature.assert_called_once() + + def test_temperature_property(self, analysis_base): + # Create a mock temperature value + fake_temperature = 300 + + # Patch the 'sample_model' attribute's temperature property + with patch.object( + type(analysis_base.sample_model), 'temperature', new_callable=PropertyMock + ) as mock_temperature: + mock_temperature.return_value = fake_temperature + result = analysis_base.temperature # Access the property + assert result == fake_temperature + mock_temperature.assert_called_once() + + def test_temperature_setter_raises(self, analysis_base): + with pytest.raises( + AttributeError, + match='temperature is a read-only property', + ): + analysis_base.temperature = 300 + + @pytest.mark.parametrize( + 'extra_parameters', + [ + Parameter(name='param1', value=1.0), + [ + Parameter(name='param1', value=1.0), + Parameter(name='param2', value=2.0), + ], + ], + ids=[ + 'single parameter', + 'list of parameters', + ], + ) + def test_extra_parameters_property(self, analysis_base, extra_parameters): + # WHEN + analysis_base.extra_parameters = extra_parameters + + # THEN + analysis_base.extra_parameters = extra_parameters + + # EXPECT + expected = ( + [extra_parameters] if isinstance(extra_parameters, Parameter) else extra_parameters + ) + + assert analysis_base.extra_parameters == expected + + @pytest.mark.parametrize( + 'invalid_extra_parameters', + [ + 'not a parameter', + [Parameter(name='param1', value=1.0), 'not a parameter'], + ], + ids=[ + 'single invalid parameter', + 'list with invalid parameter', + ], + ) + def test_extra_parameters_setter_invalid_type(self, analysis_base, invalid_extra_parameters): + with pytest.raises( + TypeError, + match='extra_parameters must be a Parameter or a list of Parameters.', + ): + analysis_base.extra_parameters = invalid_extra_parameters + + def test_on_experiment_changed_updates_Q(self, analysis_base): + # WHEN + fake_Q = [1, 2, 3] + + # Patch the Q property of analysis_base + with patch.object( + type(analysis_base.experiment), 'Q', new_callable=PropertyMock + ) as mock_Q: + mock_Q.return_value = fake_Q + + # THEN + analysis_base._on_experiment_changed() + + # EXPECT + # assert that the Q attribute was set + np.testing.assert_array_equal(analysis_base.Q, fake_Q) + np.testing.assert_array_equal(analysis_base.sample_model.Q, fake_Q) + np.testing.assert_array_equal(analysis_base.instrument_model.Q, fake_Q) + + def test_on_sample_model_changed_updates_Q(self, analysis_base): + # WHEN + fake_Q = [1, 2, 3] + + # Patch the Q property of analysis_base + with patch.object( + type(analysis_base.experiment), 'Q', new_callable=PropertyMock + ) as mock_Q: + mock_Q.return_value = fake_Q + + # THEN + analysis_base._on_sample_model_changed() + + # EXPECT + np.testing.assert_array_equal(analysis_base.sample_model.Q, fake_Q) + + def test_on_instrument_model_changed_updates_Q(self, analysis_base): + fake_Q = [1, 2, 3] + + # Patch the Q property of analysis_base + with patch.object( + type(analysis_base.experiment), 'Q', new_callable=PropertyMock + ) as mock_Q: + mock_Q.return_value = fake_Q + + analysis_base._on_instrument_model_changed() + np.testing.assert_array_equal(analysis_base.instrument_model.Q, fake_Q) + + def test_verify_Q_index_valid(self, analysis_base): + # WHEN + valid_Q_index = 0 + + # THEN + result = analysis_base._verify_Q_index(valid_Q_index) + + # EXPECT + assert result == valid_Q_index + + def test_verify_Q_index_invalid(self, analysis_base): + # WHEN + invalid_Q_index = -1 + + # THEN / EXPECT + with pytest.raises(IndexError, match='Q_index must be a valid index'): + analysis_base._verify_Q_index(invalid_Q_index) + + def test_extract_x_y_weights_from_experiment(self, analysis_base): + # WHEN + Q = sc.array(dims=['Q'], values=[1, 2, 3], unit='1/Angstrom') + energy = sc.array(dims=['energy'], values=[10.0, 20.0, 30.0], unit='meV') + data = sc.array( + dims=['Q', 'energy'], + values=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + variances=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + ) + + data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy}) + + experiment = Experiment(data=data_array) + analysis_base.experiment = experiment + + Q_index = 0 + + # THEN + x, y, weights = analysis_base._extract_x_y_weights_from_experiment(Q_index=Q_index) + + # EXPECT + assert np.array_equal(x, analysis_base.experiment.energy.values) + assert np.array_equal(y, analysis_base.experiment.data.values[Q_index]) + assert np.array_equal( + weights, + 1 / analysis_base.experiment.data.variances[Q_index] ** 0.5, + ) + + def test_repr(self, analysis_base): + # WHEN + repr_str = repr(analysis_base) + + # THEN EXPECT + assert 'AnalysisBase' in repr_str + assert 'display_name=TestAnalysis' in repr_str + assert 'unique_name=' in repr_str diff --git a/tests/unit/easydynamics/convolution/test_convolution.py b/tests/unit/easydynamics/convolution/test_convolution.py index b82b89cd..a2cfb193 100644 --- a/tests/unit/easydynamics/convolution/test_convolution.py +++ b/tests/unit/easydynamics/convolution/test_convolution.py @@ -51,6 +51,22 @@ def default_convolution(self): ) return conv + @pytest.fixture + def convolution_with_components(self): + energy = np.linspace(-10, 10, 5001) + sample_components = Gaussian(display_name='Gaussian1', area=2.0, center=0.1, width=0.4) + + resolution_components = Gaussian( + display_name='GaussianRes', area=3.0, center=0.2, width=0.5 + ) + + conv = Convolution( + energy=energy, + sample_components=sample_components, + resolution_components=resolution_components, + ) + return conv + def test_init(self, default_convolution): "Test initialization of Convolution with default parameters." # WHEN THEN EXPECT @@ -85,6 +101,42 @@ def test_init(self, default_convolution): assert default_convolution._convolution_plan_is_valid is True assert default_convolution._reactions_enabled is True + def test_init_components(self, convolution_with_components): + "Test initialization of Convolution with default parameters." + # WHEN THEN EXPECT + assert isinstance(convolution_with_components, Convolution) + assert isinstance(convolution_with_components.energy, sc.Variable) + assert np.allclose(convolution_with_components.energy.values, np.linspace(-10, 10, 5001)) + assert isinstance(convolution_with_components._sample_components, ComponentCollection) + assert isinstance(convolution_with_components._resolution_components, ComponentCollection) + assert convolution_with_components.upsample_factor == 5 + assert convolution_with_components.extension_factor == 0.2 + assert convolution_with_components.temperature is None + assert convolution_with_components.energy_unit == 'meV' + assert convolution_with_components.normalize_detailed_balance is True + assert isinstance(convolution_with_components._energy_grid, EnergyGrid) + + assert isinstance( + convolution_with_components._analytical_sample_components, + ComponentCollection, + ) + assert ( + convolution_with_components._analytical_sample_components.components[0] + is convolution_with_components.sample_components.components[0] + ) + assert isinstance( + convolution_with_components._numerical_sample_components, + ComponentCollection, + ) + assert convolution_with_components._numerical_sample_components.is_empty + + assert isinstance( + convolution_with_components._delta_sample_components, ComponentCollection + ) + assert convolution_with_components._delta_sample_components.is_empty + assert convolution_with_components._convolution_plan_is_valid is True + assert convolution_with_components._reactions_enabled is True + def test_convolution_plan_is_built_when_invalid(self, default_convolution): """ Test that convolution plan is built when invalid. diff --git a/tests/unit/easydynamics/convolution/test_convolution_base.py b/tests/unit/easydynamics/convolution/test_convolution_base.py index be6249c7..94272f95 100644 --- a/tests/unit/easydynamics/convolution/test_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_convolution_base.py @@ -4,8 +4,11 @@ import numpy as np import pytest import scipp as sc +from easyscience.variable import Parameter +from scipp import UnitError from easydynamics.convolution.convolution_base import ConvolutionBase +from easydynamics.sample_model import Gaussian from easydynamics.sample_model.component_collection import ComponentCollection @@ -30,6 +33,27 @@ def test_init(self, convolution_base): assert isinstance(convolution_base._sample_components, ComponentCollection) assert isinstance(convolution_base._resolution_components, ComponentCollection) + def test_init_with_model_component(self): + # WHEN + energy = np.linspace(-10, 10, 100) + sample_component = Gaussian() + resolution_component = Gaussian() + + convolution_base = ConvolutionBase( + energy=energy, + sample_components=sample_component, + resolution_components=resolution_component, + ) + + # THEN EXPECT + assert isinstance(convolution_base, ConvolutionBase) + assert isinstance(convolution_base.energy, sc.Variable) + assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100)) + assert isinstance(convolution_base.sample_components, ComponentCollection) + assert isinstance(convolution_base.resolution_components, ComponentCollection) + assert convolution_base.sample_components.components[0] == sample_component + assert convolution_base.resolution_components.components[0] == resolution_component + def test_init_energy_numerical_none_offset(self): # WHEN energy = 1 @@ -55,6 +79,7 @@ def test_init_energy_numerical_none_offset(self): 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), 'energy_unit': 'meV', + 'energy_offset': 0, }, 'Energy must be', ), @@ -64,6 +89,7 @@ def test_init_energy_numerical_none_offset(self): 'sample_components': 'invalid', 'resolution_components': ComponentCollection(), 'energy_unit': 'meV', + 'energy_offset': 0, }, ( '`sample_components` is an instance of str, ' @@ -76,6 +102,7 @@ def test_init_energy_numerical_none_offset(self): 'sample_components': ComponentCollection(), 'resolution_components': 'invalid', 'energy_unit': 'meV', + 'energy_offset': 0, }, ( '`resolution_components` is an instance of str, ' @@ -88,9 +115,20 @@ def test_init_energy_numerical_none_offset(self): 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), 'energy_unit': 123, + 'energy_offset': 0, }, 'Energy_unit must be ', ), + ( + { + 'energy': np.linspace(-10, 10, 100), + 'sample_components': ComponentCollection(), + 'resolution_components': ComponentCollection(), + 'energy_unit': 'meV', + 'energy_offset': 'invalid', + }, + 'Energy_offset must be ', + ), ], ) def test_input_type_validation_raises(self, kwargs, expected_message): @@ -164,6 +202,62 @@ def test_convert_energy_unit_invalid_type_raises(self, convolution_base): ): convolution_base.convert_energy_unit(123) + def test_convert_energy_unit_invalid_unit_rollback(self, convolution_base): + # WHEN THEN + with pytest.raises( + UnitError, + match='Conversion from `meV` to `s` is not valid.', + ): + convolution_base.convert_energy_unit('s') + + # EXPECT + assert convolution_base.energy_unit == 'meV' + assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100)) + + def test_convert_energy_unit_invalid_offset_unit_rollback(self, convolution_base): + # WHEN + convolution_base.energy_offset = Parameter(name='energy_offset', value=5, unit='s') + + # THEN + with pytest.raises( + UnitError, + match='Conversion from `s` to `meV` is not valid.', + ): + convolution_base.convert_energy_unit('meV') + + # EXPECT + assert convolution_base.energy_unit == 'meV' + assert convolution_base.energy_offset.unit == 's' + + def test_energy_offset_property(self, convolution_base): + # WHEN THEN EXPECT + assert convolution_base.energy_offset.value == 0 + + # THEN + convolution_base.energy_offset = 5 + assert convolution_base.energy_offset.value == 5 + + # THEN + convolution_base.energy_offset = Parameter(name='energy_offset', value=10, unit='meV') + assert convolution_base.energy_offset.value == 10 + assert convolution_base.energy_offset.unit == 'meV' + + def test_energy_offset_setter_invalid_type_raises(self, convolution_base): + # WHEN THEN EXPECT + with pytest.raises( + TypeError, + match='Energy_offset must be a number or a Parameter.', + ): + convolution_base.energy_offset = 'invalid' + + def test_energy_with_offset_setter_raises(self, convolution_base): + # WHEN THEN EXPECT + with pytest.raises( + AttributeError, + match='is a read-only property', + ): + convolution_base.energy_with_offset = 5 + def test_sample_components_property(self, convolution_base): # WHEN THEN EXPECT assert isinstance(convolution_base.sample_components, ComponentCollection) diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution.py b/tests/unit/easydynamics/convolution/test_numerical_convolution.py index e388f17f..9201d07c 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution.py @@ -60,12 +60,13 @@ def test_convolution(self, default_numerical_convolution, upsample_factor): """ # WHEN THEN default_numerical_convolution.upsample_factor = upsample_factor + default_numerical_convolution.energy_offset = 0.4 result = default_numerical_convolution.convolution() # EXPECT expected_area = 2.0 * 3.0 # area of sample_components * area of resolution_components expected_center = ( - 0.1 + 0.2 + 0.1 + 0.2 + 0.4 ) # center of sample_components + center of resolution_components expected_width = np.sqrt(0.4**2 + 0.5**2) # sqrt(width_sample^2 + width_res^2) expected_result = Gaussian( diff --git a/tests/unit/easydynamics/experiment/test_experiment.py b/tests/unit/easydynamics/experiment/test_experiment.py index 067a2017..b62e3305 100644 --- a/tests/unit/easydynamics/experiment/test_experiment.py +++ b/tests/unit/easydynamics/experiment/test_experiment.py @@ -73,6 +73,8 @@ def test_init_no_data(self): # THEN EXPECT assert experiment.display_name == 'empty_experiment' assert experiment._data is None + assert experiment.energy is None + assert experiment.Q is None def test_init_invalid_data(self): "Test initialization with invalid data type" @@ -271,24 +273,6 @@ def test_Q_setter_raises(self, experiment): with pytest.raises(AttributeError): experiment.Q = experiment.Q - def test_Q_getter_warns_no_data(self): - "Test that getting Q data with no data raises Warning" - # WHEN - experiment = Experiment() - - # THEN EXPECT - with pytest.warns(UserWarning, match='No data loaded'): - _ = experiment.Q - - def test_energy_getter_warns_no_data(self): - "Test that getting energy data with no data raises Warning" - # WHEN - experiment = Experiment() - - # THEN EXPECT - with pytest.warns(UserWarning, match='No data loaded'): - _ = experiment.energy - ############## # test plotting ############## @@ -297,7 +281,7 @@ def test_plot_data_success(self, experiment): "Test plotting data successfully when in notebook environment" # WHEN with ( - patch.object(Experiment, '_in_notebook', return_value=True), + patch(f'{Experiment.__module__}._in_notebook', return_value=True), patch('plopp.plot') as mock_plot, patch('IPython.display.display') as mock_display, ): @@ -327,7 +311,7 @@ def test_plot_data_not_in_notebook_raises(self, experiment): "Test plotting data raises RuntimeError" 'when not in notebook environment' # WHEN - with patch.object(Experiment, '_in_notebook', return_value=False): + with patch(f'{Experiment.__module__}._in_notebook', return_value=False): # THEN EXPECT with pytest.raises( RuntimeError, @@ -339,62 +323,6 @@ def test_plot_data_not_in_notebook_raises(self, experiment): # test private methods ############## - def test_in_notebook_returns_true_for_jupyter(self, monkeypatch): - """Should return True when IPython shell is - ZMQInteractiveShell (Jupyter).""" - - # WHEN - class ZMQInteractiveShell: - __name__ = 'ZMQInteractiveShell' - - # THEN - monkeypatch.setattr('IPython.get_ipython', lambda: ZMQInteractiveShell()) - - # EXPECT - assert Experiment._in_notebook() is True - - def test_in_notebook_returns_false_for_terminal_ipython(self, monkeypatch): - """Should return False when IPython shell is - TerminalInteractiveShell.""" - - # WHEN - class TerminalInteractiveShell: - __name__ = 'TerminalInteractiveShell' - - # THEN - - monkeypatch.setattr('IPython.get_ipython', lambda: TerminalInteractiveShell()) - - # EXPECT - assert Experiment._in_notebook() is False - - def test_in_notebook_returns_false_for_unknown_shell(self, monkeypatch): - """Should return False when IPython shell type is - unrecognized.""" - - # WHEN - class UnknownShell: - __name__ = 'UnknownShell' - - # THEN - monkeypatch.setattr('IPython.get_ipython', lambda: UnknownShell()) - # EXPECT - assert Experiment._in_notebook() is False - - def test_in_notebook_returns_false_when_no_ipython(self, monkeypatch): - """Should return False when IPython is not installed or - available.""" - - # WHEN - def raise_import_error(*args, **kwargs): - raise ImportError - - # THEN - monkeypatch.setattr('builtins.__import__', raise_import_error) - - # EXPECT - assert Experiment._in_notebook() is False - def test_validate_coordinates(self, experiment): "Test that _validate_coordinates does not raise for valid data" # WHEN / THEN EXPECT diff --git a/tests/unit/easydynamics/sample_model/components/test_polynomial.py b/tests/unit/easydynamics/sample_model/components/test_polynomial.py index f2f73a74..db9910c1 100644 --- a/tests/unit/easydynamics/sample_model/components/test_polynomial.py +++ b/tests/unit/easydynamics/sample_model/components/test_polynomial.py @@ -58,6 +58,11 @@ def test_input_type_validation_raises(self, kwargs, expected_message): with pytest.raises(TypeError, match=expected_message): Polynomial(display_name='TestPolynomial', **kwargs) + def test_init_no_coefficients_raises(self): + # WHEN THEN EXPECT + with pytest.raises(ValueError, match='At least one coefficient must be provided.'): + Polynomial(display_name='TestPolynomial', coefficients=[]) + def test_negative_value_warns_in_evaluate(self): # WHEN THEN test_polynomial = Polynomial(display_name='TestPolynomial', coefficients=[-1.0]) diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py index b8eb0956..f053e4cb 100644 --- a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py +++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model.py @@ -35,3 +35,12 @@ def test_scale_setter_raises(self, diffusion_model): # WHEN THEN EXPECT with pytest.raises(TypeError, match='scale must be a number.'): diffusion_model.scale = 'invalid' # Invalid type + + def test_repr(self, diffusion_model): + # WHEN THEN + repr_str = repr(diffusion_model) + + # EXPECT + assert 'DiffusionModelBase' in repr_str + assert 'display_name=TestDiffusionModel' in repr_str + assert 'unit=meV' in repr_str diff --git a/tests/unit/easydynamics/sample_model/test_component_collection.py b/tests/unit/easydynamics/sample_model/test_component_collection.py index 42a66f6a..115c2f2e 100644 --- a/tests/unit/easydynamics/sample_model/test_component_collection.py +++ b/tests/unit/easydynamics/sample_model/test_component_collection.py @@ -69,6 +69,11 @@ def test_init_with_invalid_components_raises(self): with pytest.raises(TypeError, match='Component must be.'): ComponentCollection(components=['NotAComponent']) + def test_init_with_invalid_list_of_components_raises(self): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='components must be a list of'): + ComponentCollection(components='NotAList') + def test_init_with_invalid_unit_raises(self): # WHEN THEN EXPECT with pytest.raises(TypeError, match='unit must be'): @@ -153,6 +158,25 @@ def test_component_setter_invalid_raises(self, component_collection): with pytest.raises(TypeError, match='components must be a list of'): component_collection.components = 'NotAList' + def test_is_empty(self): + # WHEN THEN + component_collection = ComponentCollection(display_name='EmptyModel') + # EXPECT + assert component_collection.is_empty is True + + # WHEN THEN + component = Gaussian( + display_name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV' + ) + component_collection.append_component(component) + # EXPECT + assert component_collection.is_empty is False + + def test_is_empty_setter(self, component_collection): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match='is_empty is a read-only property.'): + component_collection.is_empty = True + def test_list_component_names(self, component_collection): # WHEN THEN components = component_collection.list_component_names() diff --git a/tests/unit/easydynamics/sample_model/test_instrument_model.py b/tests/unit/easydynamics/sample_model/test_instrument_model.py index 00f036cd..54396cc2 100644 --- a/tests/unit/easydynamics/sample_model/test_instrument_model.py +++ b/tests/unit/easydynamics/sample_model/test_instrument_model.py @@ -189,6 +189,34 @@ def test_energy_offset_setter_raises(self, instrument_model): ): instrument_model.energy_offset = 'invalid_offset' + def test_get_energy_offset_at_Q(self, instrument_model): + # WHEN + + # THEN + offset_at_Q0 = instrument_model.get_energy_offset_at_Q(0) + + # EXPECT + assert offset_at_Q0.value == instrument_model.energy_offset.value + + def test_get_energy_offset_at_Q_invalid_index_raises(self, instrument_model): + # WHEN / THEN / EXPECT + with pytest.raises( + IndexError, + match='Q_index 5 is out of bounds', + ): + instrument_model.get_energy_offset_at_Q(5) + + def test_get_energy_offset_at_Q_no_Q_raises(self, instrument_model): + # WHEN + instrument_model.Q = None + + # THEN / EXPECT + with pytest.raises( + ValueError, + match='No Q values are set', + ): + instrument_model.get_energy_offset_at_Q(0) + def test_convert_unit_calls_all_children(self, instrument_model): # WHEN new_unit = 'eV' diff --git a/tests/unit/easydynamics/sample_model/test_model_base.py b/tests/unit/easydynamics/sample_model/test_model_base.py index 05591735..fbe44d73 100644 --- a/tests/unit/easydynamics/sample_model/test_model_base.py +++ b/tests/unit/easydynamics/sample_model/test_model_base.py @@ -105,14 +105,6 @@ def test_generate_component_collections_with_Q(self, model_base): assert isinstance(collection.components[1], Lorentzian) assert collection.components[1].display_name == 'TestLorentzian1' - def test_generate_component_collections_without_Q_warns(self, model_base): - # WHEN - model_base._Q = None - - # THEN / EXPECT - with pytest.warns(UserWarning, match='Q is not set'): - model_base._generate_component_collections() - def test_fix_free_all_parameters(self, model_base): # WHEN model_base.fix_all_parameters() @@ -182,6 +174,28 @@ def test_get_all_variables_with_nonint_Q_index_raises(self, model_base): ): model_base.get_all_variables(Q_index='invalid_index') + def test_get_component_collection(self, model_base): + # WHEN THEN + collection = model_base.get_component_collection(Q_index=0) + # EXPECT + assert collection is model_base._component_collections[0] + + def test_get_component_collection_invalid_index_type_raises(self, model_base): + # WHEN THEN EXPECT + with pytest.raises( + TypeError, + match='Q_index must be an int, got str', + ): + model_base.get_component_collection(Q_index='invalid_index') + + def test_get_component_collection_invalid_index_raises(self, model_base): + # WHEN THEN EXPECT + with pytest.raises( + IndexError, + match='Q_index 5 is out of bounds for ', + ): + model_base.get_component_collection(Q_index=5) + def test_append_and_remove_and_clear_component(self, model_base): # WHEN new_component = Gaussian(unique_name='NewGaussian') @@ -223,7 +237,7 @@ def test_append_component_collection(self, model_base): def test_append_component_invalid_type_raises(self, model_base): # WHEN / THEN / EXPECT - with pytest.raises(TypeError, match=' must be a ModelComponent or ComponentCollection'): + with pytest.raises(TypeError, match=' must be '): model_base.append_component('invalid_component') def test_unit_property(self, model_base): diff --git a/tests/unit/easydynamics/sample_model/test_sample_model.py b/tests/unit/easydynamics/sample_model/test_sample_model.py index e5f7a9a7..16919c91 100644 --- a/tests/unit/easydynamics/sample_model/test_sample_model.py +++ b/tests/unit/easydynamics/sample_model/test_sample_model.py @@ -98,6 +98,14 @@ def test_init_raises_with_invalid_temperature(self): ): SampleModel(temperature='invalid_temperature') + def test_init_raises_with_negative_temperature(self): + # WHEN / THEN / EXPECT + with pytest.raises( + ValueError, + match='temperature must be non-negative', + ): + SampleModel(temperature=-5.0) + def test_init_raises_with_invalid_divide_by_temperature(self): # WHEN / THEN / EXPECT with pytest.raises( diff --git a/tests/unit/easydynamics/utils/test_utils.py b/tests/unit/easydynamics/utils/test_utils.py index 97a6c36c..cb3eed27 100644 --- a/tests/unit/easydynamics/utils/test_utils.py +++ b/tests/unit/easydynamics/utils/test_utils.py @@ -5,6 +5,7 @@ import pytest import scipp as sc +from easydynamics.utils.utils import _in_notebook from easydynamics.utils.utils import _validate_and_convert_Q from easydynamics.utils.utils import _validate_unit @@ -112,3 +113,64 @@ def test_validate_unit_string_conversion(self): def test_validate_unit_invalid_type(self, unit_input): with pytest.raises(TypeError, match='unit must be None, a string, or a scipp Unit'): _validate_unit(unit_input) + + +# ----------------------------- + + +class TestInNotebook: + def test_in_notebook_returns_true_for_jupyter(self, monkeypatch): + """Should return True when IPython shell is + ZMQInteractiveShell (Jupyter).""" + + # WHEN + class ZMQInteractiveShell: + __name__ = 'ZMQInteractiveShell' + + # THEN + monkeypatch.setattr('IPython.get_ipython', lambda: ZMQInteractiveShell()) + + # EXPECT + assert _in_notebook() is True + + def test_in_notebook_returns_false_for_terminal_ipython(self, monkeypatch): + """Should return False when IPython shell is + TerminalInteractiveShell.""" + + # WHEN + class TerminalInteractiveShell: + __name__ = 'TerminalInteractiveShell' + + # THEN + + monkeypatch.setattr('IPython.get_ipython', lambda: TerminalInteractiveShell()) + + # EXPECT + assert _in_notebook() is False + + def test_in_notebook_returns_false_for_unknown_shell(self, monkeypatch): + """Should return False when IPython shell type is + unrecognized.""" + + # WHEN + class UnknownShell: + __name__ = 'UnknownShell' + + # THEN + monkeypatch.setattr('IPython.get_ipython', lambda: UnknownShell()) + # EXPECT + assert _in_notebook() is False + + def test_in_notebook_returns_false_when_no_ipython(self, monkeypatch): + """Should return False when IPython is not installed or + available.""" + + # WHEN + def raise_import_error(*args, **kwargs): + raise ImportError + + # THEN + monkeypatch.setattr('builtins.__import__', raise_import_error) + + # EXPECT + assert _in_notebook() is False