{-# LANGUAGE RecordWildCards #-}
module Numeric.Recommender.ALS where
import Control.Parallel.Strategies
import Data.Bifunctor
import Data.Default.Class
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import Data.List (sortBy)
import Data.Maybe
import Data.Tuple
import qualified Data.Vector.Storable
import Numeric.LinearAlgebra
import System.Random
import Prelude hiding ((<>))
data ALSParams = ALSParams
{ lambda :: Double
, alpha :: Double
, seed :: Int
, nFactors :: Int
, parChunk :: Int
} deriving (Show)
instance Default ALSParams where
def = ALSParams 0.1 10 0 10 10
data ALSResult = ALSResult
{ cost :: Double
, itemFeature :: !(Matrix Double)
, userFeature :: !(Matrix Double)
}
data ALSModel u i = ALSModel
{ encodeUser :: u -> Maybe Int
, decodeUser :: Int -> u
, encodeItem :: i -> Maybe Int
, decodeItem :: Int -> i
, pairs :: [(Int, Int)]
, results :: [ALSResult]
}
buildModel
:: (Functor f, Foldable f)
=> ALSParams
-> (u -> Int)
-> (Int -> u)
-> (i -> Int)
-> (Int -> i)
-> f (u, i)
-> ALSModel u i
buildModel ALSParams{..} fromUser toUser fromItem toItem xs = let
parMap' f = withStrategy (parListChunk parChunk rdeepseq) . map f
rnd = mkStdGen seed
(encUser, decUser) = bimap (. fromUser) (toUser .) .
compact $ fmap (fromUser . fst) xs
(encItem, decItem) = bimap (. fromItem) (toItem .) .
compact $ fmap (fromItem . snd) xs
xs' = fmap (bimap (fromJust . encUser) (fromJust . encItem)) xs
nU = 1 + (maximum $ fmap fst xs')
nM = 1 + (maximum $ fmap snd xs')
selections = foldr (\(u,c) -> IntSet.insert (c+(nM*u))) mempty xs'
ratings = (nU><nM) $
map (\k -> if IntSet.member k selections then 1 else 0) [0..(nU*nM)-1]
weighted = scalar alpha * ratings
mIni = (nFactors><nM) $ replicate nFactors 1 ++
(take (nFactors*(nM-1)) $ randomRs (0,lambda) rnd)
sumsU = vector $ map (Data.Vector.Storable.foldr (+) 0) $ toRows ratings
sumsM = vector $ map (Data.Vector.Storable.foldr (+) 0) $ toColumns ratings
f m = let
mtm = m <> tr m
u = fromRows $ parMap'
(\i -> let
m' = m ¿ ((filter (\j -> (>0.1) $
atIndex ratings (i,j))) [0..nM-1])
f' x = vector $ map (atIndex ((toRows x) !! i))
((filter (\j -> (>0.1) $ atIndex ratings (i,j))) [0..nM-1])
w' = f' weighted
r' = f' ratings
m'' = tr $ (tr m') * asColumn w'
x1 = mtm + (m'' <> tr m' +
(scalar lambda * scalar (atIndex sumsU i) * ident nFactors))
x2 = asColumn $ (m'' + m') #> r'
in flatten . maybe (linearSolveSVD x1 x2) id $ linearSolve x1 x2
) [0..nU-1]
tuu = tr u <> u
m2 = fromColumns $ parMap'
(\j -> let
u' = u ? ((filter (\i -> (>0.1) $
atIndex ratings (i,j))) [0..nU-1])
f' x = vector $ map (atIndex ((toColumns x) !! j))
((filter (\i -> (>0.1) $ atIndex ratings (i,j))) [0..nU-1])
w' = f' weighted
r' = f' ratings
u'' = tr $ asColumn w' * u'
x1 = tuu + u'' <> u' +
(scalar lambda * scalar (atIndex sumsM j) * ident nFactors)
x2 = asColumn $ (u'' + tr u') #> r'
in flatten . maybe (linearSolveSVD x1 x2) id $ linearSolve x1 x2
) [0..nM-1]
in (u, m2)
results = iterate ((\x -> x `seq` f x) . snd) (f mIni)
in ALSModel encUser decUser encItem decItem (foldr (:) [] xs') $
map (\(u, m) -> ALSResult
(costFunction ratings u m weighted lambda sumsU sumsM) m u) results
where
compact
:: Foldable f
=> f Int
-> (Int -> Maybe Int, Int -> Int)
compact ys = let
mp = foldr (\x a -> IntMap.insertWith (flip const) x (IntMap.size a) a) mempty ys
pm = IntMap.fromList . map swap $ IntMap.toList mp
in ( flip IntMap.lookup mp
, maybe (error $ "missing value") id . flip IntMap.lookup pm
)
costFunction
:: Matrix Double -> Matrix Double -> Matrix Double -> Matrix Double
-> Double -> Vector Double -> Vector Double -> Double
costFunction r u m w l nui nmj = let rum = r - (u <> m)
in sumElements ((w + 1) * (rum * rum)) +
(l * (sumElements (nui <# (u^2)) + sumElements ((m^2) #> nmj)))
recommend
:: ALSModel u i
-> Int
-> IntMap.IntMap [(i, Bool)]
recommend ALSModel{..} n =
let ALSResult{..} = results !! (n-1)
feat = userFeature <> itemFeature
usrIt = foldr
(\(k,v) -> IntMap.insertWith IntSet.union k (IntSet.singleton v))
mempty pairs
in foldr (\u -> let inUsr = fromJust $ IntMap.lookup u usrIt in
IntMap.insert u $
map ((\x -> (decodeItem x, not $ IntSet.member x inUsr)) . fst) $
sortBy (\(_,a) (_,b) -> compare b a) $
zip [0..] $ head $ toLists $ feat ? [u])
mempty $ map fst pairs