Разработка оценщиков scikit-learn#

Предлагаете ли вы оценщик для включения в scikit-learn, разрабатываете отдельный пакет, совместимый с scikit-learn, или реализуете пользовательские компоненты для своих проектов, эта глава подробно описывает, как разрабатывать объекты, которые безопасно взаимодействуют с конвейерами scikit-learn и инструментами выбора моделей.

В этом разделе подробно описывается публичный API, который следует использовать и реализовывать для совместимого с scikit-learn оценщика. Внутри scikit-learn мы экспериментируем и используем некоторые частные инструменты, и наша цель — всегда сделать их публичными, как только они станут достаточно стабильными, чтобы вы также могли использовать их в своих проектах.

API объектов scikit-learn#

Существует два основных типа оценщиков. Первую группу можно рассматривать как простые оценщики, которые включают большинство оценщиков, таких как LogisticRegression или RandomForestClassifier. И вторая группа — это мета-оценщики, которые являются оценщиками, обёртывающими другие оценщики. Pipeline и GridSearchCV являются двумя примерами мета-оценщиков.

Здесь мы начнем с нескольких терминов словаря, а затем проиллюстрируем, как вы можете реализовать свои собственные оценщики.

Элементы API scikit-learn более определенно описаны в Глоссарий общих терминов и элементов API.

Различные объекты#

Основные объекты в scikit-learn (один класс может реализовывать несколько интерфейсов):

Оценщик:

Базовый объект, реализует fit метод обучения на данных, либо:

estimator = estimator.fit(data, targets)

или:

estimator = estimator.fit(data)
Предиктор:

Для обучения с учителем или некоторых задач без учителя реализует:

prediction = predictor.predict(data)

Алгоритмы классификации обычно также предлагают способ количественной оценки уверенности в предсказании, либо с использованием decision_function или predict_proba:

probability = predictor.predict_proba(data)
Трансформер:

Для изменения данных контролируемым или неконтролируемым способом (например, путем добавления, изменения или удаления столбцов, но не путем добавления или удаления строк). Реализует:

new_data = transformer.transform(data)

Когда подгонка и преобразование могут быть выполнены гораздо эффективнее вместе, чем отдельно, реализует:

new_data = transformer.fit_transform(data)
Модель:

Модель, которая может дать качество соответствия метрика или вероятность непросмотренных данных, реализует (чем выше, тем лучше):

score = model.score(data)

Оценщики#

API имеет один основной объект: оценщик. Оценщик — это объект, который обучает модель на основе некоторых обучающих данных и способен выводить некоторые свойства на новых данных. Это может быть, например, классификатор или регрессор. Все оценщики реализуют метод fit:

estimator.fit(X, y)

Из всех методов, которые реализует оценщик, fit обычно тот, который вы хотите реализовать самостоятельно. Другие методы, такие как set_params, get_params, и т.д. реализованы в BaseEstimator, от которого вы должны наследовать. Возможно, вам потребуется наследовать от большего количества примесей, что мы объясним позже.

Создание экземпляра#

Это касается создания объекта. Объекта __init__ метод может принимать константы в качестве аргументов, определяющих поведение оценщика (например, alpha константа в SGDClassifier). Однако он не должен принимать фактические обучающие данные в качестве аргумента, так как это оставлено для fit() method:

clf2 = SGDClassifier(alpha=2.3)
clf3 = SGDClassifier([[1, 2], [2, 3]], [-1, 1]) # WRONG!

В идеале аргументы, принимаемые __init__ должны быть аргументами ключевых слов со значением по умолчанию. Другими словами, пользователь должен иметь возможность создать экземпляр оценщика без передачи каких-либо аргументов. В некоторых случаях, когда нет разумных значений по умолчанию для аргумента, они могут быть оставлены без значения по умолчанию. В самом scikit-learn у нас очень мало мест, только в некоторых мета-оценщиках, где аргумент под-оценщика(ов) является обязательным аргументом.

