Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metatomic-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub mod c_api;
mod metadata;
use crate::c_api::mta_status_t;

pub use self::metadata::{ModelMetadata, PairListOptions};
pub use self::metadata::{Device, DType, ModelCapabilities, ModelMetadata, PairListOptions};

mod quantities;
pub use self::quantities::{Quantity, SampleKind, Gradients};
Expand Down
252 changes: 246 additions & 6 deletions metatomic-core/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,133 @@
use metatensor::{Labels, TensorMap};
use std::ffi::c_void;

use crate::{Error, Quantity, System};
use metatensor::{Labels, TensorMap};

use crate::c_api::mta_model_t;
use crate::{Error, ModelCapabilities, ModelMetadata, PairListOptions, Quantity, System};
use crate::c_api::{mta_model_t, mta_status_t, mta_string_t, mta_string_free};

/// TODO
/// A loaded atomistic model, ready to be executed on a set of systems.
///
/// `Model` wraps a [`mta_model_t`] vtable provided by a plugin. It gives
/// access to the model's metadata and capabilities, and can be run with
/// [`execute_model`]. When a `Model` is dropped, the underlying plugin
/// resources are released via the `unload` callback.
pub struct Model(pub(crate) mta_model_t);

impl Drop for Model {
fn drop(&mut self) {
if let Some(unload) = self.0.unload {
unsafe { unload(self.0.data) };
}
}
}

impl Model {
/// Create a new `Model` from the corresponding C API struct.
///
/// The `Model` takes ownership of `model` and will call its `unload`
/// callback when dropped.
pub fn new(model: mta_model_t) -> Self {
return Model(model);
}

/// Extract the underlying C API struct.
/// Extract the underlying C API struct, transferring ownership to the caller.
///
/// The caller is responsible for eventually calling the `unload` callback
/// on the returned [`mta_model_t`] to free its resources. The `Model`'s
/// own `Drop` implementation is skipped.
pub fn into_raw(self) -> mta_model_t {
return self.0;
let model = std::mem::ManuallyDrop::new(self);
return unsafe { std::ptr::read(&model.0) };
}

fn call_string_callback(
callback: unsafe extern "C" fn(*const c_void, *mut mta_string_t) -> mta_status_t,
data: *const c_void,
) -> Result<String, Error> {
let mut output = mta_string_t::null();
let status = unsafe { callback(data, &mut output) };
if status != mta_status_t::MTA_SUCCESS {
unsafe { mta_string_free(output) };
return Err(Error::CallbackError(status));
}
let json_str = output.as_str().to_owned();
unsafe { mta_string_free(output) };
return Ok(json_str);
}

/// Get the metadata describing this model (name, authors, description,
/// references, ...).
pub fn metadata(&self) -> Result<ModelMetadata, Error> {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All functions should have some documentation

let callback = self.0.metadata.ok_or_else(|| {
Error::Internal("model is missing a 'metadata' callback".into())
})?;
let json_str = Self::call_string_callback(callback, self.0.data)?;
let json = json::parse(&json_str).map_err(|e| {
Error::Serialization(format!("model returned invalid JSON for metadata: {}", e))
})?;
return ModelMetadata::try_from(&json);
}

/// Get the capabilities of this model: which outputs it can compute, which
/// atomic types it supports, its interaction range, length unit, supported
/// devices, and data type.
pub fn capabilities(&self) -> Result<ModelCapabilities, Error> {
let callback = self.0.capabilities.ok_or_else(|| {
Error::Internal("model is missing a 'capabilities' callback".into())
})?;
let json_str = Self::call_string_callback(callback, self.0.data)?;
let json = json::parse(&json_str).map_err(|e| {
Error::Serialization(format!("model returned invalid JSON for capabilities: {}", e))
})?;
return ModelCapabilities::try_from(&json);
}

/// Get the pair lists (neighbor lists) this model needs as input.
///
/// The engine must compute these and attach them to every system with
/// `mta_system_add_pairs` before calling [`execute_model`].
pub fn requested_pair_lists(&self) -> Result<Vec<PairListOptions>, Error> {
let callback = self.0.requested_pair_lists.ok_or_else(|| {
Error::Internal("model is missing a 'requested_pair_lists' callback".into())
})?;
let json_str = Self::call_string_callback(callback, self.0.data)?;
let json = json::parse(&json_str).map_err(|e| {
Error::Serialization(format!("model returned invalid JSON for requested_pair_lists: {}", e))
})?;
if !json.is_array() {
return Err(Error::Serialization(
"model returned invalid JSON for requested_pair_lists, expected an array".into()
));
}
let mut result = Vec::new();
for item in json.members() {
result.push(PairListOptions::try_from(item)?);
}
return Ok(result);
}

/// Get the additional per-system inputs this model needs.
///
/// The engine must attach these to every system with
/// `mta_system_add_custom_data` before calling [`execute_model`].
pub fn requested_inputs(&self) -> Result<Vec<Quantity>, Error> {
let callback = self.0.requested_inputs.ok_or_else(|| {
Error::Internal("model is missing a 'requested_inputs' callback".into())
})?;
let json_str = Self::call_string_callback(callback, self.0.data)?;
let json = json::parse(&json_str).map_err(|e| {
Error::Serialization(format!("model returned invalid JSON for requested_inputs: {}", e))
})?;
if !json.is_array() {
return Err(Error::Serialization(
"model returned invalid JSON for requested_inputs, expected an array".into()
));
}
let mut result = Vec::new();
for item in json.members() {
result.push(Quantity::try_from(item)?);
}
return Ok(result);
}
}

