Skip to content

moseq_train.py

activate(train_schema_name, infer_schema_name=None, *, create_schema=True, create_tables=True, linking_module=None)

Activate this schema.

Parameters:

Name Type Description Default
train_schema_name str

A string containing the name of the moseq_train schema.

required
infer_schema_name str

A string containing the name of the moseq_infer schema.

None
create_schema bool

If True (default), schema will be created in the database.

True
create_tables bool

If True (default), tables related to the schema will be created in the database.

True
linking_module str

A string containing the module name or module containing the required dependencies to activate the schema.

None

Dependencies: Functions: get_kpms_root_data_dir(): Returns absolute path for root data director(y/ies) with all behavioral recordings, as (list of) string(s). get_kpms_processed_data_dir(): Optional. Returns absolute path for processed data.

Source code in element_moseq/moseq_train.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def activate(
    train_schema_name: str,
    infer_schema_name: str = None,
    *,
    create_schema: bool = True,
    create_tables: bool = True,
    linking_module: str = None,
):
    """Activate this schema.

    Args:
        train_schema_name (str): A string containing the name of the `moseq_train` schema.
        infer_schema_name (str): A string containing the name of the `moseq_infer` schema.
        create_schema (bool): If True (default), schema  will be created in the database.
        create_tables (bool): If True (default), tables related to the schema will be created in the database.
        linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
    Dependencies:
    Functions:
        get_kpms_root_data_dir(): Returns absolute path for root data director(y/ies)
                                 with all behavioral recordings, as (list of) string(s).
        get_kpms_processed_data_dir(): Optional. Returns absolute path for processed
                                      data.

    """

    if isinstance(linking_module, str):
        linking_module = importlib.import_module(linking_module)
    assert inspect.ismodule(
        linking_module
    ), "The argument 'dependency' must be a module's name or a module"

    assert hasattr(
        linking_module, "get_kpms_root_data_dir"
    ), "The linking module must specify a lookup function for a root data directory"

    global _linking_module
    _linking_module = linking_module

    # activate
    moseq_infer.activate(
        infer_schema_name,
        create_schema=create_schema,
        create_tables=create_tables,
        linking_module=linking_module,
    )

    schema.activate(
        train_schema_name,
        create_schema=create_schema,
        create_tables=create_tables,
        add_objects=_linking_module.__dict__,
    )

KeypointSet

Bases: Manual

Store the keypoint data and the video set directory for model training.

Attributes:

Name Type Description
kpset_id int)

Unique ID for each keypoint set.

PoseEstimationMethod foreign key)

Unique format method used to obtain the keypoints data.

kpset_dir str)

Path where the keypoint files are located together with the pose estimation config file, relative to root data directory.

kpset_desc str)

Optional. User-entered description.

Source code in element_moseq/moseq_train.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@schema
class KeypointSet(dj.Manual):
    """Store the keypoint data and the video set directory for model training.

    Attributes:
        kpset_id (int)                          : Unique ID for each keypoint set.
        PoseEstimationMethod (foreign key)      : Unique format method used to obtain the keypoints data.
        kpset_dir (str)                         : Path where the keypoint files are located together with the pose estimation `config` file, relative to root data directory.
        kpset_desc (str)                            : Optional. User-entered description.
    """

    definition = """
    kpset_id                        : int           # Unique ID for each keypoint set   
    ---
    -> moseq_infer.PoseEstimationMethod             # Unique format method used to obtain the keypoints data
    kpset_dir                       : varchar(255)  # Path where the keypoint files are located together with the pose estimation `config` file, relative to root data directory 
    kpset_desc=''                   : varchar(1000) # Optional. User-entered description
    """

    class VideoFile(dj.Part):
        """Store the IDs and file paths of each video file that will be used for model training.

        Attributes:
            KeypointSet (foreign key) : Unique ID for each keypoint set.
            video_id (int)            : Unique ID for each video corresponding to each keypoint data file, relative to root data directory.
            video_path (str)          : Filepath of each video from which the keypoints are derived, relative to root data directory.
        """

        definition = """
        -> master
        video_id                    : int           # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
        ---
        video_path                  : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory
        """

VideoFile

Bases: Part

Store the IDs and file paths of each video file that will be used for model training.

Attributes:

Name Type Description
KeypointSet foreign key)

Unique ID for each keypoint set.

video_id int)

Unique ID for each video corresponding to each keypoint data file, relative to root data directory.

video_path str)

Filepath of each video from which the keypoints are derived, relative to root data directory.

Source code in element_moseq/moseq_train.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class VideoFile(dj.Part):
    """Store the IDs and file paths of each video file that will be used for model training.

    Attributes:
        KeypointSet (foreign key) : Unique ID for each keypoint set.
        video_id (int)            : Unique ID for each video corresponding to each keypoint data file, relative to root data directory.
        video_path (str)          : Filepath of each video from which the keypoints are derived, relative to root data directory.
    """

    definition = """
    -> master
    video_id                    : int           # Unique ID for each video corresponding to each keypoint data file, relative to root data directory
    ---
    video_path                  : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory
    """

Bodyparts

Bases: Manual

Store the body parts to use in the analysis.

Attributes:

Name Type Description
KeypointSet foreign key)

Unique ID for each KeypointSet key.

bodyparts_id int)

Unique ID for a set of bodyparts for a particular keypoint set.

anterior_bodyparts blob)

List of strings of anterior bodyparts

posterior_bodyparts blob)

List of strings of posterior bodyparts

use_bodyparts blob)

List of strings of bodyparts to be used

bodyparts_desc(varchar)

Optional. User-entered description.

Source code in element_moseq/moseq_train.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@schema
class Bodyparts(dj.Manual):
    """Store the body parts to use in the analysis.

    Attributes:
        KeypointSet (foreign key)       : Unique ID for each `KeypointSet` key.
        bodyparts_id (int)              : Unique ID for a set of bodyparts for a particular keypoint set.
        anterior_bodyparts (blob)       : List of strings of anterior bodyparts
        posterior_bodyparts (blob)      : List of strings of posterior bodyparts
        use_bodyparts (blob)            : List of strings of bodyparts to be used
        bodyparts_desc(varchar)         : Optional. User-entered description.
    """

    definition = """
    -> KeypointSet                              # Unique ID for each `KeypointSet` key
    bodyparts_id                : int           # Unique ID for a set of bodyparts for a particular keypoint set
    ---
    anterior_bodyparts          : blob          # List of strings of anterior bodyparts
    posterior_bodyparts         : blob          # List of strings of posterior bodyparts
    use_bodyparts               : blob          # List of strings of bodyparts to be used
    bodyparts_desc=''           : varchar(1000) # Optional. User-entered description
    """

