Skip to content

utils.py

General-purpose utilities

user_choice(prompt, choices=('yes', 'no'), default=None)

Prompts the user for confirmation. The default value, if any, is capitalized.

Parameters:

Name Type Description Default
prompt

Information to display to the user.

required
choices

an iterable of possible choices.

('yes', 'no')
default

default choice

None

Returns:

Type Description

the user's choice

Source code in datajoint/utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def user_choice(prompt, choices=("yes", "no"), default=None):
    """
    Prompts the user for confirmation.  The default value, if any, is capitalized.

    :param prompt: Information to display to the user.
    :param choices: an iterable of possible choices.
    :param default: default choice
    :return: the user's choice
    """
    assert default is None or default in choices
    choice_list = ", ".join(
        (choice.title() if choice == default else choice for choice in choices)
    )
    response = None
    while response not in choices:
        response = input(prompt + " [" + choice_list + "]: ")
        response = response.lower() if response else default
    return response

get_master(full_table_name)

If the table name is that of a part table, then return what the master table name would be. This follows DataJoint's table naming convention where a master and a part must be in the same schema and the part table is prefixed with the master table name + __.

Example: ephys.session -- master ephys.session__recording -- part

Parameters:

Name Type Description Default
full_table_name str

Full table name including part.

required

Returns:

Type Description
str

Supposed master full table name or empty string if not a part table name.

Source code in datajoint/utils.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_master(full_table_name: str) -> str:
    """
    If the table name is that of a part table, then return what the master table name would be.
    This follows DataJoint's table naming convention where a master and a part must be in the
    same schema and the part table is prefixed with the master table name + ``__``.

    Example:
       `ephys`.`session`    -- master
       `ephys`.`session__recording`  -- part

    :param full_table_name: Full table name including part.
    :type full_table_name: str
    :return: Supposed master full table name or empty string if not a part table name.
    :rtype: str
    """
    match = re.match(r"(?P<master>`\w+`.`\w+)__(?P<part>\w+)`", full_table_name)
    return match["master"] + "`" if match else ""

to_camel_case(s)

Convert names with under score (_) separation into camel case names.

Parameters:

Name Type Description Default
s

string in under_score notation

required

Returns:

Type Description

string in CamelCase notation Example: >>> to_camel_case("table_name") # returns "TableName"

Source code in datajoint/utils.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def to_camel_case(s):
    """
    Convert names with under score (_) separation into camel case names.

    :param s: string in under_score notation
    :returns: string in CamelCase notation
    Example:
    >>> to_camel_case("table_name")  # returns "TableName"
    """

    def to_upper(match):
        return match.group(0)[-1].upper()

    return re.sub(r"(^|[_\W])+[a-zA-Z]", to_upper, s)

from_camel_case(s)

Convert names in camel case into underscore (_) separated names

Parameters:

Name Type Description Default
s

string in CamelCase notation

required

Returns:

Type Description

string in under_score notation Example: >>> from_camel_case("TableName") # yields "table_name"

Source code in datajoint/utils.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def from_camel_case(s):
    """
    Convert names in camel case into underscore (_) separated names

    :param s: string in CamelCase notation
    :returns: string in under_score notation
    Example:
    >>> from_camel_case("TableName") # yields "table_name"
    """

    def convert(match):
        return ("_" if match.groups()[0] else "") + match.group(0).lower()

    if not re.match(r"[A-Z][a-zA-Z0-9]*", s):
        raise DataJointError(
            "ClassName must be alphanumeric in CamelCase, begin with a capital letter"
        )
    return re.sub(r"(\B[A-Z])|(\b[A-Z])", convert, s)

safe_write(filepath, blob)

A two-step write.

Parameters:

Name Type Description Default
filename

full path

required
blob

binary data

required
Source code in datajoint/utils.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def safe_write(filepath, blob):
    """
    A two-step write.

    :param filename: full path
    :param blob: binary data
    """
    filepath = Path(filepath)
    if not filepath.is_file():
        filepath.parent.mkdir(parents=True, exist_ok=True)
        temp_file = filepath.with_suffix(filepath.suffix + ".saving")
        temp_file.write_bytes(blob)
        temp_file.rename(filepath)

safe_copy(src, dest, overwrite=False)

Copy the contents of src file into dest file as a two-step process. Skip if dest exists already

Source code in datajoint/utils.py
107
108
109
110
111
112
113
114
115
116
def safe_copy(src, dest, overwrite=False):
    """
    Copy the contents of src file into dest file as a two-step process. Skip if dest exists already
    """
    src, dest = Path(src), Path(dest)
    if not (dest.exists() and src.samefile(dest)) and (overwrite or not dest.is_file()):
        dest.parent.mkdir(parents=True, exist_ok=True)
        temp_file = dest.with_suffix(dest.suffix + ".copying")
        shutil.copyfile(str(src), str(temp_file))
        temp_file.rename(dest)

parse_sql(filepath)

yield SQL statements from an SQL file

Source code in datajoint/utils.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def parse_sql(filepath):
    """
    yield SQL statements from an SQL file
    """
    delimiter = ";"
    statement = []
    with Path(filepath).open("rt") as f:
        for line in f:
            line = line.strip()
            if not line.startswith("--") and len(line) > 1:
                if line.startswith("delimiter"):
                    delimiter = line.split()[1]
                else:
                    statement.append(line)
                    if line.endswith(delimiter):
                        yield " ".join(statement)
                        statement = []