Большинство аргументов соответствуют гиперпараметрам, описывающим модель или задачу оптимизации, которую решает оценщик. Другие параметры могут определять поведение оценщика, например, определять местоположение кэша для хранения некоторых данных. Эти начальные аргументы (или параметры) всегда запоминаются оценщиком. Также обратите внимание, что они не должны документироваться в разделе «Атрибуты», а скорее в разделе «Параметры» для этого оценщика.

Кроме того, каждый ключевой аргумент, принимаемый __init__ должен соответствовать атрибуту экземпляра. Scikit-learn полагается на это, чтобы найти соответствующие атрибуты для установки на оценщике при выполнении выбора модели.

Таким образом, __init__ должен выглядеть как:

def __init__(self, param1=1, param2=2):
    self.param1 = param1
    self.param2 = param2

Не должно быть никакой логики, даже проверки входных данных, и параметры не должны изменяться; что также означает, что в идеале они не должны быть изменяемыми объектами, такими как списки или словари. Если они изменяемы, их следует скопировать перед изменением. Соответствующая логика должна быть размещена там, где используются параметры, обычно в fit. Следующее неверно:

def __init__(self, param1=1, param2=2, param3=3):
    # WRONG: parameters should not be modified
    if param1 > 1:
        param2 += 1
    self.param1 = param1
    # WRONG: the object's attributes should have exactly the name of
    # the argument in the constructor
    self.param3 = param2

Причина отложенной проверки заключается в том, что если __init__ включает проверку входных данных, тогда та же проверка должна быть выполнена в set_params, который используется в алгоритмах, таких как GridSearchCV.

Также ожидается, что параметры с завершающими _ являются не должен быть установлен внутри __init__ метод. Подробнее об атрибутах, которые не являются аргументами инициализации, будет рассказано далее.

Обучение#

Следующее, что вы, вероятно, захотите сделать, это оценить некоторые параметры в модели. Это реализовано в fit() метод, и здесь происходит обучение. Например, здесь выполняется вычисление для обучения или оценки коэффициентов линейной модели.

The fit() метод принимает обучающие данные в качестве аргументов, которые могут быть одним массивом в случае обучения без учителя или двумя массивами в случае обучения с учителем. Другие метаданные, которые идут с обучающими данными, такие как sample_weight, также может быть передан в fit ), а не просто

Обратите внимание, что модель обучается с использованием X и y, но объект не содержит ссылки на X и y. Однако есть некоторые исключения из этого, как в случае предвычисленных ядер, где эти данные должны храниться для использования методом predict.

Параметры

X

array-like формы (n_samples, n_features)

y

array-like формы (n_samples,)

kwargs

необязательные параметры, зависящие от данных

Количество образцов, т.е. X.shape[0] должен быть таким же, как y.shape[0]. Если это требование не выполняется, возникает исключение типа ValueError должно быть вызвано.

y может игнорироваться в случае обучения без учителя. Однако, чтобы сделать возможным использование оценщика как части конвейера, который может смешивать как контролируемые, так и неконтролируемые преобразователи, даже неконтролируемые оценщики должны принимать y=None аргумент ключевого слова во второй позиции, который просто игнорируется оценщиком. По той же причине, fit_predict, fit_transform, score и partial_fit методы должны принимать y аргумент на втором месте, если они реализованы.

Метод должен возвращать объект (self). Этот шаблон полезен для возможности реализации быстрых однострочников в сессии IPython, таких как:

y_predicted = SGDClassifier(alpha=10).fit(X_train, y_train).predict(X_test)

В зависимости от природы алгоритма, fit иногда также может принимать дополнительные ключевые аргументы. Однако любой параметр, которому можно присвоить значение до доступа к данным, должен быть __init__ аргумент ключевого слова. В идеале, параметры подгонки должны быть ограничены переменными, непосредственно зависящими от данных. Например, матрица Грама или матрица сходства, которые предварительно вычислены из матрицы данных X зависят от данных. Критерий остановки по допуску tol не зависит напрямую от данных (хотя оптимальное значение согласно некоторой функции оценки, вероятно, зависит).

