Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 89 additions & 80 deletions inst/Regression/RegressionGAM.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## Copyright (C) 2023 Mohammed Azmat Khan <azmat.dev0@gmail.com>
## Copyright (C) 2023-2024 Andreas Bertsatos <abertsatos@biol.uoa.gr>
## Copyright (C) 2026 Jayant Chauhan <0001jayant@gmail.com>
##
## This file is part of the statistics package for GNU Octave.
##
Expand Down Expand Up @@ -58,20 +59,15 @@
## @item @tab @qcode{"responsename"} @tab Response Variable Name, specified as
## a string. If omitted, the default value is @qcode{"Y"}.
##
## @item @tab @qcode{"formula"} @tab a model specification given as a string in
## the form @qcode{"Y ~ terms"} where @qcode{Y} represents the reponse variable
## and @qcode{terms} the predictor variables. The formula can be used to
## specify a subset of variables for training model. For example:
## @qcode{"Y ~ x1 + x2 + x3 + x4 + x1:x2 + x2:x3"} specifies four linear terms
## for the first four columns of for predictor data, and @qcode{x1:x2} and
## @qcode{x2:x3} specify the two interaction terms for 1st-2nd and 3rd-4th
## columns respectively. Only these terms will be used for training the model,
## but @var{X} must have at least as many columns as referenced in the formula.
## If Predictor Variable names have been defined, then the terms in the formula
## must reference to those. When @qcode{"formula"} is specified, all terms used
## for training the model are referenced in the @qcode{IntMatrix} field of the
## @var{obj} class object as a matrix containing the column indexes for each
## term including both the predictors and the interactions used.
## @item @tab @qcode{"formula"} @tab A model specification given as a string
## using standard Wilkinson notation in the form @qcode{"Y ~ terms"}. In
## addition to basic main effects and interactions (@code{+}, @code{:}), it
## fully supports advanced operators including crossing (@code{*}), nesting
## (@code{/}), power/limits (@code{^}), and deletion (@code{-}). The formula
## is evaluated internally via @code{parseWilkinsonFormula}. All expanded terms
## used for training the model are referenced in the @qcode{IntMatrix} field
## of the @var{obj} class object as a matrix containing the column indexes for
## each term including both the predictors and the interactions used.
##
## @item @tab @qcode{"interactions"} @tab a logical matrix, a positive integer
## scalar, or the string @qcode{"all"} for defining the interactions between
Expand Down Expand Up @@ -478,7 +474,7 @@
" must be a logical value."));
endif
## Check model for interactions
if (tmpInt && isempty (this.IntMat))
if (tmpInt && isempty (this.IntMatrix))
error (strcat ("RegressionGAM.predict: trained model", ...
" does not include any interactions."));
endif
Expand All @@ -503,8 +499,8 @@
if (incInt)
if (! isempty (this.Interactions))
## Append interaction terms to the predictor matrix
for i = 1:rows (this.IntMat)
tindex = logical (this.IntMat(i,:));
for i = 1:rows (this.IntMatrix)
tindex = logical (this.IntMatrix(i,:));
Xterms = Xfit(:,tindex);
Xinter = ones (rows (Xfit), 1);
for c = 1:sum (tindex)
Expand All @@ -516,8 +512,8 @@
else
## Add selected predictors and interaction terms
XN = [];
for i = 1:rows (this.IntMat)
tindex = logical (this.IntMat(i,:));
for i = 1:rows (this.IntMatrix)
tindex = logical (this.IntMatrix(i,:));
Xterms = Xfit(:,tindex);
Xinter = ones (rows (Xfit), 1);
for c = 1:sum (tindex)
Expand Down Expand Up @@ -631,78 +627,66 @@ function savemodel (obj, fname)
error (strcat ("RegressionGAM: columns in Interactions logical", ...
" matrix must equal to the number of predictors."));
endif
intMat = this.Interactions
intMat = this.Interactions;
elseif (isnumeric (this.Interactions))
## Need to measure the effect of all interactions to keep the best
## performing. Just check that the given number is not higher than
## p*(p-1)/2, where p is the number of predictors.
p = this.NumPredictors;
if (this.Interactions > p * (p - 1) / 2)
error (strcat ("RegressionGAM: number of interaction terms", ...
" requested is larger than all possible", ...
" combinations of predictors in X."));
endif
## Get all combinations except all zeros
allMat = flip (fullfact(p)([2:end],:), 2);
## Only keep interaction terms
iterms = find (sum (allMat, 2) != 1);
intMat = allMat(iterms);
## Calculate all binary combinations (excluding all zeros)
allMat = dec2bin (1:(2^p - 1)) - '0';
## Only keep true interaction terms (combinations with 2 or more features)
iterms = find (sum (allMat, 2) > 1);
intMat = allMat(iterms, :);
elseif (strcmpi (this.Interactions, "all"))
## Calculate all p*(p-1)/2 interaction terms
allMat = flip (fullfact(p)([2:end],:), 2);
## Only keep interaction terms
iterms = find (sum (allMat, 2) != 1);
intMat = allMat(iterms);
p = this.NumPredictors;
## Calculate all binary combinations (excluding all zeros)
allMat = dec2bin (1:(2^p - 1)) - '0';
## Only keep true interaction terms (combinations with 2 or more features)
iterms = find (sum (allMat, 2) > 1);
intMat = allMat(iterms, :);
endif
endfunction

## Determine interactions from formula
function intMat = parseFormula (this)
intMat = [];
## Check formula for syntax
if (isempty (strfind (this.Formula, '~')))
error ("RegressionGAM: invalid syntax in Formula.");
try
schema = parseWilkinsonFormula (this.Formula, 'matrix');
catch ME
error ("RegressionGAM: Invalid formula. %s", ME.message);
end_try_catch

