Мемоизация в Хаскеле?

136

Любые указатели на то, как эффективно решить следующую функцию в Haskell, для больших чисел (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

Я видел примеры запоминания в Хаскеле для решения чисел Фибоначчи, которые включали (лениво) вычисление всех чисел Фибоначчи до требуемого n. Но в этом случае для данного n нам нужно только вычислить очень мало промежуточных результатов.

Спасибо

Анхель де Висенте
источник
110
Только в том смысле, что это какая-то работа, которую я делаю дома :-)
Angel de Vicente

Ответы:

256

Мы можем сделать это очень эффективно, создав структуру, которую мы можем индексировать за сублинейное время.

Но сначала,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Давайте определимся f, но сделаем так, чтобы он использовал «открытую рекурсию», а не вызывал сам себя.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Вы можете получить незапятнанный fс помощьюfix f

Это позволит вам протестировать то f, что вы имеете в виду, для небольших значений f, вызвав, например:fix f 123 = 144

Мы могли бы запомнить это, определив:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Это работает сносно хорошо, и заменяет то, что собиралось занять O (n ^ 3) время чем-то, что запоминает промежуточные результаты.

Но для индексации запомненного ответа все равно требуется линейное время mf. Это означает, что результаты как:

*Main Data.List> faster_f 123801
248604

терпимы, но результат не намного лучше, чем это. Мы можем сделать лучше!

Сначала давайте определим бесконечное дерево:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

И тогда мы определим способ индексации в нем, чтобы вместо этого мы могли найти узел с индексом nза O (log n) :

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... и мы можем найти удобное дерево, полное натуральных чисел, поэтому нам не нужно возиться с этими индексами:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Поскольку мы можем индексировать, вы можете просто преобразовать дерево в список:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Вы можете проверить работу до сих пор, убедившись, что toList natsдает вам[0..]

Сейчас,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

работает так же, как и в приведенном выше списке, но вместо того, чтобы тратить время на линейное нахождение каждого узла, можно найти его в логарифмическом времени.

Результат значительно быстрее:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

На самом деле это намного быстрее, что вы можете перейти и заменить Intна Integerвыше и получить смехотворно большие ответы почти мгновенно

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
Эдвард КМЕТТ
источник
3
Я попробовал этот код и, что интересно, f_faster показался медленнее, чем f. Я полагаю, что ссылки на список действительно замедлили ход событий. Определение nats и index показалось мне довольно загадочным, поэтому я добавил свой собственный ответ, который может прояснить ситуацию.
Питару
5
Случай бесконечного списка имеет дело со связанным списком длиной 111111111 пунктов. Случай дерева имеет дело с log n * количеством достигнутых узлов.
Эдвард КМЕТТ
2
то есть версия списка должна создавать thunks для всех узлов в списке, тогда как древовидная версия избегает создания многих из них.
Том Эллис
7
Я знаю, что это довольно старая запись, но ее не f_treeследует определять в whereпредложении, чтобы избежать сохранения ненужных путей в дереве между вызовами?
dfeuer
17
Причина, по которой он был добавлен в CAF, заключалась в том, что вы могли получать напоминания о вызовах. Если бы у меня был дорогой звонок, который я запоминал, то я бы, вероятно, оставил его в CAF, отсюда и методика, показанная здесь. В реальном приложении, конечно, существует компромисс между преимуществами и затратами на постоянное запоминание. Хотя, учитывая вопрос о том, как добиться запоминания, я думаю, что было бы неправильно вводить ответ с помощью техники, которая намеренно избегает запоминания между вызовами, и если ничего другого, то этот комментарий здесь не укажет на тот факт, что существуют тонкости. ;)
Эдвард КМЕТТ
17

Ответ Эдварда настолько удивителен, что я продублировал его и предоставил реализации memoListи memoTreeкомбинаторы, которые запоминают функцию в открытой рекурсивной форме.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
Том Эллис
источник
12

Не самый эффективный способ, но запоминает

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

при запросе f !! 144проверяется, что f !! 143существует, но его точное значение не рассчитывается. Это все еще установлено как некоторый неизвестный результат вычисления. Единственные точные рассчитанные значения являются необходимыми.

Итак, изначально, насколько было рассчитано, программа ничего не знает.

f = .... 

Когда мы делаем запрос f !! 12, он начинает выполнять сопоставление с шаблоном:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Теперь начинается расчет

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Это рекурсивно выдвигает другое требование на f, поэтому мы вычислим

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Теперь мы можем сделать немного

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Что означает, что программа теперь знает:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Продолжая сочиться:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Что означает, что программа теперь знает:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Теперь мы продолжим наш расчет f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Что означает, что программа теперь знает:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Теперь мы продолжим наш расчет f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Что означает, что программа теперь знает:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Так что расчет сделан довольно лениво. Программа знает, что какое-то значение для f !! 8существует, что оно равно g 8, но она не знает, что это g 8такое.

колокольчик-рапунцель
источник
Спасибо за это. Как бы вы создали и использовали двумерное пространство решений? Это был бы список списков? иg n m = (something with) f!!a!!b
Викингстеве
1
Конечно, вы могли бы. Для реального решения, тем не менее, я бы, вероятно, использовал библиотеку напоминаний, такую
rampion
Это O (n ^ 2), к сожалению.
номер
8

Это дополнение к прекрасному ответу Эдварда Кметта.

Когда я попробовал его код, определения natsи indexказались довольно загадочными, поэтому я написал альтернативную версию, которая мне показалась более легкой для понимания.

Я определяю indexи natsс точки зрения index'и nats'.