Когда fit вызывается, любой предыдущий вызов fit должно игнорироваться. В общем случае вызов estimator.fit(X1) и затем estimator.fit(X2) должен быть таким же, как просто вызов estimator.fit(X2). Однако на практике это может быть не так, когда fit зависит от некоторого случайного процесса, см. random_state. Другое исключение из этого правила — когда гиперпараметр warm_start установлено в True для оценщиков, которые поддерживают это. warm_start=True означает, что предыдущее состояние обучаемых параметров оценщика повторно используется вместо использования стратегии инициализации по умолчанию.

Оцененные атрибуты#

В соответствии с соглашениями scikit-learn, атрибуты, которые вы хотите предоставить пользователям как публичные атрибуты и которые были оценены или изучены из данных, всегда должны иметь имя, оканчивающееся на подчеркивание, например, коэффициенты некоторого регрессионного оценщика будут храниться в coef_ атрибут после fit был вызван. Аналогично, атрибуты, которые вы изучаете в процессе и хотите сохранить, но не раскрывать пользователю, должны иметь ведущее подчеркивание, например _intermediate_coefs. Вам нужно задокументировать первую группу (с завершающим подчеркиванием) как «Атрибуты» и не нужно документировать вторую группу (с начальным подчеркиванием).

Ожидается, что оцененные атрибуты будут переопределены при вызове fit второй раз.

Универсальные атрибуты#

Оценщики, ожидающие табличные входные данные, должны установить n_features_in_ атрибут в fit время, чтобы указать количество признаков, которые оценщик ожидает для последующих вызовов predict или преобразовать. См. SLEP010 подробности.

Аналогично, если оценщикам предоставляются датафреймы, такие как pandas или polars, они должны установить feature_names_in_ атрибут для указания имен признаков входных данных, подробно описанный в SLEP007. Используя validate_data автоматически установит эти атрибуты для вас.

Создание собственного оценщика#

Если вы хотите реализовать новый оценщик, совместимый с scikit-learn, вам следует знать о нескольких внутренних компонентах scikit-learn в дополнение к API scikit-learn, описанному выше. Вы можете проверить, соответствует ли ваш оценщик интерфейсу и стандартам scikit-learn, запустив check_estimator на экземпляре. parametrize_with_checks декоратор pytest также может использоваться (см. его документацию для подробностей и возможных взаимодействий с pytest):

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.tree import DecisionTreeClassifier
>>> check_estimator(DecisionTreeClassifier())  # passes
[...]

Основная мотивация для создания класса, совместимого с интерфейсом оценщика scikit-learn, может заключаться в том, что вы хотите использовать его вместе с инструментами оценки и выбора моделей, такими как GridSearchCV и Pipeline.

Прежде чем детализировать требуемый интерфейс ниже, мы описываем два способа достижения правильного интерфейса более легко.

И вы можете проверить, что приведенный выше оценщик проходит все общие проверки:

>>> from sklearn.utils.estimator_checks import check_estimator
>>> check_estimator(TemplateClassifier())  # passes

get_params и set_params#

Все оценщики scikit-learn имеют get_params и set_params функций.

The get_params функция не принимает аргументов и возвращает словарь __init__ параметры оценщика вместе с их значениями.

Он принимает один аргумент ключевого слова, deep, который получает логическое значение, определяющее, должен ли метод возвращать параметры подоценщиков (актуально только для мета-оценщиков). Значение по умолчанию для deep является True. Например, рассмотрим следующий оценщик:

>>> from sklearn.base import BaseEstimator
>>> from sklearn.linear_model import LogisticRegression
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, subestimator=None, my_extra_param="random"):
...         self.subestimator = subestimator
...         self.my_extra_param = my_extra_param