PCATask

Bases: Manual

Staging table to define the PCA task and its output directory.

Attributes:

Name Type Description
Bodyparts foreign key)

Unique ID for each Bodyparts key

kpms_project_output_dir str)

Keypoint-MoSeq project output directory, relative to root data directory

Source code in element_moseq/moseq_train.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@schema
class PCATask(dj.Manual):
    """
    Staging table to define the PCA task and its output directory.

    Attributes:
        Bodyparts (foreign key)         : Unique ID for each `Bodyparts` key
        kpms_project_output_dir (str)   : Keypoint-MoSeq project output directory, relative to root data directory
    """

    definition = """ 
    -> Bodyparts                                                # Unique ID for each `Bodyparts` key
    ---
    kpms_project_output_dir=''          : varchar(255)          # Keypoint-MoSeq project output directory, relative to root data directory
    task_mode='load'                 :enum('load','trigger') # Trigger or load the task

    """

PCAPrep

Bases: Imported

Table to set up the Keypoint-MoSeq project output directory (kpms_project_output_dir) , creating the default config.yml and updating it in a new kpms_dj_config.yml.

Attributes:

Name Type Description
PCATask foreign key)

Unique ID for each PCATask key.

coordinates longblob)

Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]).

confidences longblob)

Dictionary mapping filenames to likelihood scores as ndarrays of shape (n_frames, n_bodyparts).

formatted_bodyparts longblob)

List of bodypart names. The order of the names matches the order of the bodyparts in coordinates and confidences.

average_frame_rate float)

Average frame rate of the videos for model training.

frame_rates longblob)

List of the frame rates of the videos for model training.

Source code in element_moseq/moseq_train.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@schema
class PCAPrep(dj.Imported):
    """
    Table to set up the Keypoint-MoSeq project output directory (`kpms_project_output_dir`) , creating the default `config.yml` and updating it in a new `kpms_dj_config.yml`.

    Attributes:
        PCATask (foreign key)           : Unique ID for each `PCATask` key.
        coordinates (longblob)          : Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]).
        confidences (longblob)          : Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts).
        formatted_bodyparts (longblob)  : List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`.
        average_frame_rate (float)      : Average frame rate of the videos for model training.
        frame_rates (longblob)          : List of the frame rates of the videos for model training.
    """

    definition = """
    -> PCATask                          # Unique ID for each `PCATask` key
    ---
    coordinates             : longblob  # Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3])
    confidences             : longblob  # Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts)           
    formatted_bodyparts     : longblob  # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`.
    average_frame_rate      : float     # Average frame rate of the videos for model training
    frame_rates             : longblob  # List of the frame rates of the videos for model training
    """

    def make(self, key):
        """
        Make function to:
        1. Generate and update the `kpms_dj_config.yml` with both the videoset directory and the bodyparts.
        2. Create the keypoint coordinates and confidences scores to format the data for the PCA fitting.

        Args:
            key (dict): Primary key from the `PCATask` table.

        Raises:
            NotImplementedError: `pose_estimation_method` is only supported for `deeplabcut`.

        High-Level Logic:
        1. Fetches the bodyparts, format method, and the directories for the Keypoint-MoSeq project output, the keypoint set, and the video set.
        2. Set variables for each of the full path of the mentioned directories.
        3. Find the first existing pose estimation config file in the `kpset_dir` directory, if not found, raise an error.
        4. Check that the pose_estimation_method is `deeplabcut` and set up the project output directory with the default `config.yml`.
        5. Create the `kpms_project_output_dir` (if it does not exist), and generates the kpms default `config.yml` with the default values from the pose estimation config.
        6. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts
        7. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set.
        8. As a result of the keypoint loading, the coordinates and confidences scores are generated and will be used to format the data for modeling.
        9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. This two attributes can be used to calculate the kappa value.
        10. Insert the results of this `make` function into the table.
        """
        from keypoint_moseq import setup_project, load_config, load_keypoints

        anterior_bodyparts, posterior_bodyparts, use_bodyparts = (
            Bodyparts & key
        ).fetch1(
            "anterior_bodyparts",
            "posterior_bodyparts",
            "use_bodyparts",
        )

        pose_estimation_method, kpset_dir = (KeypointSet & key).fetch1(
            "pose_estimation_method", "kpset_dir"
        )
        video_paths, video_ids = (KeypointSet.VideoFile & key).fetch(
            "video_path", "video_id"
        )

        kpms_root = moseq_infer.get_kpms_root_data_dir()
        kpms_processed = moseq_infer.get_kpms_processed_data_dir()

        kpms_project_output_dir, task_mode = (PCATask & key).fetch1(
            "kpms_project_output_dir", "task_mode"
        )

        if task_mode == "trigger":
            try:
                kpms_project_output_dir = find_full_path(
                    kpms_processed, kpms_project_output_dir
                )

            except FileNotFoundError:
                kpms_project_output_dir = kpms_processed / kpms_project_output_dir

            kpset_dir = find_full_path(kpms_root, kpset_dir)
            videos_dir = find_full_path(kpms_root, Path(video_paths[0]).parent)

            if pose_estimation_method == "deeplabcut":
                setup_project(
                    project_dir=kpms_project_output_dir.as_posix(),
                    deeplabcut_config=(kpset_dir / "config.yaml")
                    or (kpset_dir / "config.yml"),
                )
            else:
                raise NotImplementedError(
                    "Currently, `deeplabcut` is the only pose estimation method supported by this Element. Please reach out at `support@datajoint.com` if you use another method."
                )

            kpms_config = load_config(
                kpms_project_output_dir.as_posix(),
                check_if_valid=True,
                build_indexes=False,
            )

            kpms_dj_config_kwargs_dict = dict(
                video_dir=videos_dir.as_posix(),
                anterior_bodyparts=anterior_bodyparts,
                posterior_bodyparts=posterior_bodyparts,
                use_bodyparts=use_bodyparts,
            )
            kpms_config.update(**kpms_dj_config_kwargs_dict)
            kpms_reader.generate_kpms_dj_config(
                kpms_project_output_dir.as_posix(), **kpms_config
            )
        else:
            kpms_project_output_dir = find_full_path(
                kpms_processed, kpms_project_output_dir
            )
            kpset_dir = find_full_path(kpms_root, kpset_dir)
            videos_dir = find_full_path(kpms_root, Path(video_paths[0]).parent)

        coordinates, confidences, formatted_bodyparts = load_keypoints(
            filepath_pattern=kpset_dir, format=pose_estimation_method
        )

        frame_rate_list = []
        for fp, _ in zip(video_paths, video_ids):
            video_path = (find_full_path(kpms_root, fp)).as_posix()
            cap = cv2.VideoCapture(video_path)
            frame_rate_list.append(int(cap.get(cv2.CAP_PROP_FPS)))
            cap.release()
        average_frame_rate = int(np.mean(frame_rate_list))

        self.insert1(
            dict(
                **key,
                coordinates=coordinates,
                confidences=confidences,
                formatted_bodyparts=formatted_bodyparts,
                average_frame_rate=average_frame_rate,
                frame_rates=frame_rate_list,
            )
        )

