Source code for brainstat.tests.test_terms

""" Tests the fixed and mixed effects classes. """

import numpy as np
import pandas as pd

from brainstat.stats.terms import FixedEffect, MixedEffect


[docs]def test_fixed_init(): """Tests the initialization of the FixedEffect class.""" random_data = np.random.random_sample((10, 1)) fix1 = FixedEffect(random_data, ["x0"]) fix2 = FixedEffect(random_data, ["x0"], add_intercept=False) assert np.array_equal(fix1.m.shape, [10, 2]) assert np.array_equal(fix1.names, ["intercept", "x0"]) assert np.array_equal(fix2.m.shape, [10, 1]) assert np.array_equal(fix2.names, ["x0"]) categorical_array = pd.DataFrame({"Sex": ["M", "M", "M", "F", "F"]}) fix3 = FixedEffect(categorical_array) assert np.array_equal(fix3.Sex_M, [1, 1, 1, 0, 0]) assert np.array_equal(fix3.Sex_F, [0, 0, 0, 1, 1]) fix4 = FixedEffect(1) assert np.array_equal(fix4.intercept, [1])
[docs]def test_fixed_overload(): """Tests the overloads of the FixedEffect class.""" random_data = np.random.random_sample((10, 3)) fix01 = FixedEffect(random_data[:, :2], ["x0", "x1"], add_intercept=False) fix12 = FixedEffect(random_data[:, 1:], ["x2", "x3"], add_intercept=False) fix2 = FixedEffect(random_data[:, 2], ["x2"], add_intercept=False) fixi0 = FixedEffect(random_data[:, 0], ["x0"], add_intercept=True) fixi1 = FixedEffect(random_data[:, 1], ["x1"], add_intercept=True) fix_add = fix01 + fix12 assert np.array_equal(fix_add.m, random_data) fix_add_intercept = 1 + FixedEffect(random_data[:, 0]) assert np.array_equal(fixi0.m, fix_add_intercept.m) fix_add_intercept = fixi0 + fixi1 expected = np.concatenate((np.ones((10, 1)), random_data[:, 0:2]), axis=1) assert np.array_equal(fix_add_intercept.m, expected) fix_sub = fix01 - fix12 assert np.array_equal(fix_sub.m, random_data[:, 0][:, None]) fix_mul = fix01 * fix2 assert np.array_equal(fix_mul.m, random_data[:, :2] * random_data[:, 2][:, None])
[docs]def test_mixed_init(): """Tests the initialization of the MixedEffect class.""" n = 10 random_data = np.random.random_sample((n, 1)) mix1 = MixedEffect(random_data, ["x0"]) mix2 = MixedEffect(random_data, ["x0"], add_identity=False) mix3 = MixedEffect(random_data, random_data, ["x0"], ["y0"]) mix4 = MixedEffect(random_data, random_data, ["x0"], ["y0"], add_intercept=False) assert np.array_equal(mix1.variance.shape, [n**2, 2]) assert np.array_equal(mix1.variance.names, ["x0", "I"]) assert np.array_equal(mix2.variance.shape, [n**2, 1]) assert np.array_equal(mix2.variance.names, ["x0"]) assert np.array_equal(mix3.mean.shape, [10, 2]) assert np.array_equal(mix3.mean.names, ["intercept", "y0"]) assert np.array_equal(mix4.mean.shape, [10, 1]) assert np.array_equal(mix4.mean.names, ["y0"]) categorical_array = pd.DataFrame({"Sex": ["M", "M", "M", "F", "F"]}) mix5 = MixedEffect(categorical_array) assert np.array_equal(mix5.variance.shape, [25, 2]) assert np.array_equal(mix5.variance.names, ["Sex_", "I"]) mix6 = MixedEffect(1) assert np.array_equal(mix6.variance.shape, [1, 1]) assert np.array_equal(mix6.variance.names, ["I"])
[docs]def test_mixed_overload(): """Tests the overloads of the MixedEffect class.""" n = 3 random_data = np.random.random_sample((n, 4)) mix1 = MixedEffect(random_data[:, 0], name_ran=["x0"]) mix2 = MixedEffect(random_data[:, 1], name_ran=["x1"]) I = np.identity(n).flatten()[:, None] var12 = as_variance(random_data[:, :2]) mix_add = mix1 + mix2 expected_add = np.concatenate((var12, I), axis=1) assert np.array_equal(mix_add.variance.m, expected_add) mix_sub = mix1 - mix1 assert mix_sub.empty mix_mul = mix1 * mix2 expected_mul = np.concatenate( ( var12[:, 0][:, None] * var12[:, 1][:, None], var12[:, 1][:, None] * I, var12[:, 0][:, None] * I, I, ), axis=1, ) assert np.array_equal(mix_mul.variance.m, expected_mul)
[docs]def test_identity_detection(): """Tests that the identity matrix is correctly placed last.""" mix1 = MixedEffect(np.random.rand(3, 1), add_identity=False) mix2 = MixedEffect(1, name_ran="test_identity") I = np.identity(3).flatten() mix_add1 = mix2 + mix1 mix_add2 = mix1 + mix2 assert np.all(mix_add1.variance.m.to_numpy()[:, 1] == I) assert np.all(mix_add2.variance.m.to_numpy()[:, 1] == I)
[docs]def as_variance(M): var = [np.reshape(x[:, None] @ x[None, :], (-1, 1)) for x in M.T] return np.squeeze(var, axis=2).T