Параметр deep управляет тем, будут ли параметры subestimator должны быть сообщены. Таким образом, когда deep=True, вывод будет:

>>> my_estimator = MyEstimator(subestimator=LogisticRegression())
>>> for param, value in my_estimator.get_params(deep=True).items():
...     print(f"{param} -> {value}")
my_extra_param -> random
subestimator__C -> 1.0
subestimator__class_weight -> None
subestimator__dual -> False
subestimator__fit_intercept -> True
subestimator__intercept_scaling -> 1
subestimator__l1_ratio -> 0.0
subestimator__max_iter -> 100
subestimator__n_jobs -> None
subestimator__penalty -> deprecated
subestimator__random_state -> None
subestimator__solver -> lbfgs
subestimator__tol -> 0.0001
subestimator__verbose -> 0
subestimator__warm_start -> False
subestimator -> LogisticRegression()

Если мета-оцениватель принимает несколько подоценивателей, часто эти подоцениватели имеют имена (например, именованные шаги в Pipeline объект), в этом случае ключ должен стать __C, __class_weight, и т.д.

Когда deep=False, вывод будет:

>>> for param, value in my_estimator.get_params(deep=False).items():
...     print(f"{param} -> {value}")
my_extra_param -> random
subestimator -> LogisticRegression()

С другой стороны, set_params принимает параметры __init__ в качестве аргументов ключевых слов, распаковывает их в словарь вида 'parameter': value и устанавливает параметры оценщика с помощью этого словаря. Возвращает сам оценщик.

The set_params функция используется для установки параметров во время поиска по сетке, например.

Клонирование#

Как уже упоминалось, когда аргументы конструктора изменяемы, их следует копировать перед изменением. Это также относится к аргументам конструктора, которые являются оценщиками. Вот почему мета-оценщики, такие как GridSearchCV создать копию данного оценщика перед его изменением.

Однако в scikit-learn при копировании оценщика мы получаем необученный оценщик, где копируются только аргументы конструктора (с некоторыми исключениями, например, атрибуты, связанные с определенной внутренней механикой, такой как маршрутизация метаданных).

Функция, отвечающая за это поведение, clone.

Оценщики могут настраивать поведение base.clone путем переопределения base.BaseEstimator.__sklearn_clone__ метод. __sklearn_clone__ должен возвращать экземпляр оценщика. __sklearn_clone__ полезен, когда оценщику нужно сохранить некоторое состояние, когда base.clone вызывается у оценщика. Например, FrozenEstimator использует это.

Типы оценщиков#

Среди простых оценщиков (в отличие от мета-оценщиков) наиболее распространенными типами являются преобразователи, классификаторы, регрессоры и алгоритмы кластеризации.

Transformers наследуют от TransformerMixin, и реализовать transform метод. Это оцениватели, которые принимают входные данные и преобразуют их определённым образом. Обратите внимание, что они никогда не должны изменять количество входных выборок, и выход transform должен соответствовать его входным образцам в том же заданном порядке.

Регрессоры наследуют от RegressorMixin, и реализовать predict метод. Они должны принимать числовые y в их fit метод. Регрессоры используют r2_score по умолчанию в их score метод.

Классификаторы наследуют от ClassifierMixin. Если применимо, классификаторы могут реализовывать decision_function для возврата исходных значений решений, на основе которых predict может принять решение. Если вычисление вероятностей поддерживается, классификаторы также могут реализовать predict_proba и predict_log_proba.

Классификаторы должны принимать y (целевые) аргументы для fit которые являются последовательностями (списками, массивами) либо строк, либо целых чисел. Они не должны предполагать, что метки классов — это непрерывный диапазон целых чисел; вместо этого они должны хранить список классов в classes_ атрибут или свойство. Порядок меток классов в этом атрибуте должен соответствовать порядку, в котором predict_proba, predict_log_proba и decision_function возвращают их значения. Самый простой способ добиться этого — поместить:

