Я искал альтернативные способы сохранить обученную модель в PyTorch. Пока что я нашел две альтернативы.
- torch.save () для сохранения модели и torch.load () для загрузки модели.
- model.state_dict () для сохранения обученной модели и model.load_state_dict () для загрузки сохраненной модели.
Я сталкивался с этим обсуждением, где подход 2 рекомендуется по подходу 1
У меня вопрос, почему второй подход предпочтительнее? Только потому, что модули torch.nn имеют эти две функции, и нам рекомендуется их использовать?
python
serialization
deep-learning
pytorch
tensor
Васи ахмад
источник
источник
torch.save(model, f)
иtorch.save(model.state_dict(), f)
. Сохраненные файлы имеют одинаковый размер. Теперь я в замешательстве. Кроме того, я обнаружил, что использование pickle для сохранения model.state_dict () очень медленно. Я думаю, что лучший способ - использовать,torch.save(model.state_dict(), f)
так как вы управляете созданием модели, а torch обрабатывает загрузку весов модели, таким образом устраняя возможные проблемы. Ссылка: обсуждения.pytorch.org/t/saving-torch-models/838/4pickle
?Ответы:
Я нашел эту страницу в их репозитории на github, я просто вставлю сюда содержимое.
Рекомендуемый подход к сохранению модели
Существует два основных подхода к сериализации и восстановлению модели.
Первый (рекомендуется) сохраняет и загружает только параметры модели:
Тогда позже:
Вторая сохраняет и загружает всю модель:
Тогда позже:
Однако в этом случае сериализованные данные привязываются к конкретным классам и конкретной используемой структуре каталогов, поэтому они могут ломаться различными способами при использовании в других проектах или после некоторых серьезных рефакторингов.
источник
pickle
?Это зависит от того, что вы хотите сделать.
Случай № 1: Сохраните модель, чтобы использовать ее для вывода : вы сохраняете модель, восстанавливаете ее, а затем переводите модель в режим оценки. Это сделано потому, что у вас обычно есть
BatchNorm
иDropout
слои, которые по умолчанию находятся в режиме поезда на строительстве:Случай № 2. Сохранение модели для возобновления обучения позже . Если вам нужно продолжить обучение модели, которую вы собираетесь сохранить, вам нужно сохранить больше, чем просто модель. Вам также нужно сохранить состояние оптимизатора, эпох, счета и т. Д. Вы бы сделали это так:
Чтобы возобновить обучение, вы должны сделать что-то вроде:,
state = torch.load(filepath)
а затем, чтобы восстановить состояние каждого отдельного объекта, что-то вроде этого:Так как вы возобновляете обучение, НЕ звоните, как
model.eval()
только вы восстанавливаете состояния при загрузке.Случай № 3: Модель для использования кем-то другим, не имеющим доступа к вашему коду : В Tensorflow вы можете создать
.pb
файл, который определяет как архитектуру, так и вес модели. Это очень удобно, особенно при использованииTensorflow serve
. Эквивалентный способ сделать это в Pytorch будет:Этот способ по-прежнему не является пуленепробиваемым, и поскольку в pytorch все еще происходит множество изменений, я бы не рекомендовал это делать.
источник
torch.load
возвращается только OrderedDict. Как вы получаете модель для того, чтобы делать прогнозы?Библиотека Python pickle реализует двоичные протоколы для сериализации и десериализации объекта Python.
Когда вы
import torch
(или когда вы используете PyTorch) это будетimport pickle
для вас, и вам не нужно вызыватьpickle.dump()
иpickle.load()
напрямую, которые являются методами для сохранения и загрузки объекта.На самом деле
torch.save()
иtorch.load()
завернуpickle.dump()
иpickle.load()
для вас.А
state_dict
другой упомянутый ответ заслуживает еще несколько заметок.Что
state_dict
у нас внутри PyTorch? На самом деле есть дваstate_dict
с.Модель PyTorch
torch.nn.Module
имеетmodel.parameters()
вызов для получения обучаемых параметров (w и b). Эти обучаемые параметры, однажды установленные случайным образом, будут обновляться с течением времени по мере нашего обучения. Изучаемые параметры являются первымиstate_dict
.Второе
state_dict
- это диктатор состояния оптимизатора. Вы помните, что оптимизатор используется для улучшения наших усваиваемых параметров. Но оптимизаторstate_dict
исправлен. Там нечему учиться.Поскольку
state_dict
объекты являются словарями Python, их можно легко сохранять, обновлять, изменять и восстанавливать, добавляя большую модульность моделям и оптимизаторам PyTorch.Давайте создадим супер простую модель, чтобы объяснить это:
Этот код выведет следующее:
Обратите внимание, что это минимальная модель. Вы можете попробовать добавить стек последовательных
Обратите внимание, что только слои с усваиваемыми параметрами (сверточные слои, линейные слои и т. Д.) И зарегистрированные буферы (слои группового набора) имеют записи в модели
state_dict
.Неизучаемые вещи принадлежат объекту оптимизатора
state_dict
, который содержит информацию о состоянии оптимизатора, а также используемые гиперпараметры.Остальная часть истории такая же; на этапе вывода (это этап, когда мы используем модель после обучения) для прогнозирования; мы делаем прогноз на основе параметров, которые мы узнали. Поэтому для вывода нам просто нужно сохранить параметры
model.state_dict()
.И использовать позже model.load_state_dict (torch.load (filepath)) model.eval ()
Примечание: не забывайте последнюю строку,
model.eval()
это важно после загрузки модели.Также не пытайтесь сохранить
torch.save(model.parameters(), filepath)
. Этоmodel.parameters()
просто объект генератора.С другой стороны,
torch.save(model, filepath)
сохраняет сам объект модели, но имейте в виду, что модель не имеет оптимизатораstate_dict
. Посмотрите другой отличный ответ @Jadiel de Armas, чтобы сохранить информацию о состоянии оптимизатора.источник
Общепринятым соглашением PyTorch является сохранение моделей с использованием расширения файлов .pt или .pth.
Сохранить / загрузить всю модель Сохранить:
Загрузить:
Класс модели должен быть определен где-то
источник
Если вы хотите сохранить модель и хотите продолжить обучение позже:
Одиночный графический процессор: Сохранить:
Загрузить:
Несколько GPU: Сохранить
Загрузить:
источник