Haskellでプログラミングコンテスト チャレンジブック part3:「迷路の最短路」(幅優先探索・Data.Sequence・STArray)

蟻本の問題から「迷路の最短路」



入力:(N M ≦ 100)

N M
#S######.#
......#..#
.#.##.##.#
.#........
##.##.####
....#....#
.#######.#
....#.....
.####.###.
....#...G#

S:入口
G:出口
.:床(この上は通過可能)
#:壁(通り抜けられない)

出力:
入口Sから出口Gまでの最短路の距離


幅優先探索の問題です。


(1)Sの位置posSを調べ、(posS,0)を状態キューに挿入
(2)キューの先頭を取り出し(pos,turn)とする
(3a)posがGならturnを出力
(3b)posが到達済みなら(2)に戻る
(3c)posが未到達なら、状態キューにその位置の周囲の行けるマスnextpos(最大4マス)をすべて取り出し、(nextpos,turn+1)を状態キューに挿入し、posに到達したフラグを立てて(2)に戻る

という、普通のBFSです。


これをHaskellで書くにあたって手間取ったのは以下の2つ。

(1)キューをどうするか
(2)到達済みフラグの管理をどうするか

です。





(1)はData.Sequence型を使いました。これは両端への挿入・取り出しがO(1)となっている汎用シーケンスです。
挿入は「|>」「<|」演算子を使えばいいのですが、取り出しが多少面倒で、コレは取り出す関数を定義しておいたほうがいいと思います。

dequeue :: Seq a -> (a,Seq a)
dequeue que = case viewl que of
                EmptyL -> error "Empty Queue"
                x :< xs -> (x,xs)

いったんviewl関数を適用し、viewL型の変数に直してからパターンマッチ。キューから取り出せなくなったときの処理もかんがえるならエラーではなくMaybe (a,Seq a)型にしておいたほうがいいですね。


いちおう、今回は使いませんが逆から取り出す関数も。

dequeuer :: Seq a -> (a,Seq a)
dequeuer que = case viewr que of
                EmptyR -> error "Empty Queue"
                xs :> x -> (x,xs)

(2)はちょっと苦労しました。
以前Lake Countingを解いたときは配列の漸進的更新を使っていましたが、調べるとあの更新法だと配列の更新コストがO(N)となってしまうようです。


C++などで上のアルゴリズムを実装するとO(NM)ですが、Haskellで配列の漸進的更新をもちいてフラグテーブルを管理するコードを書いてしまうと、上記の理由で計算量がO(N2 M)になってしまいます。
まあN,M≦100ならこれでも通りそうですが、
N,M≦1000とかだと厳しそうです。


そこで、STArray型を使うことにしました。
これはモナドを使って安全に配列の破壊的更新(こんどこそ破壊的更新です)を行うものです。

ただ、STArray型をそのまま扱うよりも、次のようにArray型を受け取り、STArrayに変換してまたArray型に直して返すような関数を定義して用いると便利だと思います。

update :: (Ix a) => Array a b -> a -> b -> Array a b
update ary x y = runSTArray $ do
                   ary' <- unsafeThaw ary -- コンパイルオプション -O の最適化のもとで 計算量 O(1)
                   writeArray ary' x y  -- 配列を計算量 O(1) で破壊的更新
                   return ary' -- 最後にunsafeFreezeを用いているが、これも -O のもとで O(1) 

これで、全体として(-Oオプションのもとで)O(1)で配列の更新ができます。(できるはずです)




と、いうことで、出来上がったコードがこちら。
上のでかいclass宣言とかIOとかは全部無視してmatrix以降が本題のプログラムです。