make(key)

Make function to: 1. Generate and update the kpms_dj_config.yml with both the videoset directory and the bodyparts. 2. Create the keypoint coordinates and confidences scores to format the data for the PCA fitting.

Parameters:

Name Type Description Default
key dict

Primary key from the PCATask table.

required

Raises:

Type Description
NotImplementedError

pose_estimation_method is only supported for deeplabcut.

High-Level Logic: 1. Fetches the bodyparts, format method, and the directories for the Keypoint-MoSeq project output, the keypoint set, and the video set. 2. Set variables for each of the full path of the mentioned directories. 3. Find the first existing pose estimation config file in the kpset_dir directory, if not found, raise an error. 4. Check that the pose_estimation_method is deeplabcut and set up the project output directory with the default config.yml. 5. Create the kpms_project_output_dir (if it does not exist), and generates the kpms default config.yml with the default values from the pose estimation config. 6. Create a copy of the kpms config.yml named kpms_dj_config.yml that will be updated with both the video_dir and bodyparts 7. Load keypoint data from the keypoint files found in the kpset_dir that will serve as the training set. 8. As a result of the keypoint loading, the coordinates and confidences scores are generated and will be used to format the data for modeling. 9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. This two attributes can be used to calculate the kappa value. 10. Insert the results of this make function into the table.

Source code in element_moseq/moseq_train.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def make(self, key):
    """
    Make function to:
    1. Generate and update the `kpms_dj_config.yml` with both the videoset directory and the bodyparts.
    2. Create the keypoint coordinates and confidences scores to format the data for the PCA fitting.

    Args:
        key (dict): Primary key from the `PCATask` table.

    Raises:
        NotImplementedError: `pose_estimation_method` is only supported for `deeplabcut`.

    High-Level Logic:
    1. Fetches the bodyparts, format method, and the directories for the Keypoint-MoSeq project output, the keypoint set, and the video set.
    2. Set variables for each of the full path of the mentioned directories.
    3. Find the first existing pose estimation config file in the `kpset_dir` directory, if not found, raise an error.
    4. Check that the pose_estimation_method is `deeplabcut` and set up the project output directory with the default `config.yml`.
    5. Create the `kpms_project_output_dir` (if it does not exist), and generates the kpms default `config.yml` with the default values from the pose estimation config.
    6. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts
    7. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set.
    8. As a result of the keypoint loading, the coordinates and confidences scores are generated and will be used to format the data for modeling.
    9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. This two attributes can be used to calculate the kappa value.
    10. Insert the results of this `make` function into the table.
    """
    from keypoint_moseq import setup_project, load_config, load_keypoints

    anterior_bodyparts, posterior_bodyparts, use_bodyparts = (
        Bodyparts & key
    ).fetch1(
        "anterior_bodyparts",
        "posterior_bodyparts",
        "use_bodyparts",
    )

    pose_estimation_method, kpset_dir = (KeypointSet & key).fetch1(
        "pose_estimation_method", "kpset_dir"
    )
    video_paths, video_ids = (KeypointSet.VideoFile & key).fetch(
        "video_path", "video_id"
    )

    kpms_root = moseq_infer.get_kpms_root_data_dir()
    kpms_processed = moseq_infer.get_kpms_processed_data_dir()

    kpms_project_output_dir, task_mode = (PCATask & key).fetch1(
        "kpms_project_output_dir", "task_mode"
    )

    if task_mode == "trigger":
        try:
            kpms_project_output_dir = find_full_path(
                kpms_processed, kpms_project_output_dir
            )

        except FileNotFoundError:
            kpms_project_output_dir = kpms_processed / kpms_project_output_dir

        kpset_dir = find_full_path(kpms_root, kpset_dir)
        videos_dir = find_full_path(kpms_root, Path(video_paths[0]).parent)

        if pose_estimation_method == "deeplabcut":
            setup_project(
                project_dir=kpms_project_output_dir.as_posix(),
                deeplabcut_config=(kpset_dir / "config.yaml")
                or (kpset_dir / "config.yml"),
            )
        else:
            raise NotImplementedError(
                "Currently, `deeplabcut` is the only pose estimation method supported by this Element. Please reach out at `support@datajoint.com` if you use another method."
            )

        kpms_config = load_config(
            kpms_project_output_dir.as_posix(),
            check_if_valid=True,
            build_indexes=False,
        )

        kpms_dj_config_kwargs_dict = dict(
            video_dir=videos_dir.as_posix(),
            anterior_bodyparts=anterior_bodyparts,
            posterior_bodyparts=posterior_bodyparts,
            use_bodyparts=use_bodyparts,
        )
        kpms_config.update(**kpms_dj_config_kwargs_dict)
        kpms_reader.generate_kpms_dj_config(
            kpms_project_output_dir.as_posix(), **kpms_config
        )
    else:
        kpms_project_output_dir = find_full_path(
            kpms_processed, kpms_project_output_dir
        )
        kpset_dir = find_full_path(kpms_root, kpset_dir)
        videos_dir = find_full_path(kpms_root, Path(video_paths[0]).parent)

    coordinates, confidences, formatted_bodyparts = load_keypoints(
        filepath_pattern=kpset_dir, format=pose_estimation_method
    )

    frame_rate_list = []
    for fp, _ in zip(video_paths, video_ids):
        video_path = (find_full_path(kpms_root, fp)).as_posix()
        cap = cv2.VideoCapture(video_path)
        frame_rate_list.append(int(cap.get(cv2.CAP_PROP_FPS)))
        cap.release()
    average_frame_rate = int(np.mean(frame_rate_list))

    self.insert1(
        dict(
            **key,
            coordinates=coordinates,
            confidences=confidences,
            formatted_bodyparts=formatted_bodyparts,
            average_frame_rate=average_frame_rate,
            frame_rates=frame_rate_list,
        )
    )