termMat = schema.Terms;
varNames = schema.VariableNames;

if (! isempty (schema.ResponseIdx))
respIdx = schema.ResponseIdx;
varNames(respIdx) = [];
termMat(:, respIdx) = [];
endif
## Split formula and keep predictor terms
formulaParts = strsplit (this.Formula, '~');
## Check there is some string after '~'
if (numel (formulaParts) < 2)
error ("RegressionGAM: no predictor terms in Formula.");
endif
predictorString = strtrim (formulaParts{2});
if (isempty (predictorString))
error ("RegressionGAM: no predictor terms in Formula.");

if (isempty (varNames) || isempty (termMat))
intMat = logical (zeros (0, this.NumPredictors));
return;
endif
## Split additive terms (between + sign)
aterms = strtrim (strsplit (predictorString, '+'));
## Process all terms
for i = 1:numel (aterms)
## Find individual terms (string missing ':')
if (isempty (strfind (aterms(i), ':'){:}))
## Search PredictorNames to associate with column in X
sterms = strcmp (this.PredictorNames, aterms(i));
## Append to interactions matrix
intMat = [intMat; sterms];
else
## Split interaction terms (string contains ':')
mterms = strsplit (aterms{i}, ':');
## Add each individual predictor to interaction term vector
iterms = logical (zeros (1, this.NumPredictors));
for t = 1:numel (mterms)
iterms = iterms | strcmp (this.PredictorNames, mterms(t));
endfor
## Check that all predictors have been identified
if (sum (iterms) != t)
error ("RegressionGAM: some predictors have not been identified.");
endif
## Append to interactions matrix
intMat = [intMat; iterms];

intMat = zeros (rows (termMat), this.NumPredictors);

for i = 1:numel (varNames)
colIdx = find (strcmp (this.PredictorNames, varNames{i}));
if (isempty (colIdx))
error ("RegressionGAM: Formula contains unknown predictor '%s'.", varNames{i});
endif
intMat(:, colIdx) = termMat(:, i);
endfor
## Check that all terms have been identified
if (! all (sum (intMat, 2) > 0))
error ("RegressionGAM: some terms have not been identified.");
endif

## Remove intercept row (all zeros)
interceptRow = (sum (intMat, 2) == 0);
intMat(interceptRow, :) = [];

intMat = logical (intMat);
endfunction

## Fit the model
Expand Down Expand Up @@ -866,10 +850,37 @@ function savemodel (obj, fname)
%! formula = "Y ~ A + B + C + D + A:C";
%! intMat = logical ([1,0,0,0;0,1,0,0;0,0,1,0;0,0,0,1;1,0,1,0]);
%! a = RegressionGAM (x, y, "predictors", pnames, "formula", formula);
%! assert (a.IntMatrix, double (intMat))
%! assert (a.IntMatrix, intMat)
%! assert ({a.ResponseName, a.PredictorNames}, {"Y", pnames})
%! assert (a.Formula, formula)

%!test
%! ## Test that predict() executes correctly when interactions are present
%! X = [1, 2; 3, 4; 5, 6; 7, 8];
%! Y = [10; 20; 30; 40];
%! mdl = RegressionGAM (X, Y, "formula", "Y ~ x1 + x2 + x1:x2");
%! ypred = predict (mdl, X);
%! assert (isnumeric (ypred));
%! assert (size (ypred), [4, 1]);
%! [ypred2, ySD, yInt] = predict (mdl, X, "includeinteractions", true);
%! assert (size (ypred2), [4, 1]);
%! assert (size (ySD), [4, 1]);
%! assert (size (yInt), [4, 2]);

%!test
%! ## Test advanced Wilkinson notation parsing in RegressionGAM
%! X = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! Y = [1; 2; 3; 4];
%! ## The * operator should automatically expand to main effects + interaction
%! a = RegressionGAM (X, Y, 'Formula', 'Y ~ x1 * x2');
%! assert (class (a), "RegressionGAM");
%! assert (a.NumPredictors, 3);
%! ## Verify the IntMatrix correctly captured x1, x2, and x1:x2
%! expected_int = logical ([1, 0, 0;
%! 0, 1, 0;
%! 1, 1, 0]);
%! assert (a.IntMatrix, expected_int);

## Test input validation for constructor
%!error<RegressionGAM: too few input arguments.> RegressionGAM ()
%!error<RegressionGAM: too few input arguments.> RegressionGAM (ones(10,2))
Expand All @@ -883,13 +894,11 @@ function savemodel (obj, fname)
%! RegressionGAM (ones(10,2), ones (10,1), "formula", {"y~x1+x2"})
%!error<RegressionGAM: Formula must be a string.>
%! RegressionGAM (ones(10,2), ones (10,1), "formula", [0, 1, 0])
%!error<RegressionGAM: invalid syntax in Formula.> ...
%!error<RegressionGAM: Formula contains unknown predictor 'something'.> ...
%! RegressionGAM (ones(10,2), ones (10,1), "formula", "something")
%!error<RegressionGAM: no predictor terms in Formula.> ...
%! RegressionGAM (ones(10,2), ones (10,1), "formula", "something~")
%!error<RegressionGAM: no predictor terms in Formula.> ...
%!error<RegressionGAM: Invalid formula.> ...
%! RegressionGAM (ones(10,2), ones (10,1), "formula", "something~")
%!error<RegressionGAM: some predictors have not been identified> ...
%!error<RegressionGAM: Invalid formula.> ...
%! RegressionGAM (ones(10,2), ones (10,1), "formula", "something~x1:")
%!error<RegressionGAM: invalid Interactions parameter.> ...
%! RegressionGAM (ones(10,2), ones (10,1), "interactions", "some")
Expand Down