self.classes_, y = np.unique(y, return_inverse=True)

в fit. Это возвращает новый y который содержит индексы классов, а не метки, в диапазоне [0, n_classes).

Классификатора predict метод должен возвращать массивы, содержащие метки классов из classes_. В классификаторе, который реализует decision_function, это может быть достигнуто с помощью:

def predict(self, X):
    D = self.decision_function(X)
    return self.classes_[np.argmax(D, axis=1)]

The multiclass модуль содержит полезные функции для работы с многоклассовыми и многометочными задачами.

Алгоритмы кластеризации наследуют от ClusterMixin. В идеале они должны принимать y параметр в их fit метод, но его следует игнорировать. Алгоритмы кластеризации должны устанавливать labels_ атрибут, хранящий метки, назначенные каждому образцу. Если применимо, они также могут реализовывать predict метод, возвращающий метки, назначенные новым выборкам.

Если необходимо проверить тип данного оценщика, например, в мета-оценщике, можно проверить, реализует ли данный объект transform метод для трансформеров, а в остальных случаях использовать вспомогательные функции, такие как is_classifier или is_regressor.

Теги оценщиков#

Примечание

Scikit-learn ввел теги оценщиков в версии 0.21 как частный API и в основном использовал их в тестах. Однако эти теги со временем расширились, и многие сторонние разработчики также нуждаются в их использовании. Поэтому в версии 1.6 API для тегов был переработан и представлен как публичный API.

Теги оценщиков — это аннотации оценщиков, которые позволяют программно проверять их возможности, такие как поддержка разреженных матриц, поддерживаемые типы выходных данных и поддерживаемые методы. Теги оценщиков являются экземпляром Tags возвращаемый методом __sklearn_tags__. Эти теги используются в разных местах, например, is_regressor или общие проверки, выполняемые check_estimator и parametrize_with_checks, где теги определяют, какие проверки запускать и какие входные данные подходят. Теги могут зависеть от параметров оценщика или даже от архитектуры системы и в общем случае могут быть определены только во время выполнения, поэтому являются атрибутами экземпляра, а не атрибутами класса. См. Tags для получения дополнительной информации об отдельных тегах.

Маловероятно, что значения по умолчанию для каждого тега подойдут для вашего конкретного оценщика. Вы можете изменить значения по умолчанию, определив __sklearn_tags__() метод, который возвращает новые значения для тегов вашего оценщика. Например:

class MyMultiOutputEstimator(BaseEstimator):

    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.target_tags.single_output = False
        tags.non_deterministic = True
        return tags

Вы можете создать новый подкласс Tags если вы хотите добавить новые теги к существующему набору. Обратите внимание, что все атрибуты, которые вы добавляете в дочернем классе, должны иметь значение по умолчанию. Это может быть в форме:

from dataclasses import dataclass, fields

@dataclass
class MyTags(Tags):
    my_tag: bool = True

class MyEstimator(BaseEstimator):
    def __sklearn_tags__(self):
        tags_orig = super().__sklearn_tags__()
        as_dict = {
            field.name: getattr(tags_orig, field.name)
            for field in fields(tags_orig)
        }
        tags = MyTags(**as_dict)
        tags.my_tag = True
        return tags

API для разработчиков set_output#

С SLEP018, scikit-learn представляет set_output API для настройки преобразователей на вывод pandas DataFrames. set_output API автоматически определяется, если трансформер определяет get_feature_names_out и подклассы base.TransformerMixin. get_feature_names_out используется для получения названий столбцов вывода pandas.

base.OneToOneFeatureMixin и base.ClassNamePrefixFeaturesOutMixin являются полезными примесями для определения get_feature_names_out. base.OneToOneFeatureMixin полезен, когда преобразователь имеет взаимно однозначное соответствие между входными и выходными признаками, например, StandardScaler. base.ClassNamePrefixFeaturesOutMixin полезно, когда преобразователю нужно сгенерировать собственные имена признаков, например, PCA.