PCAFit

Bases: Computed

Fit PCA model.

Attributes:

Name Type Description
PCAPrep foreign key)

PCAPrep Key.

pca_fit_time datetime)

datetime of the PCA fitting analysis.

Source code in element_moseq/moseq_train.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
@schema
class PCAFit(dj.Computed):
    """Fit PCA model.

    Attributes:
        PCAPrep (foreign key)           : `PCAPrep` Key.
        pca_fit_time (datetime)         : datetime of the PCA fitting analysis.
    """

    definition = """
    -> PCAPrep                           # `PCAPrep` Key
    ---
    pca_fit_time=NULL        : datetime  # datetime of the PCA fitting analysis
    """

    def make(self, key):
        """
        Make function to format the keypoint data, fit the PCA model, and store it as a `pca.p` file in the Keypoint-MoSeq project output directory.

        Args:
            key (dict): `PCAPrep` Key

        Raises:

        High-Level Logic:
        1. Fetch the `kpms_project_output_dir` from the `PCATask` table and define its full path.
        2. Load the `kpms_dj_config` file that contains the updated `video_dir` and bodyparts, \
           and format the keypoint data with the coordinates and confidences scores to be used in the PCA fitting.
        3. Fit the PCA model and save it as `pca.p` file in the output directory.
        4.Insert the creation datetime as the `pca_fit_time` into the table.
        """
        from keypoint_moseq import format_data, fit_pca, save_pca

        kpms_project_output_dir, task_mode = (PCATask & key).fetch1(
            "kpms_project_output_dir", "task_mode"
        )
        kpms_project_output_dir = (
            moseq_infer.get_kpms_processed_data_dir() / kpms_project_output_dir
        )

        kpms_default_config = kpms_reader.load_kpms_dj_config(
            kpms_project_output_dir.as_posix(), check_if_valid=True, build_indexes=True
        )
        coordinates, confidences = (PCAPrep & key).fetch1("coordinates", "confidences")
        data, _ = format_data(
            **kpms_default_config, coordinates=coordinates, confidences=confidences
        )

        if task_mode == "trigger":
            pca = fit_pca(**data, **kpms_default_config)
            save_pca(pca, kpms_project_output_dir.as_posix())
            creation_datetime = datetime.now(timezone.utc)
        else:
            creation_datetime = None

        self.insert1(dict(**key, pca_fit_time=creation_datetime))

make(key)

Make function to format the keypoint data, fit the PCA model, and store it as a pca.p file in the Keypoint-MoSeq project output directory.

Parameters:

Name Type Description Default
key dict

PCAPrep Key

required

Raises:

High-Level Logic: 1. Fetch the kpms_project_output_dir from the PCATask table and define its full path. 2. Load the kpms_dj_config file that contains the updated video_dir and bodyparts, and format the keypoint data with the coordinates and confidences scores to be used in the PCA fitting. 3. Fit the PCA model and save it as pca.p file in the output directory. 4.Insert the creation datetime as the pca_fit_time into the table.

Source code in element_moseq/moseq_train.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def make(self, key):
    """
    Make function to format the keypoint data, fit the PCA model, and store it as a `pca.p` file in the Keypoint-MoSeq project output directory.

    Args:
        key (dict): `PCAPrep` Key

    Raises:

    High-Level Logic:
    1. Fetch the `kpms_project_output_dir` from the `PCATask` table and define its full path.
    2. Load the `kpms_dj_config` file that contains the updated `video_dir` and bodyparts, \
       and format the keypoint data with the coordinates and confidences scores to be used in the PCA fitting.
    3. Fit the PCA model and save it as `pca.p` file in the output directory.
    4.Insert the creation datetime as the `pca_fit_time` into the table.
    """
    from keypoint_moseq import format_data, fit_pca, save_pca

    kpms_project_output_dir, task_mode = (PCATask & key).fetch1(
        "kpms_project_output_dir", "task_mode"
    )
    kpms_project_output_dir = (
        moseq_infer.get_kpms_processed_data_dir() / kpms_project_output_dir
    )

    kpms_default_config = kpms_reader.load_kpms_dj_config(
        kpms_project_output_dir.as_posix(), check_if_valid=True, build_indexes=True
    )
    coordinates, confidences = (PCAPrep & key).fetch1("coordinates", "confidences")
    data, _ = format_data(
        **kpms_default_config, coordinates=coordinates, confidences=confidences
    )

    if task_mode == "trigger":
        pca = fit_pca(**data, **kpms_default_config)
        save_pca(pca, kpms_project_output_dir.as_posix())
        creation_datetime = datetime.now(timezone.utc)
    else:
        creation_datetime = None

    self.insert1(dict(**key, pca_fit_time=creation_datetime))

LatentDimension

Bases: Imported

Determine the latent dimension as part of the autoregressive hyperparameters (ar_hypparams) for the model fitting. The objective of the analysis is to inform the user about the number of principal components needed to explain a 90% variance threshold. Subsequently, the decision on how many components to utilize for the model fitting is left to the user.

Attributes:

Name Type Description
PCAFit foreign key)

PCAFit Key.

variance_percentage float)

Variance threshold. Fixed value to 90%.

latent_dimension int)

Number of principal components required to explain the specified variance.

latent_dim_desc varchar)

Automated description of the computation result.