{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_GHC -O2 #-}
import Control.Monad.ST
import Data.STRef
import Data.Array.ST
import Data.Array
import Data.Sequence
import qualified Data.IntMap as Map


class Scan a where scan' :: String -> a
instance Scan Int where scan' n = read n
instance Scan Char where scan' (x:_) = x
instance Scan Float where scan' f = read f
instance Scan Double where scan' d = read d
instance Scan Integer where scan' n = read n
instance Scan String where scan' x = x
instance (Scan a,Scan b) => Scan (a,b) where scan' x = scan'' (words x)
                                                 where
                                                   scan'' (x:y:_) = (scan' x,scan' y)
instance (Scan a,Scan b,Scan c) => Scan (a,b,c) where scan' x = scan'' (words x)
                                                          where
                                                            scan'' (x:y:z:_) = (scan' x,scan' y,scan' z)
instance (Scan a,Scan b,Scan c,Scan d) => Scan (a,b,c,d) where scan' x = scan'' (words x)
                                                                   where
                                                                     scan'' (w:x:y:z:_) = (scan' w,scan' x,scan' y,scan' z)
instance (Scan a,Scan b,Scan c,Scan d,Scan e) => Scan (a,b,c,d,e) where scan' x = scan'' (words x)
                                                                            where
                                                                              scan'' (v:w:x:y:z:_) = (scan' v,scan' w,scan' x,scan' y,scan' z)
class Ans a where showans :: a -> String
instance Ans Int where showans x = show x
instance Ans Char where showans x = [x]
instance Ans Float where showans x = show x
instance Ans Double where showans x = show x
instance Ans Integer where showans x = show x
instance Ans String where showans x = x
instance (Ans a, Ans b) => Ans (a,b) where
    showans (x,y) = showans x ++ " " ++ showans y
instance (Ans a, Ans b,Ans c) => Ans (a,b,c) where
    showans (x,y,z) = showans x ++ " " ++ showans y ++ " " ++ showans z

scan :: (Scan a) => IO a
scan = do n <- getLine
          return (scan' n)

scans :: (Scan a) => Int -> IO [a]
scans 0 = return []
scans n = do x <- scan
             xs <- scans (n-1)
             return (x:xs)

putAnsLn :: (Ans a) => a -> IO ()
putAnsLn ans = putStrLn (showans ans)


matrix :: Int -> Int -> [[a]] -> Array (Int,Int) a
matrix n m lis = array ((0,0),(n-1,m-1)) (matrix_rec 0 lis [])
    where
      matrix_rec _ [] l = l
      matrix_rec i (x:xs) l = matrix_rec (i+1) xs ((Prelude.zip [(i,j)|j<-[0..m]] x)++l)

dequeue :: Seq a -> (a,Seq a)
dequeue que = case viewl que of
                EmptyL -> error "Empty Queue"
                x :< xs -> (x,xs)

memorize ary pos = runSTArray $ do
                     ary' <- unsafeThaw ary -- コンパイルオプション -O の最適化のもとで 計算量 O(1)
                     writeArray ary' pos True -- 配列を計算量 O(1) で破壊的更新
                     return ary' -- 最後にunsafeFreezeを用いているが、これも -O のもとで O(1) 

visited ary pos = ary!pos

solve n m maze = rec (singleton (start,0)) (array ((0,0),(n-1,m-1)) [((i,j),False)|i<-[0..n-1],j<-[0..m-1]])
    where
      start = srec (0,0)
      srec p@(i,j) |(maze!p)=='S' = p
                   |j<m = srec (i,j+1)
                   |otherwise = srec (i+1,0)
      enqueue (i,j) turn que = enqueued
          where
            que' = if i>1 then que |> ((i-1,j),turn) else que
            que'' = if j>1 then que' |> ((i,j-1),turn) else que'
            que''' = if i<n-1 then que'' |> ((i+1,j),turn) else que''
            enqueued = if j<m-1 then que''' |> ((i,j+1),turn) else que'''
      rec que table = let ((pos,turn),rest) = dequeue que
                      in
                        if (maze!pos) == 'G' then
                            turn
                        else if (maze!pos) == '#' || (visited table pos) then
                                 rec rest table
                             else
                                 rec (enqueue pos (turn+1) rest) (memorize table pos)

main = do (n,m) <- scan::IO (Int, Int)
          mazedata <- scans n::IO [String]
          putDebugLn (solve n m (matrix n m mazedata))


これで、上の入力には正しい答え22を返します。




思ったより難しかったです。
キューやらテーブルやらはC++で書いたらSTLや配列で一発なので、こういうところで詰まると苦労しますね。
特にこのupdate関数はライブラリ化しておく必要もありそうです。

あ、ちなみに、未確認ですがCodeforcesでは-Oオプションが付いているそうなので、そこは安心ですね。



ではまた(・ω・)ノシ