Expand All @@ -29,3 +141,131 @@ pub fn execute_model(
) -> Result<Vec<TensorMap>, Error> {
todo!()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::c_api::{mta_model_t, mta_status_t, mta_string_t};

// ── minimal callback implementations ────────────────────────────────────
// Each function below is a stand-in for what a real plugin would implement.
// They simply write a hard-coded JSON string into the output mta_string_t
// and return MTA_SUCCESS.

unsafe extern "C" fn metadata_callback(
_data: *const c_void,
out: *mut mta_string_t,
) -> mta_status_t {
*out = mta_string_t::new(r#"{
"type": "metatomic_model_metadata",
"name": "test-model",
"authors": ["Alice"],
"description": "A test model",
"references": {"model": [], "architecture": [], "implementation": []},
"extra": {}
}"#);
mta_status_t::MTA_SUCCESS
}

unsafe extern "C" fn capabilities_callback(
_data: *const c_void,
out: *mut mta_string_t,
) -> mta_status_t {
*out = mta_string_t::new(r#"{
"type": "metatomic_model_capabilities",
"outputs": [{"type": "metatomic_quantity", "name": "energy", "unit": "eV", "gradients": [], "sample_kind": "system"}],
"atomic_types": [1, 6],
"interaction_range": 5.0,
"length_unit": "Angstrom",
"supported_devices": ["cpu"],
"dtype": "float32"
}"#);
mta_status_t::MTA_SUCCESS
}

unsafe extern "C" fn pair_lists_callback(
_data: *const c_void,
out: *mut mta_string_t,
) -> mta_status_t {
*out = mta_string_t::new(format!(
r#"[{{"type": "metatomic_pair_options", "cutoff": "{:#x}", "full_list": true, "strict": true}}]"#,
3.5_f64.to_bits()
));
mta_status_t::MTA_SUCCESS
}

unsafe extern "C" fn inputs_callback(
_data: *const c_void,
out: *mut mta_string_t,
) -> mta_status_t {
*out = mta_string_t::new(r#"[{"type": "metatomic_quantity", "name": "charge", "unit": "e", "gradients": [], "sample_kind": "atom"}]"#);
mta_status_t::MTA_SUCCESS
}

// ── helper to build a fully-wired test model ─────────────────────────────
// Constructs an mta_model_t with all four query callbacks set,
// data pointer left null (the callbacks above don't need it).

fn make_test_model() -> Model {
Model::new(mta_model_t {
metadata: Some(metadata_callback),
capabilities: Some(capabilities_callback),
requested_pair_lists: Some(pair_lists_callback),
requested_inputs: Some(inputs_callback),
..mta_model_t::null()
})
}

// ── tests ────────────────────────────────────────────────────────────────

/// `metadata()` correctly deserializes the JSON returned by the callback.
#[test]
fn metadata_happy_path() {
let model = make_test_model();
let m = model.metadata().unwrap();
assert_eq!(m.name, "test-model");
assert_eq!(m.authors, vec!["Alice"]);
assert_eq!(m.description, "A test model");
}

/// `metadata()` returns `Error::Internal` when the model has no callback set.
#[test]
fn metadata_missing_callback() {
let model = Model::new(mta_model_t::null());
let err = model.metadata().unwrap_err();
assert!(matches!(err, Error::Internal(_)));
}

/// `capabilities()` correctly deserializes outputs, atomic types, and device info.
#[test]
fn capabilities_happy_path() {
let model = make_test_model();
let caps = model.capabilities().unwrap();
assert_eq!(caps.outputs.len(), 1);
assert_eq!(caps.outputs[0].name, "energy");
assert_eq!(caps.atomic_types, vec![1, 6]);
assert_eq!(caps.interaction_range.to_bits(), 5.0_f64.to_bits());
assert_eq!(caps.length_unit, "Angstrom");
}

/// `requested_pair_lists()` correctly parses the JSON array and preserves cutoff precision.
#[test]
fn requested_pair_lists_happy_path() {
let model = make_test_model();
let lists = model.requested_pair_lists().unwrap();
assert_eq!(lists.len(), 1);
assert_eq!(lists[0].cutoff.to_bits(), 3.5_f64.to_bits());
assert!(lists[0].full_list);
assert!(lists[0].strict);
}

/// `requested_inputs()` correctly parses the JSON array of Quantity objects.
#[test]
fn requested_inputs_happy_path() {
let model = make_test_model();
let inputs = model.requested_inputs().unwrap();
assert_eq!(inputs.len(), 1);
assert_eq!(inputs[0].name, "charge");
assert_eq!(inputs[0].unit, "e");
}
}
Loading