Source code in element_moseq/moseq_train.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
@schema
class LatentDimension(dj.Imported):
    """
    Determine the latent dimension as part of the autoregressive hyperparameters (`ar_hypparams`) for the model fitting.
    The objective of the analysis is to inform the user about the number of principal components needed to explain a
    90% variance threshold. Subsequently, the decision on how many components to utilize for the model fitting is left
    to the user.

    Attributes:
        PCAFit (foreign key)               : `PCAFit` Key.
        variance_percentage (float)        : Variance threshold. Fixed value to 90%.
        latent_dimension (int)             : Number of principal components required to explain the specified variance.
        latent_dim_desc (varchar)          : Automated description of the computation result.
    """

    definition = """
    -> PCAFit                                   # `PCAFit` Key
    ---
    variance_percentage      : float            # Variance threshold. Fixed value to 90 percent.
    latent_dimension         : int              # Number of principal components required to explain the specified variance.
    latent_dim_desc          : varchar(1000)    # Automated description of the computation result.
    """

    def make(self, key):
        """
        Make function to compute and store the latent dimension that explains a 90% variance threshold.

        Args:
            key (dict): `PCAFit` Key.

        Raises:

        High-Level Logic:
        1. Fetches the Keypoint-MoSeq project output directory from the PCATask table and define the full path.
        2. Load the PCA model from file in this directory.
        2. Set a specified variance threshold to 90% and compute the cumulative sum of the explained variance ratio.
        3. Determine the number of components required to explain the specified variance.
            3.1 If the cumulative sum of the explained variance ratio is less than the specified variance threshold, \
                it sets the `latent_dimension` to the total number of components and `variance_percentage` to the cumulative sum of the explained variance ratio.
            3.2 If the cumulative sum of the explained variance ratio is greater than the specified variance threshold, \
                it sets the `latent_dimension` to the number of components that explain the specified variance and `variance_percentage` to the specified variance threshold.
        4. Insert the results of this `make` function into the table.
        """
        from keypoint_moseq import load_pca

        kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir")
        kpms_project_output_dir = (
            moseq_infer.get_kpms_processed_data_dir() / kpms_project_output_dir
        )

        pca_path = kpms_project_output_dir / "pca.p"
        if pca_path:
            pca = load_pca(kpms_project_output_dir.as_posix())
        else:
            raise FileNotFoundError(
                f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}"
            )

        variance_threshold = 0.90

        cs = np.cumsum(
            pca.explained_variance_ratio_
        )  # explained_variance_ratio_ndarray of shape (n_components,)

        if cs[-1] < variance_threshold:
            latent_dimension = len(cs)
            variance_percentage = cs[-1] * 100
            latent_dim_desc = (
                f"All components together only explain {cs[-1]*100}% of variance."
            )
        else:
            latent_dimension = (cs > variance_threshold).nonzero()[0].min() + 1
            variance_percentage = variance_threshold * 100
            latent_dim_desc = f">={variance_threshold*100}% of variance explained by {(cs>variance_threshold).nonzero()[0].min()+1} components."

        self.insert1(
            dict(
                **key,
                variance_percentage=variance_percentage,
                latent_dimension=latent_dimension,
                latent_dim_desc=latent_dim_desc,
            )
        )

make(key)

Make function to compute and store the latent dimension that explains a 90% variance threshold.

Parameters:

Name Type Description Default
key dict

PCAFit Key.

required

Raises:

High-Level Logic: 1. Fetches the Keypoint-MoSeq project output directory from the PCATask table and define the full path. 2. Load the PCA model from file in this directory. 2. Set a specified variance threshold to 90% and compute the cumulative sum of the explained variance ratio. 3. Determine the number of components required to explain the specified variance. 3.1 If the cumulative sum of the explained variance ratio is less than the specified variance threshold, it sets the latent_dimension to the total number of components and variance_percentage to the cumulative sum of the explained variance ratio. 3.2 If the cumulative sum of the explained variance ratio is greater than the specified variance threshold, it sets the latent_dimension to the number of components that explain the specified variance and variance_percentage to the specified variance threshold. 4. Insert the results of this make function into the table.

Source code in element_moseq/moseq_train.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def make(self, key):
    """
    Make function to compute and store the latent dimension that explains a 90% variance threshold.

    Args:
        key (dict): `PCAFit` Key.

    Raises:

    High-Level Logic:
    1. Fetches the Keypoint-MoSeq project output directory from the PCATask table and define the full path.
    2. Load the PCA model from file in this directory.
    2. Set a specified variance threshold to 90% and compute the cumulative sum of the explained variance ratio.
    3. Determine the number of components required to explain the specified variance.
        3.1 If the cumulative sum of the explained variance ratio is less than the specified variance threshold, \
            it sets the `latent_dimension` to the total number of components and `variance_percentage` to the cumulative sum of the explained variance ratio.
        3.2 If the cumulative sum of the explained variance ratio is greater than the specified variance threshold, \
            it sets the `latent_dimension` to the number of components that explain the specified variance and `variance_percentage` to the specified variance threshold.
    4. Insert the results of this `make` function into the table.
    """
    from keypoint_moseq import load_pca

    kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir")
    kpms_project_output_dir = (
        moseq_infer.get_kpms_processed_data_dir() / kpms_project_output_dir
    )

    pca_path = kpms_project_output_dir / "pca.p"
    if pca_path:
        pca = load_pca(kpms_project_output_dir.as_posix())
    else:
        raise FileNotFoundError(
            f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}"
        )

    variance_threshold = 0.90

    cs = np.cumsum(
        pca.explained_variance_ratio_
    )  # explained_variance_ratio_ndarray of shape (n_components,)

    if cs[-1] < variance_threshold:
        latent_dimension = len(cs)
        variance_percentage = cs[-1] * 100
        latent_dim_desc = (
            f"All components together only explain {cs[-1]*100}% of variance."
        )
    else:
        latent_dimension = (cs > variance_threshold).nonzero()[0].min() + 1
        variance_percentage = variance_threshold * 100
        latent_dim_desc = f">={variance_threshold*100}% of variance explained by {(cs>variance_threshold).nonzero()[0].min()+1} components."

    self.insert1(
        dict(
            **key,
            variance_percentage=variance_percentage,
            latent_dimension=latent_dimension,
            latent_dim_desc=latent_dim_desc,
        )
    )

PreFitTask

Bases: Manual

Insert the parameters for the model (AR-HMM) pre-fitting.

Attributes:

Name Type Description
PCAFit foreign key)

PCAFit task.

pre_latent_dim int)

Latent dimension to use for the model pre-fitting.

pre_kappa int)

Kappa value to use for the model pre-fitting.

pre_num_iterations int)

Number of Gibbs sampling iterations to run in the model pre-fitting.

pre_fit_desc(varchar)

User-defined description of the pre-fitting task.

