{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
module What4.Protocol.PolyRoot
( Root
, approximate
, fromYicesText
, parseYicesRoot
) where
import Control.Applicative
import Control.Lens
import qualified Data.Attoparsec.Text as Atto
import qualified Data.Map as Map
import Data.Ratio
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Vector as V
import Text.PrettyPrint.ANSI.Leijen as PP hiding ((<$>))
atto_angle :: Atto.Parser a -> Atto.Parser a
atto_angle p = Atto.char '<' *> p <* Atto.char '>'
atto_paren :: Atto.Parser a -> Atto.Parser a
atto_paren p = Atto.char '(' *> p <* Atto.char ')'
newtype SingPoly coef = SingPoly (V.Vector coef)
deriving (Functor, Foldable, Traversable, Show)
instance (Ord coef, Num coef, Pretty coef) => Pretty (SingPoly coef) where
pretty (SingPoly v) =
case V.findIndex (/= 0) v of
Nothing -> text "0"
Just j -> go (V.length v - 1)
where ppc c | c < 0 = parens (pretty c)
| otherwise = pretty c
ppi 1 = text "*x"
ppi i = text "*x^" <> pretty i
go 0 = ppc (v V.! 0)
go i | seq i False = error "pretty SingPoly"
| i == j = ppc (v V.! i) <> ppi i
| v V.! i == 0 = go (i-1)
| otherwise = ppc (v V.! i) <> ppi i <+> text "+" <+> go (i-1)
fromList :: [c] -> SingPoly c
fromList = SingPoly . V.fromList
fromMap :: (Eq c, Num c) => Map.Map Int c -> SingPoly c
fromMap m0 = SingPoly (V.generate (n+1) f)
where m = Map.filter (/= 0) m0
(n,_) = Map.findMax m
f i = Map.findWithDefault 0 i m
pos_mono :: Integral c => Atto.Parser (c, Int)
pos_mono = (,) <$> Atto.decimal <*> times_x
where times_x :: Atto.Parser Int
times_x = (Atto.char '*' *> Atto.char 'x' *> expon) <|> pure 0
expon :: Atto.Parser Int
expon = (Atto.char '^' *> Atto.decimal) <|> pure 1
mono :: Integral c => Atto.Parser (c, Int)
mono = atto_paren (Atto.char '-' *> (over _1 negate <$> pos_mono))
<|> pos_mono
parseYicesPoly :: Integral c => Atto.Parser (SingPoly c)
parseYicesPoly = do
(c,p) <- mono
go (Map.singleton p c)
where go m = next m <|> pure (fromMap m)
next m = seq m $ do
_ <- Atto.char ' ' *> Atto.char '+' *> Atto.char ' '
(c,p) <- mono
go (Map.insertWith (+) p c m)
eval :: forall c . Num c => SingPoly c -> c -> c
eval (SingPoly v) c = f 0 1 0
where
f :: Int -> c -> c -> c
f i p s
| seq p $ seq s $ False = error "internal error: Poly.eval"
| i < V.length v = f (i+1) (p * c) (s + p * (v V.! i))
| otherwise = s
data Root c = Root { rootPoly :: !(SingPoly c)
, rootLbound :: !c
, rootUbound :: !c
}
deriving (Show)
rootFromRational :: Num c => c -> Root c
rootFromRational r = Root { rootPoly = fromList [ negate r, 1 ]
, rootLbound = r
, rootUbound = r
}
instance (Ord c, Num c, Pretty c) => Pretty (Root c) where
pretty (Root p l u) = langle <> pretty p <> comma <+> bounds <> rangle
where bounds = parens (pretty l <> comma <+> pretty u)
approximate :: Root Rational -> Rational
approximate r
| l0 == u0 = l0
| init_lval == 0 = l0
| init_uval == 0 = u0
| init_lval < 0 && init_uval > 0 = bisect (fromRational l0) (fromRational u0)
| init_lval > 0 && init_uval < 0 = bisect (fromRational u0) (fromRational l0)
| otherwise = error "Closest root given bad root."
where p_rat = rootPoly r
l0 = rootLbound r
u0 = rootUbound r
init_lval = eval p_rat l0
init_uval = eval p_rat u0
bisect :: Double -> Double -> Rational
bisect l u
| m == l || m == u = toRational $
if l_val <= u_val then l else u
| m_val == 0 = toRational m
| m_val < 0 = bisect m u
| otherwise = bisect l m
where m = (l + u) / 2
m_val = eval p_rat (toRational m)
l_val = abs (eval p_rat (toRational l))
u_val = abs (eval p_rat (toRational u))
atto_pair :: (a -> b -> r) -> Atto.Parser a -> Atto.Parser b -> Atto.Parser r
atto_pair f x y = f <$> x <*> (Atto.char ',' *> Atto.char ' ' *> y)
atto_sdecimal :: Integral c => Atto.Parser c
atto_sdecimal = Atto.char '-' *> (negate <$> Atto.decimal)
<|> Atto.decimal
atto_rational :: Integral c => Atto.Parser (Ratio c)
atto_rational = (%) <$> atto_sdecimal <*> denom
where denom = (Atto.char '/' *> Atto.decimal) <|> pure 1
parseYicesRoot :: Atto.Parser (Root Rational)
parseYicesRoot = atto_angle (atto_pair mkRoot (fmap fromInteger <$> parseYicesPoly) parseBounds)
<|> (rootFromRational <$> atto_rational)
where mkRoot :: SingPoly c -> (c, c) -> Root c
mkRoot = uncurry . Root
parseBounds :: Atto.Parser (Rational, Rational)
parseBounds = atto_paren (atto_pair (,) atto_rational atto_rational)
fromYicesText :: Text -> Maybe (Root Rational)
fromYicesText t = resolve (Atto.parse parseYicesRoot t)
where resolve (Atto.Fail _rem _ _msg) = Nothing
resolve (Atto.Partial f) =
resolve (f Text.empty)
resolve (Atto.Done i r)
| Text.null i = Just $! r
| otherwise = Nothing