Вы можете отказаться от set_output API, установив auto_wrap_output_keys=None при определении пользовательского подкласса:

class MyTransformer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):

    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        return X
    def get_feature_names_out(self, input_features=None):
        ...

Значение по умолчанию для auto_wrap_output_keys является ("transform",), который автоматически оборачивает fit_transform и transform. TransformerMixin использует __init_subclass__ механизм для потребления auto_wrap_output_keys и передает все остальные ключевые аргументы своему суперклассу. Суперклассы __init_subclass__ должен не зависят от auto_wrap_output_keys.

Для преобразователей, которые возвращают несколько массивов в transform, автообёртка будет оборачивать только первый массив и не изменять другие массивы.

См. Введение API set_output для примера использования API.

API для разработчиков check_is_fitted#

По умолчанию check_is_fitted проверяет, есть ли в экземпляре какие-либо атрибуты с завершающим подчеркиванием, например. coef_. Оценщик может изменить поведение, реализовав __sklearn_is_fitted__ метод, не принимающий входных данных и возвращающий логическое значение. Если этот метод существует, check_is_fitted просто возвращает свой вывод.

См. __sklearn_is_fitted__ как API для разработчиков для примера использования API.

API разработчика для HTML-представления#

Предупреждение

API HTML-представления является экспериментальным, и API может изменяться.

Оценщики, наследующие от BaseEstimator отображают HTML-представление себя в интерактивных средах программирования, таких как Jupyter notebooks. Например, мы можем отобразить эту HTML-диаграмму:

from sklearn.base import BaseEstimator

BaseEstimator()

Необработанное HTML-представление получается путем вызова функции estimator_html_repr на экземпляре оценщика.

Чтобы настроить URL-ссылку на документацию оценщика (т.е. при нажатии на значок "?"), переопределите _doc_link_module и _doc_link_template атрибуты. Кроме того, вы можете предоставить _doc_link_url_param_generator метод. Установите _doc_link_module названию модуля верхнего уровня, содержащего ваш оценщик. Если значение не совпадает с названием модуля верхнего уровня, HTML-представление не будет содержать ссылку на документацию. Для оценщиков scikit-learn это значение установлено в "sklearn".

The _doc_link_template используется для построения конечного URL. По умолчанию он может содержать две переменные: estimator_module (полное имя модуля, содержащего оценщик) и estimator_name (имя класса оценщика). Если вам нужно больше переменных, вы должны реализовать _doc_link_url_param_generator метод, который должен возвращать словарь переменных и их значений. Этот словарь будет использоваться для отрисовки _doc_link_template.

Руководство по стилю кодирования#

Ниже приведены некоторые рекомендации о том, как новый код должен быть написан для включения в scikit-learn, и которые могут быть подходящими для принятия во внешних проектах. Конечно, есть особые случаи, и будут исключения из этих правил. Однако следование этим правилам при отправке нового кода облегчает рецензирование, поэтому новый код может быть интегрирован за меньшее время.

Единообразно отформатированный код упрощает совместное владение кодом. Проект scikit-learn старается строго следовать официальным рекомендациям Python, подробно изложенным в PEP8 которые подробно описывают, как должен быть отформатирован и отступлен код. Пожалуйста, прочитайте его и следуйте ему.

Кроме того, мы добавляем следующие рекомендации:

  • Используйте подчёркивания для разделения слов в неклассовых именах: n_samples вместо nsamples.

  • Избегайте нескольких операторов на одной строке. Предпочитайте перевод строки после оператора управления потоком (if/for).

  • Используйте абсолютные импорты

  • Модульные тесты должны использовать импорты точно так же, как клиентский код. Если sklearn.foo экспортирует класс или функцию, реализованную в sklearn.foo.bar.baz, тест должен импортировать его из sklearn.foo.

  • Пожалуйста, не используйте import * в любом случае. Считается вредным по мнению официальные рекомендации Python. Это затрудняет чтение кода, так как происхождение символов больше не явно указано, но что более важно, это препятствует использованию инструментов статического анализа, таких как pyflakes для автоматического поиска ошибок в scikit-learn.

  • Используйте стандарт numpy docstring во всех ваших строках документации.