Source code in element_moseq/moseq_train.py
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
@schema
class PreFitTask(dj.Manual):
    """Insert the parameters for the model (AR-HMM) pre-fitting.

    Attributes:
        PCAFit (foreign key)                : `PCAFit` task.
        pre_latent_dim (int)                : Latent dimension to use for the model pre-fitting.
        pre_kappa (int)                     : Kappa value to use for the model pre-fitting.
        pre_num_iterations (int)            : Number of Gibbs sampling iterations to run in the model pre-fitting.
        pre_fit_desc(varchar)               : User-defined description of the pre-fitting task.
    """

    definition = """
    -> PCAFit                                           # `PCAFit` Key
    pre_latent_dim               : int                  # Latent dimension to use for the model pre-fitting
    pre_kappa                    : int                  # Kappa value to use for the model pre-fitting
    pre_num_iterations           : int                  # Number of Gibbs sampling iterations to run in the model pre-fitting
    ---
    model_name                   : varchar(100)         # Name of the model to be loaded if `task_mode='load'`
    task_mode='load'             :enum('trigger','load')# 'load': load computed analysis results, 'trigger': trigger computation
    pre_fit_desc=''              : varchar(1000)        # User-defined description of the pre-fitting task
    """

PreFit

Bases: Computed

Fit AR-HMM model.

Attributes:

Name Type Description
PreFitTask foreign key)

PreFitTask Key.

model_name varchar)

Name of the model as "kpms_project_output_dir/model_name".

pre_fit_duration float)

Time duration (seconds) of the model fitting computation.

Source code in element_moseq/moseq_train.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
@schema
class PreFit(dj.Computed):
    """Fit AR-HMM model.

    Attributes:
        PreFitTask (foreign key)                : `PreFitTask` Key.
        model_name (varchar)                    : Name of the model as "kpms_project_output_dir/model_name".
        pre_fit_duration (float)                : Time duration (seconds) of the model fitting computation.
    """

    definition = """
    -> PreFitTask                               # `PreFitTask` Key
    ---
    model_name=''                : varchar(100) # Name of the model as "kpms_project_output_dir/model_name"
    pre_fit_duration=NULL        : float        # Time duration (seconds) of the model fitting computation
    """

    def make(self, key):
        """
        Make function to fit the AR-HMM model using the latent trajectory defined by `model['states']['x'].

        Args:
            key (dict) : dictionary with the `PreFitTask` Key.

        Raises:

        High-level Logic:
        1. Fetch the `kpms_project_output_dir` and define the full path.
        2. Fetch the model parameters from the `PreFitTask` table.
        3. Update the `dj_config.yml` with the latent dimension and kappa for the AR-HMM fitting.
        4. Load the pca model from file in the `kpms_project_output_dir`.
        5. Fetch `coordinates` and `confidences` scores to format the data for the model initialization. \
            # Data - contains the data for model fitting. \
            # Metadata - contains the recordings and start/end frames for the data.
        6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed.
        7. Update the model dict with the selected kappa for the AR-HMM fitting.
        8. Fit the AR-HMM model using the `pre_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file.
        9. Calculate the duration of the model fitting computation and insert it in the `PreFit` table.
        """
        from keypoint_moseq import (
            load_pca,
            format_data,
            init_model,
            update_hypparams,
            fit_model,
        )

        kpms_processed = moseq_infer.get_kpms_processed_data_dir()

        kpms_project_output_dir = find_full_path(
            kpms_processed, (PCATask & key).fetch1("kpms_project_output_dir")
        )

        pre_latent_dim, pre_kappa, pre_num_iterations, task_mode, model_name = (
            PreFitTask & key
        ).fetch1(
            "pre_latent_dim",
            "pre_kappa",
            "pre_num_iterations",
            "task_mode",
            "model_name",
        )
        if task_mode == "trigger":
            kpms_dj_config = kpms_reader.load_kpms_dj_config(
                kpms_project_output_dir.as_posix(),
                check_if_valid=True,
                build_indexes=True,
            )

            kpms_dj_config.update(
                dict(latent_dim=int(pre_latent_dim), kappa=float(pre_kappa))
            )
            kpms_reader.generate_kpms_dj_config(
                kpms_project_output_dir.as_posix(), **kpms_dj_config
            )

            pca_path = kpms_project_output_dir / "pca.p"
            if pca_path:
                pca = load_pca(kpms_project_output_dir.as_posix())
            else:
                raise FileNotFoundError(
                    f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}"
                )

            coordinates, confidences = (PCAPrep & key).fetch1(
                "coordinates", "confidences"
            )
            data, metadata = format_data(coordinates, confidences, **kpms_dj_config)

            model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config)

            model = update_hypparams(
                model, kappa=float(pre_kappa), latent_dim=int(pre_latent_dim)
            )

            start_time = datetime.now()
            model, model_name = fit_model(
                model=model,
                data=data,
                metadata=metadata,
                project_dir=kpms_project_output_dir.as_posix(),
                ar_only=True,
                num_iters=pre_num_iterations,
            )
            end_time = datetime.now()

            duration_seconds = (end_time - start_time).total_seconds()
        else:
            duration_seconds = None

        self.insert1(
            {
                **key,
                "model_name": (
                    kpms_project_output_dir.relative_to(kpms_processed) / model_name
                ).as_posix(),
                "pre_fit_duration": duration_seconds,
            }
        )

make(key)

