Из того, что я собрал до сих пор, существует несколько различных способов сбросить график TensorFlow в файл и затем загрузить его в другую программу, но я не смог найти четких примеров / информации о том, как они работают. Я уже знаю следующее:
- Сохраните переменные модели в файл контрольной точки (.ckpt) с помощью
tf.train.Saver()
и восстановите их позже ( источник ) - Сохраните модель в файл .pb и загрузите ее обратно с помощью
tf.train.write_graph()
иtf.import_graph_def()
( источник ) - Загрузите модель из файла .pb, переобучите ее и выгружайте в новый файл .pb с помощью Bazel ( источник )
- Заморозьте график, чтобы сохранить график и веса вместе ( источник )
- Используйте
as_graph_def()
для сохранения модели, а для весов / переменных сопоставьте их с константами ( источник )
Однако мне не удалось прояснить несколько вопросов, касающихся этих различных методов:
- Что касается файлов контрольных точек, они сохраняют только обученные веса модели? Могут ли файлы контрольных точек быть загружены в новую программу и использоваться для запуска модели, или они просто служат в качестве способов сохранения весов в модели в определенное время / этап?
- Что касается
tf.train.write_graph()
, сохраняются ли веса / переменные? - Что касается Базела, может он только для переподготовки сохранять / загружать из .pb файлов? Есть ли простая команда Bazel для выгрузки графика в .pb?
- Что касается замораживания, можно ли загрузить замороженный график с помощью
tf.import_graph_def()
? - Демонстрация Android для TensorFlow загружается в модель Google Inception из файла .pb. Если бы я хотел заменить свой собственный файл .pb, как бы я это сделал? Нужно ли мне менять какой-либо собственный код / методы?
- В общем, в чем именно разница между всеми этими методами? Или, в более широком смысле, в чем разница между
as_graph_def()
/.ckpt/.pb?
Короче говоря, я ищу способ сохранить как график (например, различные операции и т. Д.), Так и его веса / переменные в файл, который затем можно использовать для загрузки графика и весов в другую программу. , для использования (не обязательно для продолжения / переподготовки).
Документация по этой теме не очень проста, поэтому мы будем благодарны за любые ответы / информацию.
python
tensorflow
protocol-buffers
Разноцветный
источник
источник
Ответы:
Есть много способов подойти к проблеме сохранения модели в TensorFlow, что может немного сбить с толку. Рассмотрим по очереди каждый из ваших подвопросов:
Файлы контрольных точек (например , производится путем вызова
saver.save()
наtf.train.Saver
объект) содержат только веса, и любые другие переменные , определенные в одной и той же программе. Чтобы использовать их в другой программе, вы должны воссоздать связанную структуру графа (например, запустив код для ее повторного построения или вызвавtf.import_graph_def()
), которая сообщает TensorFlow, что делать с этими весами. Обратите внимание, что при вызовеsaver.save()
также создается файл, содержащийMetaGraphDef
граф и подробности того, как связать веса из контрольной точки с этим графом. Смотрите руководство для более подробной информации.tf.train.write_graph()
пишет только структуру графа; не веса.Bazel не имеет отношения к чтению или написанию графиков TensorFlow. (Возможно, я неправильно понял ваш вопрос: не стесняйтесь прояснить его в комментарии.)
Замороженный график можно загрузить с помощью
tf.import_graph_def()
. В этом случае веса (обычно) встроены в график, поэтому вам не нужно загружать отдельную контрольную точку.Основное изменение будет заключаться в обновлении имен тензоров, которые вводятся в модель, и имен тензоров, которые извлекаются из модели. В демке TensorFlow Android, это будет соответствовать
inputName
иoutputName
строки, которые передаютсяTensorFlowClassifier.initializeTensorFlow()
.Это
GraphDef
структура программы, которая обычно не меняется в процессе обучения. Контрольная точка - это моментальный снимок состояния тренировочного процесса, который обычно изменяется на каждом этапе тренировочного процесса. В результате TensorFlow использует разные форматы хранения для этих типов данных, а низкоуровневый API предоставляет разные способы их сохранения и загрузки. Библиотеки более высокого уровня, такие какMetaGraphDef
библиотеки, Keras и skflow на основе этих механизмов , чтобы обеспечить более удобные способы сохранения и восстановления целой модели.источник
tf.train.write_graph()
и затем выполнить его?GraphDef
сохраненномуtf.train.write_graph()
, вам также необходимо запомнить имена тензоров, которые вы хотите кормить и получать при выполнении графика (пункт 5 выше).Вы можете попробовать следующий код:
источник