Хороший пример кода, который нам нравится, можно найти здесь.

Проверка входных данных#

Модуль sklearn.utils содержит различные функции для проверки и преобразования входных данных. Иногда np.asarray достаточно для валидации; делать не использовать np.asanyarray или np.atleast_2d, поскольку они позволяют NumPy np.matrix через, который имеет другой API (например, * означает скалярное произведение на np.matrix, но произведение Адамара на np.ndarray).

В других случаях обязательно вызовите check_array на любой аргумент типа array-like, переданный в функцию API scikit-learn. Точные параметры для использования зависят в основном от того, используется ли scipy.sparse матрицы должны приниматься.

Для получения дополнительной информации обратитесь к Утилиты для разработчиков страница.

Случайные числа#

Если ваш код зависит от генератора случайных чисел, не используйте numpy.random.random() или аналогичные процедуры. Для обеспечения повторяемости при проверке ошибок процедура должна принимать ключевое слово random_state и использовать это для построения numpy.random.RandomState объект. См. sklearn.utils.check_random_state в Утилиты для разработчиков.

Вот простой пример кода, использующий некоторые из приведённых выше рекомендаций:

from sklearn.utils import check_array, check_random_state

def choose_random_sample(X, random_state=0):
    """Choose a random point from X.

    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        An array representing the data.
    random_state : int or RandomState instance, default=0
        The seed of the pseudo random number generator that selects a
        random sample. Pass an int for reproducible output across multiple
        function calls.
        See :term:`Glossary `.

    Returns
    -------
    x : ndarray of shape (n_features,)
        A random point selected from X.
    """
    X = check_array(X)
    random_state = check_random_state(random_state)
    i = random_state.randint(X.shape[0])
    return X[i]

Если вы используете случайность в оценщике вместо отдельной функции, применяются дополнительные рекомендации.

Во-первых, оценщик должен принимать random_state аргумент для его __init__ со значением по умолчанию None. Он должен сохранять значение этого аргумента, неизмененный, в атрибуте random_state. fit может вызывать check_random_state по этому атрибуту, чтобы получить настоящий генератор случайных чисел. Если по какой-то причине случайность требуется после fitГенератор случайных чисел должен храниться в атрибуте random_state_. Следующий пример должен прояснить это:

class GaussianNoise(BaseEstimator, TransformerMixin):
    """This estimator ignores its input and returns random Gaussian noise.

    It also does not adhere to all scikit-learn conventions,
    but showcases how to handle randomness.
    """

    def __init__(self, n_components=100, random_state=None):
        self.random_state = random_state
        self.n_components = n_components

    # the arguments are ignored anyway, so we make them optional
    def fit(self, X=None, y=None):
        self.random_state_ = check_random_state(self.random_state)

    def transform(self, X):
        n_samples = X.shape[0]
        return self.random_state_.randn(n_samples, self.n_components)

Причина такой настройки — воспроизводимость: когда оценщик fit дважды к одним и тем же данным, он должен производить идентичную модель оба раза, отсюда валидация в fit, а не __init__.

Числовые утверждения в тестах#

При проверке квази-равенства массивов непрерывных значений используйте sklearn.utils._testing.assert_allclose.

Относительная погрешность автоматически выводится из предоставленных типов данных массивов (особенно для типов float32 и float64), но вы можете переопределить через rtol.

При сравнении массивов с нулевыми элементами, пожалуйста, укажите ненулевое значение для абсолютного допуска через atol.

Для получения дополнительной информации обратитесь к документации sklearn.utils._testing.assert_allclose.