Make function to fit the AR-HMM model using the latent trajectory defined by `model['states']['x'].

Parameters:

Name Type Description Default
key dict)

dictionary with the PreFitTask Key.

required

Raises:

High-level Logic: 1. Fetch the kpms_project_output_dir and define the full path. 2. Fetch the model parameters from the PreFitTask table. 3. Update the dj_config.yml with the latent dimension and kappa for the AR-HMM fitting. 4. Load the pca model from file in the kpms_project_output_dir. 5. Fetch coordinates and confidences scores to format the data for the model initialization. # Data - contains the data for model fitting. # Metadata - contains the recordings and start/end frames for the data. 6. Initialize the model that create a model dict containing states, parameters, hyperparameters, noise prior, and random seed. 7. Update the model dict with the selected kappa for the AR-HMM fitting. 8. Fit the AR-HMM model using the pre_num_iterations and create a subdirectory in kpms_project_output_dir with the model's latest checkpoint file. 9. Calculate the duration of the model fitting computation and insert it in the PreFit table.

Source code in element_moseq/moseq_train.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
def make(self, key):
    """
    Make function to fit the AR-HMM model using the latent trajectory defined by `model['states']['x'].

    Args:
        key (dict) : dictionary with the `PreFitTask` Key.

    Raises:

    High-level Logic:
    1. Fetch the `kpms_project_output_dir` and define the full path.
    2. Fetch the model parameters from the `PreFitTask` table.
    3. Update the `dj_config.yml` with the latent dimension and kappa for the AR-HMM fitting.
    4. Load the pca model from file in the `kpms_project_output_dir`.
    5. Fetch `coordinates` and `confidences` scores to format the data for the model initialization. \
        # Data - contains the data for model fitting. \
        # Metadata - contains the recordings and start/end frames for the data.
    6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed.
    7. Update the model dict with the selected kappa for the AR-HMM fitting.
    8. Fit the AR-HMM model using the `pre_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file.
    9. Calculate the duration of the model fitting computation and insert it in the `PreFit` table.
    """
    from keypoint_moseq import (
        load_pca,
        format_data,
        init_model,
        update_hypparams,
        fit_model,
    )

    kpms_processed = moseq_infer.get_kpms_processed_data_dir()

    kpms_project_output_dir = find_full_path(
        kpms_processed, (PCATask & key).fetch1("kpms_project_output_dir")
    )

    pre_latent_dim, pre_kappa, pre_num_iterations, task_mode, model_name = (
        PreFitTask & key
    ).fetch1(
        "pre_latent_dim",
        "pre_kappa",
        "pre_num_iterations",
        "task_mode",
        "model_name",
    )
    if task_mode == "trigger":
        kpms_dj_config = kpms_reader.load_kpms_dj_config(
            kpms_project_output_dir.as_posix(),
            check_if_valid=True,
            build_indexes=True,
        )

        kpms_dj_config.update(
            dict(latent_dim=int(pre_latent_dim), kappa=float(pre_kappa))
        )
        kpms_reader.generate_kpms_dj_config(
            kpms_project_output_dir.as_posix(), **kpms_dj_config
        )

        pca_path = kpms_project_output_dir / "pca.p"
        if pca_path:
            pca = load_pca(kpms_project_output_dir.as_posix())
        else:
            raise FileNotFoundError(
                f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}"
            )

        coordinates, confidences = (PCAPrep & key).fetch1(
            "coordinates", "confidences"
        )
        data, metadata = format_data(coordinates, confidences, **kpms_dj_config)

        model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config)

        model = update_hypparams(
            model, kappa=float(pre_kappa), latent_dim=int(pre_latent_dim)
        )

        start_time = datetime.now()
        model, model_name = fit_model(
            model=model,
            data=data,
            metadata=metadata,
            project_dir=kpms_project_output_dir.as_posix(),
            ar_only=True,
            num_iters=pre_num_iterations,
        )
        end_time = datetime.now()

        duration_seconds = (end_time - start_time).total_seconds()
    else:
        duration_seconds = None

    self.insert1(
        {
            **key,
            "model_name": (
                kpms_project_output_dir.relative_to(kpms_processed) / model_name
            ).as_posix(),
            "pre_fit_duration": duration_seconds,
        }
    )

FullFitTask

Bases: Manual

Insert the parameters for the full (Keypoint-SLDS model) fitting. The full model will generally require a lower value of kappa to yield the same target syllable durations.

Attributes:

Name Type Description
PCAFit foreign key)

PCAFit Key.

full_latent_dim int)

Latent dimension to use for the model full fitting.

full_kappa int)

Kappa value to use for the model full fitting.

full_num_iterations int)

Number of Gibbs sampling iterations to run in the model full fitting.

full_fit_desc(varchar)

User-defined description of the model full fitting task.

Source code in element_moseq/moseq_train.py
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
@schema
class FullFitTask(dj.Manual):
    """Insert the parameters for the full (Keypoint-SLDS model) fitting.
       The full model will generally require a lower value of kappa to yield the same target syllable durations.

    Attributes:
        PCAFit (foreign key)                 : `PCAFit` Key.
        full_latent_dim (int)                : Latent dimension to use for the model full fitting.
        full_kappa (int)                     : Kappa value to use for the model full fitting.
        full_num_iterations (int)            : Number of Gibbs sampling iterations to run in the model full fitting.
        full_fit_desc(varchar)               : User-defined description of the model full fitting task.

    """

    definition = """
    -> PCAFit                                           # `PCAFit` Key
    full_latent_dim              : int                  # Latent dimension to use for the model full fitting
    full_kappa                   : int                  # Kappa value to use for the model full fitting
    full_num_iterations          : int                  # Number of Gibbs sampling iterations to run in the model full fitting
    ---
    model_name                   : varchar(100)         # Name of the model to be loaded if `task_mode='load'`
    task_mode='load'             :enum('load','trigger')# Trigger or load the task
    full_fit_desc=''             : varchar(1000)        # User-defined description of the model full fitting task   
    """

FullFit

Bases: Computed

Fit the full (Keypoint-SLDS) model.

Attributes:

Name Type Description
FullFitTask foreign key)

FullFitTask Key.

model_name

varchar(100) # Name of the model as "kpms_project_output_dir/model_name"

full_fit_duration float)

Time duration (seconds) of the full fitting computation

Source code in element_moseq/moseq_train.py
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
@schema
class FullFit(dj.Computed):
    """Fit the full (Keypoint-SLDS) model.

    Attributes:
        FullFitTask (foreign key)            : `FullFitTask` Key.
        model_name                           : varchar(100) # Name of the model as "kpms_project_output_dir/model_name"
        full_fit_duration (float)            : Time duration (seconds) of the full fitting computation
    """

    definition = """
    -> FullFitTask                               # `FullFitTask` Key
    ---
    model_name                    : varchar(100) # Name of the model as "kpms_project_output_dir/model_name"
    full_fit_duration=NULL        : float        # Time duration (seconds) of the full fitting computation 
    """

    def make(self, key):
        """
            Make function to fit the full (keypoint-SLDS) model

            Args:
                key (dict): dictionary with the `FullFitTask` Key.

            Raises:

            High-level Logic:
            1. Fetch the `kpms_project_output_dir` and define the full path.
            2. Fetch the model parameters from the `FullFitTask` table.
            2. Update the `dj_config.yml` with the selected latent dimension and kappa for the full-fitting.
            3. Initialize and fit the full model in a new `model_name` directory.
            4. Load the pca model from file in the `kpms_project_output_dir`.
            5. Fetch the `coordinates` and `confidences` scores to format the data for the model initialization.
            6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed.
            7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting.
            8. Fit the Keypoint-SLDS model using the `full_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file.
            8. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. \
                This function permutes the states and parameters of a saved checkpoint so that syllables are labeled \
                in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on).
            8. Calculate the duration of the model fitting computation and insert it in the `PreFit` table.
        """
        from keypoint_moseq import (
            load_pca,
            format_data,
            init_model,
            update_hypparams,
            fit_model,
            reindex_syllables_in_checkpoint,
        )

        kpms_processed = moseq_infer.get_kpms_processed_data_dir()

        kpms_project_output_dir = find_full_path(
            kpms_processed, (PCATask & key).fetch1("kpms_project_output_dir")
        )

        full_latent_dim, full_kappa, full_num_iterations, task_mode, model_name = (
            FullFitTask & key
        ).fetch1(
            "full_latent_dim",
            "full_kappa",
            "full_num_iterations",
            "task_mode",
            "model_name",
        )
        if task_mode == "trigger":
            kpms_dj_config = kpms_reader.load_kpms_dj_config(
                kpms_project_output_dir.as_posix(),
                check_if_valid=True,
                build_indexes=True,
            )
            kpms_dj_config.update(
                dict(latent_dim=int(full_latent_dim), kappa=float(full_kappa))
            )
            kpms_reader.generate_kpms_dj_config(
                kpms_project_output_dir.as_posix(), **kpms_dj_config
            )

            pca_path = kpms_project_output_dir / "pca.p"
            if pca_path:
                pca = load_pca(kpms_project_output_dir.as_posix())
            else:
                raise FileNotFoundError(
                    f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}"
                )

            coordinates, confidences = (PCAPrep & key).fetch1(
                "coordinates", "confidences"
            )
            data, metadata = format_data(coordinates, confidences, **kpms_dj_config)
            model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config)
            model = update_hypparams(
                model, kappa=float(full_kappa), latent_dim=int(full_latent_dim)
            )

            start_time = datetime.utcnow()
            model, model_name = fit_model(
                model=model,
                data=data,
                metadata=metadata,
                project_dir=kpms_project_output_dir.as_posix(),
                ar_only=False,
                num_iters=full_num_iterations,
            )
            end_time = datetime.utcnow()
            duration_seconds = (end_time - start_time).total_seconds()

            reindex_syllables_in_checkpoint(
                kpms_project_output_dir.as_posix(), Path(model_name).parts[-1]
            )
        else:
            duration_seconds = None

        self.insert1(
            {
                **key,
                "model_name": (
                    kpms_project_output_dir.relative_to(kpms_processed) / model_name
                ).as_posix(),
                "full_fit_duration": duration_seconds,
            }
        )

make(key)

Make function to fit the full (keypoint-SLDS) model

Parameters:

Name Type Description Default
key dict

dictionary with the FullFitTask Key.

required

Raises:

High-level Logic: 1. Fetch the kpms_project_output_dir and define the full path. 2. Fetch the model parameters from the FullFitTask table. 2. Update the dj_config.yml with the selected latent dimension and kappa for the full-fitting. 3. Initialize and fit the full model in a new model_name directory. 4. Load the pca model from file in the kpms_project_output_dir. 5. Fetch the coordinates and confidences scores to format the data for the model initialization. 6. Initialize the model that create a model dict containing states, parameters, hyperparameters, noise prior, and random seed. 7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting. 8. Fit the Keypoint-SLDS model using the full_num_iterations and create a subdirectory in kpms_project_output_dir with the model's latest checkpoint file. 8. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. This function permutes the states and parameters of a saved checkpoint so that syllables are labeled in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on). 8. Calculate the duration of the model fitting computation and insert it in the PreFit table.

Source code in element_moseq/moseq_train.py
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
def make(self, key):
    """
        Make function to fit the full (keypoint-SLDS) model

        Args:
            key (dict): dictionary with the `FullFitTask` Key.

        Raises:

        High-level Logic:
        1. Fetch the `kpms_project_output_dir` and define the full path.
        2. Fetch the model parameters from the `FullFitTask` table.
        2. Update the `dj_config.yml` with the selected latent dimension and kappa for the full-fitting.
        3. Initialize and fit the full model in a new `model_name` directory.
        4. Load the pca model from file in the `kpms_project_output_dir`.
        5. Fetch the `coordinates` and `confidences` scores to format the data for the model initialization.
        6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed.
        7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting.
        8. Fit the Keypoint-SLDS model using the `full_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file.
        8. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. \
            This function permutes the states and parameters of a saved checkpoint so that syllables are labeled \
            in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on).
        8. Calculate the duration of the model fitting computation and insert it in the `PreFit` table.
    """
    from keypoint_moseq import (
        load_pca,
        format_data,
        init_model,
        update_hypparams,
        fit_model,
        reindex_syllables_in_checkpoint,
    )

    kpms_processed = moseq_infer.get_kpms_processed_data_dir()

    kpms_project_output_dir = find_full_path(
        kpms_processed, (PCATask & key).fetch1("kpms_project_output_dir")
    )

    full_latent_dim, full_kappa, full_num_iterations, task_mode, model_name = (
        FullFitTask & key
    ).fetch1(
        "full_latent_dim",
        "full_kappa",
        "full_num_iterations",
        "task_mode",
        "model_name",
    )
    if task_mode == "trigger":
        kpms_dj_config = kpms_reader.load_kpms_dj_config(
            kpms_project_output_dir.as_posix(),
            check_if_valid=True,
            build_indexes=True,
        )
        kpms_dj_config.update(
            dict(latent_dim=int(full_latent_dim), kappa=float(full_kappa))
        )
        kpms_reader.generate_kpms_dj_config(
            kpms_project_output_dir.as_posix(), **kpms_dj_config
        )

        pca_path = kpms_project_output_dir / "pca.p"
        if pca_path:
            pca = load_pca(kpms_project_output_dir.as_posix())
        else:
            raise FileNotFoundError(
                f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}"
            )

        coordinates, confidences = (PCAPrep & key).fetch1(
            "coordinates", "confidences"
        )
        data, metadata = format_data(coordinates, confidences, **kpms_dj_config)
        model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config)
        model = update_hypparams(
            model, kappa=float(full_kappa), latent_dim=int(full_latent_dim)
        )

        start_time = datetime.utcnow()
        model, model_name = fit_model(
            model=model,
            data=data,
            metadata=metadata,
            project_dir=kpms_project_output_dir.as_posix(),
            ar_only=False,
            num_iters=full_num_iterations,
        )
        end_time = datetime.utcnow()
        duration_seconds = (end_time - start_time).total_seconds()

        reindex_syllables_in_checkpoint(
            kpms_project_output_dir.as_posix(), Path(model_name).parts[-1]
        )
    else:
        duration_seconds = None

    self.insert1(
        {
            **key,
            "model_name": (
                kpms_project_output_dir.relative_to(kpms_processed) / model_name
            ).as_posix(),
            "full_fit_duration": duration_seconds,
        }
    )