index' t nопределяется в диапазоне [1..]. (Вспомните, что index tопределено в диапазоне [0..].) Он работает, ищет дерево, обрабатывая его nкак цепочку битов и считывая биты в обратном порядке. Если бит равен 1, он принимает правую ветвь. Если бит равен 0, он принимает левую ветвь. Он останавливается, когда достигает последнего бита (который должен быть 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Так же, как natsопределено для indexтак, что index nats n == nвсегда верно, nats'определено для index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Теперь, natsи indexпросто nats'и , index'но со значениями сдвинуты на 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Pitarou
источник
Спасибо. Я запомнил многомерную функцию, и это действительно помогло мне понять, что на самом деле делали index и nats.
Китцил
8

Как указано в ответе Эдварда Кметта, чтобы ускорить процесс, вам необходимо кэшировать дорогостоящие вычисления и иметь быстрый доступ к ним.

Чтобы сохранить функцию не монадической, решение этой задачи заключается в создании бесконечного ленивого дерева с соответствующим способом его индексации (как показано в предыдущих статьях). Если вы отказываетесь от немонадной природы функции, вы можете использовать стандартные ассоциативные контейнеры, доступные в Haskell, в сочетании с «подобными состоянию» монадами (такими как State или ST).

Хотя основным недостатком является то, что вы получаете немонадную функцию, вам больше не нужно индексировать структуру самостоятельно, и вы можете просто использовать стандартные реализации ассоциативных контейнеров.

Для этого сначала нужно переписать свою функцию, чтобы она принимала любую монаду:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Для ваших тестов вы все равно можете определить функцию, которая не запоминает, используя Data.Function.fix, хотя она немного более многословна:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Затем вы можете использовать State monad в сочетании с Data.Map, чтобы ускорить процесс:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

С небольшими изменениями вы можете вместо этого адаптировать код для работы с Data.HashMap:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

Вместо постоянных структур данных вы также можете использовать изменяемые структуры данных (например, Data.HashTable) в сочетании с монадой ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

По сравнению с реализацией без какого-либо запоминания, любая из этих реализаций позволяет вам при огромных входах получать результаты в микросекундах вместо того, чтобы ждать несколько секунд.

Используя Criterion в качестве эталона, я мог заметить, что реализация с Data.HashMap на самом деле работала немного лучше (около 20%), чем Data.Map и Data.HashTable, для которых сроки были очень похожи.

Я нашел результаты теста немного удивительными. Сначала я чувствовал, что HashTable превзойдет реализацию HashMap, потому что она изменчива. В этой последней реализации может быть скрыт некоторый дефект производительности.

Quentin
источник
2
GHC делает очень хорошую работу по оптимизации вокруг неизменных структур. Интуиция от C не всегда срабатывает.
Джон Тайри
3

Пару лет спустя я посмотрел на это и понял, что есть простой способ запомнить это в линейном времени, используя zipWithвспомогательную функцию:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilateимеет удобное свойство, что dilate n xs !! i == xs !! div i n.

Итак, предположим, что нам дано f (0), это упрощает вычисление до

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Очень похоже на наше оригинальное описание проблемы и дает линейное решение ( sum $ take n fsпотребуется O (n)).

колокольчик-рапунцель
источник
2
так что это генеративное (corecursive?) или динамическое программирование решение. Взятие O (1) времени на каждое сгенерированное значение, как это делает обычный Фибоначчи. Большой! И решение EKMETT похоже на логарифмическое большое число Фибоначчи, которое намного быстрее достигает больших чисел, пропуская большую часть промежуточных значений. Это правильно?
Уилл Несс
или, может быть, он ближе к числам Хэмминга, с тремя обратными указателями на производимую последовательность и различными скоростями для каждого из них, продвигающихся вдоль него. действительно красиво.
Уилл Несс
2

Еще одно дополнение к ответу Эдварда Кметта: отдельный пример:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Используйте его следующим образом, чтобы запомнить функцию с одним целочисленным аргументом (например, fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Будут кэшироваться только значения для неотрицательных аргументов.

Чтобы также кэшировать значения для отрицательных аргументов, используйте memoIntследующее:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Для кэширования значений для функций с двумя целочисленными аргументами используйте memoIntIntследующее:

memoIntInt f = memoInt (\n -> memoInt (f n))
Нил Янг
источник
2

Решение без индексации и не основано на Эдварде КМЕТТ.

Я выделяю общие поддеревья для общего родителя ( f(n/4)разделяется между f(n/2)и f(n/4), и f(n/6)разделяется между f(2)и f(3)). Сохраняя их как одну переменную в родительском объекте, вычисление поддерева выполняется один раз.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

Код не может легко распространяться на общую функцию запоминания (по крайней мере, я бы не знал, как это сделать), и вам действительно нужно подумать, как перекрываются подзадачи, но стратегия должна работать для общих нескольких нецелых параметров. , (Я придумал это для двух строковых параметров.)

Записка сбрасывается после каждого расчета. (Опять же, я думал о двух строковых параметрах.)

Я не знаю, является ли это более эффективным, чем другие ответы. Технически каждый поиск - это всего лишь один или два шага («Посмотрите на вашего ребенка или ребенка вашего ребенка»), но может потребоваться много дополнительной памяти.

Изменить: Это решение еще не является правильным. Разделение неполное.

Редактировать: Теперь он должен делиться своими детьми должным образом, но я понял, что эта проблема имеет много нетривиального обмена: n/2/2/2и n/3/3может быть такой же. Проблема не подходит для моей стратегии.

leewz
источник