Как простая модель логистической регрессии достигает 92% точности классификации по MNIST?

73

Несмотря на то, что все изображения в наборе данных MNIST центрированы с одинаковым масштабом и обращены вверх без поворотов, у них есть существенный разброс рукописного текста, который удивляет меня, как линейная модель достигает такой высокой точности классификации.

Насколько я могу визуализировать, учитывая значительные различия в почерке, цифры должны быть линейно неразделимы в пространстве размером 784, то есть должна быть небольшая сложная (хотя и не очень сложная) нелинейная граница, которая разделяет разные цифры. Аналогично хорошо цитируемому примеру где положительный и отрицательный классы не могут быть разделены никаким линейным классификатором. Мне кажется непонятным, как мультиклассовая логистическая регрессия дает такую ​​высокую точность с полностью линейными характеристиками (без полиномиальных особенностей).XOR

Например, для любого пикселя на изображении различные рукописные варианты цифр и могут сделать этот пиксель подсвеченным или нет. Следовательно, с набором изученных весов каждый пиксель может сделать цифру как так и . Только с комбинацией значений пикселей можно сказать, является ли цифра или . Это верно для большинства пар цифр. Итак, как же логистическая регрессия, которая слепо основывает свое решение независимо от всех значений пикселей (без учета каких-либо межпиксельных зависимостей вообще), способна достичь такой высокой точности.232323

Я знаю, что где-то ошибаюсь или просто переоцениваю различия в изображениях. Тем не менее, было бы здорово, если бы кто-то мог помочь мне с интуицией о том, как цифры «почти» линейно разделимы.

Нитиш Агарвал
источник
Взгляните на учебник «Статистическое обучение с редкостью: лассо и обобщения» 3.3.1. Пример: рукописные цифры web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Адриан,
Мне было любопытно: насколько хорошо что-то вроде штрафованной линейной модели (например, glmnet) справляется с этой проблемой? Если я вспоминаю, то, что вы сообщаете, - это непопулярная точность вне выборки.
Клифф AB

Ответы:

91

tl; dr Несмотря на то, что это набор данных для классификации изображений, он остается очень простой задачей, для которой можно легко найти прямое сопоставление от входных данных до предсказаний.


Ответ:

Это очень интересный вопрос, и благодаря простоте логистической регрессии вы действительно можете найти ответ.

Что логистическая регрессия делает для каждого изображения, принимает входных данных и умножает их на веса, чтобы сгенерировать его прогноз. Интересно то, что из-за прямого отображения между входом и выходом (то есть без скрытого слоя) значение каждого веса соответствует тому, насколько каждый из входов учитывается при вычислении вероятности каждого класса. Теперь, взяв веса для каждого класса и изменив их на (т.е. разрешение изображения), мы можем сказать, какие пиксели наиболее важны для вычисления каждого класса .78478428×28

Обратите внимание, опять же, что это веса .

Теперь взгляните на изображение выше и сфокусируйтесь на первых двух цифрах (то есть ноль и одна). Синие веса означают, что интенсивность этого пикселя вносит большой вклад в этот класс, а красные значения означают, что он вносит отрицательный вклад.

А теперь представьте, как человек рисует ? Он рисует пустую круглую форму между ними. Это именно то, что поднял вес. На самом деле, если кто-то рисует середину изображения, оно считается отрицательным как ноль. Поэтому для распознавания нулей вам не нужны сложные фильтры и высокоуровневые функции. Вы можете просто посмотреть на нарисованные позиции пикселей и судить по этому.0

То же самое для . Он всегда имеет прямую вертикальную линию в середине изображения. Все остальное считается отрицательно.1

Остальные цифры немного сложнее, но с небольшим воображением вы можете увидеть , , и . Остальные цифры немного сложнее, что фактически ограничивает логистическую регрессию от достижения высоких 90-х.2378

Благодаря этому вы можете видеть, что логистическая регрессия имеет очень хорошие шансы получить много изображений правильно, и поэтому она так высоко оценивается.


Код для воспроизведения приведенного выше рисунка немного устарел, но здесь вы идете:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
Djib2011
источник
13
Спасибо за иллюстрацию. Эти весовые изображения показывают, насколько точна точность. Точечное умножение рукописного цифрового изображения на весовое изображение, соответствующее истинной метке изображения, в большинстве случаев «кажется» наивысшим по сравнению с точечным продуктом с другими весовыми метками (для меня все еще 92% выглядят как много) изображений в МНИСТ. Тем не менее, немного удивительно, что и или и редко ошибочно классифицируются друг с другом при проверке матрицы путаницы. В любом случае, это то, что есть. Данные никогда не лгут. :)2378
Нитиш Агарвал
13
Конечно, полезно, чтобы образцы MNIST центрировались, масштабировались и нормализовались по контрасту до того, как классификатор их увидит. Вам не нужно отвечать на вопросы типа «что если край нуля действительно проходит через середину окна?» потому что препроцессор уже прошел долгий путь к тому, чтобы все нули выглядели одинаково.
Хоббс
1
@EricDuminil Я добавил комментарий к сценарию с вашим предложением. Большое спасибо за вклад! : D
Djib2011
1
@NitishAgarwal, если вы считаете, что этот ответ является ответом на ваш вопрос, пометьте его как таковой.
Синтаксис
16
Для кого-то, кто интересуется, но не особенно знаком с этим типом обработки, этот ответ предоставляет фантастический интуитивный пример механики.
Крили - на забастовке -