{-# LANGUAGE NoMonomorphismRestriction, TupleSections, DeriveFunctor #-}

module Math.Combinatorics.GraphAuts (isVertexTransitive, isEdgeTransitive,
                                     isArcTransitive, is2ArcTransitive, is3ArcTransitive, is4ArcTransitive, isnArcTransitive,
                                     graphAuts, incidenceAuts, graphAuts7, graphAuts8, incidenceAuts2,
                                     isGraphAut, isIncidenceAut,
                                     graphIsos, incidenceIsos,
                                     isGraphIso, isIncidenceIso) where

import Data.Either (lefts, rights, partitionEithers)
import qualified Data.List as L
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.Ord (comparing)
import qualified Data.Foldable as Foldable
import qualified Data.Sequence as Seq

import Math.Common.ListSet
import Math.Core.Utils (combinationsOf, intersectAsc, pairs, picks, (^-))
import Math.Combinatorics.Graph
import Math.Algebra.Group.PermutationGroup
import Math.Algebra.Group.SchreierSims as SS

-- The code for finding automorphisms - "graphAuts" - follows later on in file


-- |A graph is vertex-transitive if its automorphism group acts transitively on the vertices. Thus, given any two distinct vertices, there is an automorphism mapping one to the other.
isVertexTransitive :: (Ord t) => Graph t -> Bool
isVertexTransitive (G [] []) = True -- null graph is trivially vertex transitive
isVertexTransitive g@(G (v:vs) es) = orbitV auts v == v:vs where
    auts = graphAuts g

-- |A graph is edge-transitive if its automorphism group acts transitively on the edges. Thus, given any two distinct edges, there is an automorphism mapping one to the other.
isEdgeTransitive :: (Ord t) => Graph t -> Bool
isEdgeTransitive (G _ []) = True
isEdgeTransitive g@(G vs (e:es)) = orbitE auts e == e:es where
    auts = graphAuts g

arc ->^ g = map (.^ g) arc
-- unlike edges/blocks, arcs are directed, so the action on them does not sort

-- Godsil & Royle 59-60
-- |A graph is arc-transitive (or flag-transitive) if its automorphism group acts transitively on arcs. (An arc is an ordered pair of adjacent vertices.)
isArcTransitive :: (Ord t) => Graph t -> Bool
isArcTransitive (G _ []) = True -- empty graphs are trivially arc transitive
isArcTransitive g@(G vs es) = orbit (->^) a auts == a:as where
-- isArcTransitive g@(G vs es) = closure [a] [ ->^ h | h <- auts] == a:as where
    a:as = L.sort $ es ++ map reverse es
    auts = graphAuts g

isArcTransitive' g@(G (v:vs) es) =
    orbitP auts v == v:vs && -- isVertexTransitive g
    orbitP stab n == n:ns
    where auts = graphAuts g
          stab = filter (\p -> v .^ p == v) auts -- relies on v being the first base for the SGS returned by graphAuts
          -- stab = dropWhile (\p -> v .^ p /= v) auts -- we know that graphAuts are returned in this order
          n:ns = nbrs g v

-- execution time of both of the above is dominated by the time to calculate the graph auts, so their performance is similar

-- then k n, kb n n, q n, other platonic solids, petersen graph, heawood graph, pappus graph, desargues graph are all arc-transitive

-- find arcs of length l from x using dfs - results returned in order
-- an arc is a sequence of vertices connected by edges, no doubling back, but self-crossings allowed
findArcs g@(G vs es) x l = map reverse $ dfs [ ([x],0) ] where
    dfs ( (z1:z2:zs,l') : nodes)
        | l == l'   = (z1:z2:zs) : dfs nodes
        | otherwise = dfs $ [(w:z1:z2:zs,l'+1) | w <- nbrs g z1, w /= z2] ++ nodes
    dfs ( ([z],l') : nodes)
        | l == l'   = [z] : dfs nodes
        | otherwise = dfs $ [([w,z],l'+1) | w <- nbrs g z] ++ nodes
    dfs [] = []

-- note that a graph with triangles can't be 3-arc transitive, etc, because an aut can't map a self-crossing arc to a non-self-crossing arc

-- |A graph is n-arc-transitive if its automorphism group is transitive on n-arcs. (An n-arc is an ordered sequence (v0,...,vn) of adjacent vertices, with crossings allowed but not doubling back.)
isnArcTransitive :: (Ord t) => Int -> Graph t -> Bool
isnArcTransitive _ (G [] []) = True
isnArcTransitive n g@(G (v:vs) es) =
    orbitP auts v == v:vs && -- isVertexTransitive g
    orbit (->^) a stab == a:as
    -- closure [a] [ ->^ h | h <- stab] == a:as
    where auts = graphAuts g
          stab = filter (\p -> v .^ p == v) auts -- relies on v being the first base for the SGS returned by graphAuts
          -- stab = dropWhile (\p -> v .^ p /= v) auts -- we know that graphAuts are returned in this order
          a:as = findArcs g v n

is2ArcTransitive :: (Ord t) => Graph t -> Bool
is2ArcTransitive g = isnArcTransitive 2 g

is3ArcTransitive :: (Ord t) => Graph t -> Bool
is3ArcTransitive g = isnArcTransitive 3 g

-- The incidence graphs of the projective planes PG(2,Fq) are 4-arc-transitive
is4ArcTransitive :: (Ord t) => Graph t -> Bool
is4ArcTransitive g = isnArcTransitive 4 g

-- Godsil & Royle 66-7
-- |A graph is distance transitive if given any two ordered pairs of vertices (u,u') and (v,v') with d(u,u') == d(v,v'),
-- there is an automorphism of the graph that takes (u,u') to (v,v')
isDistanceTransitive :: (Ord t) => Graph t -> Bool
isDistanceTransitive (G [] []) = True
isDistanceTransitive g@(G (v:vs) es)
    | isConnected g =
        orbitP auts v == v:vs && -- isVertexTransitive g
        length stabOrbits == diameter g + 1 -- the orbits under the stabiliser of v coincide with the distance partition from v
    | otherwise = error "isDistanceTransitive: only implemented for connected graphs"
    where auts = graphAuts g
          stab = filter (\p -> v .^ p == v) auts -- relies on v being the first base for the SGS returned by graphAuts
          -- stab = dropWhile (\p -> v .^ p /= v) auts -- we know that graphAuts are returned in this order
          stabOrbits = let os = orbits stab in os ++ map (:[]) ((v:vs) L.\\ concat os) -- include fixed point orbits


-- |Is the permutation an automorphism of the graph?
isGraphAut :: Ord t => Graph t -> Permutation t -> Bool
isGraphAut (G vs es) h = all (`S.member` es') [e -^ h | e <- es]
    where es' = S.fromList es
-- this works best on sparse graphs, where p(edge) < 1/2
-- if p(edge) > 1/2, it would be better to test on the complement of the graph

-- |Is the permutation an automorphism of the incidence structure represented by the graph?
-- (Note that an incidence graph colours points as Left, blocks as Right, and a permutation
-- that swaps points and blocks, even if it is an automorphism of the graph, does not represent
-- an automorphism of the incidence structure. Instead, a point-block crossover is called a duality.)
isIncidenceAut :: (Ord p, Ord b) => Graph (Either p b) -> Permutation (Either p b) -> Bool
isIncidenceAut (G vs es) h = all (`S.member` es') [e ->^ h | e <- es]
    -- using ->^ instead of -^ excludes dualities, since each edge is of the form [Left p, Right b]
    where es' = S.fromList es

-- Calculate a map consisting of neighbour lists for each vertex in the graph
-- If a vertex has no neighbours then it is left out of the map
adjLists (G vs es) = adjLists' M.empty es
    where adjLists' nbrs ([u,v]:es) =
              adjLists' (M.insertWith (flip (++)) v [u] $ M.insertWith (flip (++)) u [v] nbrs) es
          adjLists' nbrs [] = nbrs

-- (showing how we got to the final version)

data SearchTree a = T Bool a [SearchTree a] deriving (Eq, Ord, Show, Functor)
-- The boolean indicates whether or not this is a terminal / solution node

leftDepth (T _ _ []) = 1
leftDepth (T _ _ (t:ts)) = 1 + leftDepth t

leftWidths (T _ _ []) = []
leftWidths (T _ _ ts@(t:_)) = length ts : leftWidths t

graphAutsEdgeSearchTree (G vs es) = dfs [] vs vs where
    dfs xys (x:xs) yys = T False xys [dfs ((x,y):xys) xs ys | (y,ys) <- picks yys, isCompatible xys (x,y)]
    dfs xys [] [] = T True xys []
    isCompatible xys (x',y') = and [([x,x'] `S.member` es') == (L.sort [y,y'] `S.member` es') | (x,y) <- xys]
    es' = S.fromList es

graphAuts1 = map fromPairs . terminals . graphAutsEdgeSearchTree

terminals (T False _ ts) = concatMap terminals ts
terminals (T True xys _) = [xys]

-- Using Lemma 9.1.1 from Seress p203 to prune the search tree
-- Because auts form a group, it is sufficient to expand only each leftmost branch of the tree in full.
-- For every other branch, it is sufficient to find a single representative, since the other elements
-- can then be obtained by multiplication in the group (using the leftmost elements).
-- In effect, we are finding a transversal generating set.
-- Note however, that this transversal generating set is relative to whatever base order the tree uses,
-- so for clarity, the tree should use natural vertex order.
transversalTerminals (T False _ (t:ts)) = concatMap (take 1 . transversalTerminals) ts ++ transversalTerminals t
-- transversalTerminals (T False _ (t:ts)) = transversalTerminals t ++ concatMap (take 1 . transversalTerminals) ts
transversalTerminals (T True xys _) = [xys]
transversalTerminals _ = []

graphAuts2 = filter (/=1) . map fromPairs . transversalTerminals . graphAutsEdgeSearchTree
-- init because last is identity

isSingleton [_] = True
isSingleton _ = False

intersectCells p1 p2 = concat [ [c1 `intersectAsc` c2 | c2 <- p2] | c1 <- p1]
-- Intersection preserves ordering within cells but not between cells
-- eg the cell [1,2,3,4] could be refined to [2,4],[1,3]

graphAutsDistancePartitionSearchTree g@(G vs es) = dfs [] ([vs],[vs]) where
    dfs xys (srcPart,trgPart)
        | all isSingleton srcPart =
             let xys' = zip (concat srcPart) (concat trgPart)
             in T (isCompatible xys') (xys++xys') []
             -- Since the xys' are distance-compatible with the xys, they are certainly edge-compatible.
             -- However, we do need to check that the xys' are edge-compatible with each other.
        | otherwise = let (x:xs):srcCells = srcPart
                          yys   :trgCells = trgPart
                          srcPart' = intersectCells (xs : srcCells) (dps M.! x)
                      in T False xys -- the L.sort in the following line is so that we traverse vertices in natural order
                         [dfs ((x,y):xys) ((unzip . L.sort) (zip (filter (not . null) srcPart') (filter (not . null) trgPart')))
                         | (y,ys) <- picks yys,
                           let trgPart' = intersectCells (ys : trgCells) (dps M.! y),
                           map length srcPart' == map length trgPart']
    isCompatible xys = and [([x,x'] `S.member` es') == (L.sort [y,y'] `S.member` es') | (x,y) <- xys, (x',y') <- xys, x < x']
    es' = S.fromList es
    dps = M.fromAscList [(v, distancePartitionS vs es' v) | v <- vs]

graphAuts3 = filter (/=1) . map fromPairs . transversalTerminals . graphAutsDistancePartitionSearchTree

-- Whereas transversalTerminals produced a transversal generating set, here we produce a strong generating set.
-- In particular, if we have already found (3 4), and then we find (1 2 3),
-- then there is no need to look for (1 3 ...) or (1 4 ...), since it is clear that such elements exist
-- as products of those we have already found.
strongTerminals = strongTerminals' [] where
    strongTerminals' gs (T False xys ts) =
        case listToMaybe $ reverse $ filter (\(x,y) -> x /= y) xys of -- the first vertex that isn't fixed
        Nothing -> L.foldl' (\hs t -> strongTerminals' hs t) gs ts
        Just (x,y) -> if y `elem` (x .^^ gs)
                      then gs
                      -- Since we're not on the leftmost spine, we can stop as soon as we find one new element
                      else find1New gs ts
                      -- else L.foldl' (\hs t -> if hs /= gs then hs else strongTerminals' hs t) gs ts
    strongTerminals' gs (T True xys []) = fromPairs xys : gs
    find1New gs (t:ts) = let hs = strongTerminals' gs t
                         in if take 1 gs /= take 1 hs -- we know a new element would be placed at the front
                            then hs
                            else find1New gs ts
    find1New gs [] = gs

-- |Given a graph g, @graphAuts g@ returns a strong generating set for the automorphism group of g.
graphAuts :: (Ord a) => Graph a -> [Permutation a]
graphAuts = filter (/=1) . strongTerminals . graphAutsDistancePartitionSearchTree

-- Using colourings (M.Map vertex colour, M.Map colour [vertex]), in place of partitions ([[vertex]])
-- This turns out to be slower than using partitions.
-- Updating the colour partition incrementally seems to be much less efficient than just recalculating it each time
-- (Recalculating each time is O(n), incrementally updating is O(n^2)?)
graphAutsDistanceColouringSearchTree g@(G vs es) = dfs [] unitCol unitCol where
    unitCol = (M.fromList $ map (,[]) vs, M.singleton [] vs) -- "unit colouring"
    dfs xys srcColouring@(srcVmap,srcCmap) trgColouring@(trgVmap,trgCmap)
        -- ( | M.map length srcCmap /= M.map length trgCmap = T False xys [] )
        | all isSingleton (M.elems srcCmap) = -- discrete colouring
             let xys' = zip (concat $ M.elems srcCmap) (concat $ M.elems trgCmap)
             in T (isCompatible xys') (reverse xys'++xys) []
             -- Since the xys' are distance-compatible with the xys, they are certainly edge-compatible.
             -- However, we do need to check that the xys' are edge-compatible with each other.
        | otherwise = let (x,c) = M.findMin srcVmap
                          (xVmap,xCmap) = dcs M.! x
                          ys = trgCmap M.! c
                          srcVmap' = M.delete x (intersectColouring srcVmap xVmap)
                          srcCmap' = colourPartition srcVmap'
                          -- srcCmap' = M.fromAscList [(c1++c2, cell) | (c1,srcCell) <- M.assocs srcCmap, (c2,xCell) <- M.assocs xCmap,
                          --                                            let cell = L.delete x (intersectAsc srcCell xCell),
                          --                                            (not . null) cell]
                      in T False xys
                         [dfs ((x,y):xys) (srcVmap',srcCmap') (trgVmap',trgCmap')
                         | y <- ys,
                           let (yVmap,yCmap) = dcs M.! y,
                           let trgVmap' = M.delete y (intersectColouring trgVmap yVmap),
                           let trgCmap' = colourPartition trgVmap',
                           -- let trgCmap' = M.fromAscList [(c1++c2, cell) | (c1,trgCell) <- M.assocs trgCmap, (c2,yCell) <- M.assocs yCmap,
                           --                                                let cell = L.delete y (intersectAsc trgCell yCell),
                           --                                                (not . null) cell],
                           M.map length srcCmap' == M.map length trgCmap' ]
    isCompatible xys = and [([x,x'] `S.member` es') == (L.sort [y,y'] `S.member` es') | (x,y) <- xys, (x',y') <- xys, x < x']
    es' = S.fromList es
    dcs = M.fromAscList [(v, distanceColouring v) | v <- vs]
    distanceColouring u = let dp = distancePartitionS vs es' u
                              vmap = M.fromList [(v,[c]) | (cell,c) <- zip dp [0..], v <- cell]
                              cmap = M.fromList $ zip (map (:[]) [0..]) dp
                          in (vmap, cmap)

-- If we are going to recalculate the colour partition each time anyway,
-- then we don't need to carry it around, and can simplify the code
graphAutsDistanceColouringSearchTree g@(G vs es) = dfs [] initCol initCol where
    initCol = M.fromList $ map (,[]) vs
    dfs xys srcCol trgCol
        | M.map length srcPart /= M.map length trgPart = T False xys []
        | all isSingleton (M.elems srcPart) =
             let xys' = zip (concat $ M.elems srcPart) (concat $ M.elems trgPart)
             in T (isCompatible xys') (reverse xys'++xys) []
             -- Since the xys' are distance-compatible with the xys, they are certainly edge-compatible.
             -- However, we do need to check that the xys' are edge-compatible with each other.
        | otherwise = let (x,c) = M.findMin srcCol
                          ys = trgPart M.! c
                          srcCol' = M.delete x $ intersectColouring srcCol (dcs M.! x)
                      in T False xys
                         [dfs ((x,y):xys) srcCol' trgCol'
                         | y <- ys,
                           let trgCol' = M.delete y (intersectColouring trgCol (dcs M.! y))]
        where srcPart = colourPartition srcCol
              trgPart = colourPartition trgCol
    isCompatible xys = and [([x,x'] `S.member` es') == (L.sort [y,y'] `S.member` es') | (x,y) <- xys, (x',y') <- xys, x < x']
    es' = S.fromList es
    dcs = M.fromAscList [(v, distanceColouring v) | v <- vs]
    distanceColouring u = M.fromList [(v,[c]) | (cell,c) <- zip (distancePartitionS vs es' u) [0..], v <- cell]
distanceColouring (G vs es) u = M.fromList [(v,[c]) | (cell,c) <- zip (distancePartitionS vs es' u) [0..], v <- cell]
    where es' = S.fromList es

intersectColouring c1 c2 = M.intersectionWith (++) c1 c2

colourPartition c = L.foldr (\(k,v) m -> M.insertWith (++) v [k] m) M.empty (M.assocs c)

-- Based on McKay’s Canonical Graph Labeling Algorithm, by Stephen G. Hartke and A. J. Radcliffe
-- (http://www.math.unl.edu/~aradcliffe1/Papers/Canonical.pdf)

equitableRefinement g@(G vs es) p = equitableRefinement' (S.fromList es) p

equitableRefinement' edgeset partition = go partition where
    go cells = let splits = L.zip (L.inits cells) (L.tails cells)
                   shatterPairs = [(L.zip ci counts,ls,rs) | (ls,ci:rs) <- splits, cj <- cells,
                                                             let counts = map (nbrCount cj) ci, isShatter counts]
               in case shatterPairs of -- by construction, the lexicographic least (i,j) comes first
                  [] -> cells
                  (vcs,ls,rs):_ -> let fragments = shatter vcs
                                   in go (ls ++ fragments ++ rs)
    isShatter (c:cs) = any (/= c) cs
    shatter vcs = map (map fst) $ L.groupBy (\x y -> snd x == snd y) $ L.sortBy (comparing snd) $ vcs
    -- Memoizing here results in about 10% speed improvement. Not worth it for loss of generality (ie requiring HasTrie instances)
    -- nbrCount = memo2 nbrCount'
    -- How many neighbours in cell does vertex have
    nbrCount cell vertex = length (filter (isEdge vertex) cell)
    isEdge u v = L.sort [u,v] `S.member` edgeset

equitablePartitionSearchTree g@(G vs es) p = dfs [] p where
    dfs bs p = let p' = equitableRefinement' es' p in
               if all isSingleton p'
               then T True (p',bs) []
               else T False (p',bs) [dfs (b:bs) p'' | (b,p'') <- splits [] p']
    -- For now, we just split the first non-singleton cell we find
    splits ls (r:rs) | isSingleton r = splits (r:ls) rs
                     | otherwise = let ls' = reverse ls in [(x, ls' ++ [x]:xs:rs) | (x,xs) <- picks r]
    es' = S.fromList es

-- Using Data.Sequence instead of list for the partitions
-- Makes no difference to speed (in fact slightly slower)
equitableRefinementSeq' edgeset partition = go partition where
    go cells = let splits = Seq.zip (Seq.inits cells) (Seq.tails cells)
                   shatterPairs = [(L.zip ci counts,ls,rs') | (ls,rs) <- Foldable.toList splits, (not . Seq.null) rs, let ci Seq.:< rs' = Seq.viewl rs,
                                                              cj <- Foldable.toList cells,
                                                              let counts = map (nbrCount cj) ci, isShatter counts]
               in case shatterPairs of -- by construction, the lexicographic least (i,j) comes first
                  [] -> cells
                  (vcs,ls,rs):_ -> let fragments = Seq.fromList (shatter vcs) 
                                   in go (ls Seq.>< fragments Seq.>< rs)
    isShatter (c:cs) = any (/= c) cs
    shatter vcs = map (map fst) $ L.groupBy (\x y -> snd x == snd y) $ L.sortBy (comparing snd) $ vcs
    -- How many neighbours in cell does vertex have
    nbrCount cell vertex = length (filter (isEdge vertex) cell)
    isEdge u v = L.sort [u,v] `S.member` edgeset

equitablePartitionSeqSearchTree g@(G vs es) p = dfs [] (Seq.fromList p) where
    dfs bs p = let p' = equitableRefinementSeq' es' p in
               if Foldable.all isSingleton p'
               then T True (Foldable.toList p',bs) []
               else T False (Foldable.toList p',bs) [dfs (b:bs) p'' | (b,p'') <- splits p']
    -- For now, we just split the first non-singleton cell we find
    splits cells = case Seq.findIndexL (not . isSingleton) cells of
                   Just i -> let (ls,rs) = Seq.splitAt i cells
                                 r Seq.:< rs' = Seq.viewl rs
                             in [(x, ls Seq.>< ([x] Seq.<| xs Seq.<| rs')) | (x,xs) <- picks r]
                   Nothing -> error "Not possible, as we know there are non-singleton cells"
    es' = S.fromList es

-- In this version, whenever we have an equitable partition, we separate out all the singleton cells and put them to one side.
-- (Since the partition is equitable, singleton cells have already done any work they are going to do in shattering other cells,
-- so they will no longer play any part.)
-- This seems to result in about 20% speedup.
equitablePartitionSearchTree2 g@(G vs es) p = dfs [] ([],p) where
    dfs bs (ss,cs) = let (ss',cs') = L.partition isSingleton $ equitableRefinement' es' cs
                         ss'' = ss++ss'
                     in case cs' of
                        [] -> T True (ss'',bs) []
                        -- We just split the first non-singleton cell
                        -- c:cs'' -> T False (ss''++cs',bs) [dfs (x:bs) (ss'',[x]:xs:cs'') | (x,xs) <- picks c]
                        c:cs'' -> T False (cs'++ss'',bs) [dfs (x:bs) (ss'',[x]:xs:cs'') | (x,xs) <- picks c]
    es' = S.fromList es
-- TODO: On the first level, we can use a stronger partitioning function (eg distance partitions, + see nauty manual, vertex invariants)

equitableDistancePartitionSearchTree g@(G vs es) p = dfs [] p where
    dfs bs p = let p' = equitableRefinement' es' p in
               if all isSingleton p'
               then T True (p',bs) []
               else T False (p',bs) [dfs (b:bs) p'' | (b,p'') <- splits [] p']
    -- For now, we just split the first non-singleton cell we find
    splits ls (r:rs) | isSingleton r = splits (r:ls) rs
                     | otherwise = [(x, p'') | let ls' = reverse ls,
                                               (x,xs) <- picks r,
                                               let p' = ls' ++ [x]:xs:rs,
                                               let p'' = filter (not . null) (intersectCells p' (dps M.! x))]
    es' = S.fromList es
    dps = M.fromAscList [(v, distancePartitionS vs es' v) | v <- vs]

-- This is just fmap (\(p,bs) -> (p,bs,trace p)) t
equitablePartitionTracedSearchTree g@(G vs es) trace p = dfs [] p where
    dfs bs p = let p' = equitableRefinement' es' p
               in if all isSingleton p'
                  then T True (p',bs,trace p') []
                  else T False (p',bs,trace p') [dfs (b:bs) p'' | (b,p'') <- splits [] p']
    -- For now, we just split the first non-singleton cell we find
    splits ls (r:rs) | isSingleton r = splits (r:ls) rs
                     | otherwise = let ls' = reverse ls in [(x, ls' ++ [x]:xs:rs) | (x,xs) <- picks r]
    es' = S.fromList es

-- Intended as a node invariant
trace1 p = map (\xs@(x:_) -> (x, length xs)) $ L.group $ L.sort $ map length p

equitablePartitionGraphSearchTree g@(G vs es) = equitablePartitionSearchTree g unitPartition
    where unitPartition = [vs]

-- The incidence graph has vertices that are coloured left (points) or right (blocks).
-- We are not interested in dualities (automorphisms that swap points and blocks), so we look for colour-preserving automorphisms
equitablePartitionIncidenceSearchTree g@(G vs es) = equitablePartitionSearchTree g lrPartition
    where (lefts, rights) = partitionEithers vs
          lrPartition = [map Left lefts, map Right rights]

leftLeaf (T False _ (t:ts)) = leftLeaf t
leftLeaf (T True (p,bs) []) = (concat p, reverse bs)
leftSpine (T False x (t:ts)) = x : leftSpine t
leftSpine (T True x []) = [x]
allLeaves (T False _ ts) = concatMap allLeaves ts
allLeaves (T True (p,bs) []) = [(concat p, reverse bs)]

partitionTransversals tree = [fromPairs (zip canonical partition) | partition <- findTransversals tree] where
    (_,canonical) = leftLeaf tree
    findTransversals (T False _ (t:ts)) = concatMap (take 1 . findTransversals) ts ++ findTransversals t
    findTransversals (T True (_,partition) []) = [concat partition]

graphAuts5 = partitionTransversals . equitablePartitionGraphSearchTree
partitionBSGS0 g@(G vs es) t = (bs, findLevels t) where
    (p1,bs) = leftLeaf t
    g1 = fromPairs $ zip p1 vs
    g1' = g1^-1
    es1 = S.fromList $ edges $ fmap (.^ g1) g -- the edges of the isomorph corresponding to p1. (S.fromList makes it unnecessary to call nf.)
    findLevels (T True (partition,_) []) = []
    findLevels (T False (partition,_) (t:ts)) =
        let hs = findLevels t
            -- TODO: It might be better to use the b that is added in t to find the cell that splits
            cell@(v:vs) = head $ filter (not . isSingleton) partition -- the cell that is going to split
        in findLevel v hs (zip vs ts)
    findLevel v hs ((v',t'):vts) = if v' `elem` v .^^ hs
                                   then findLevel v hs vts
                                   else let h = find1New t' in findLevel v (h++hs) vts
    findLevel _ hs [] = hs
    find1New (T False _ ts) = take 1 $ concatMap find1New ts
    -- There is a leaf for every aut, but not necessarily an aut for every leaf, so we must check we have an aut
    -- (For example, incidenceGraphPG 2 f8 has leaf nodes which do not correspond to auts.)
    find1New (T True (partition,_) []) = let h = fromPairs $ zip (concat partition) vs
                                             g' = fmap (.^ h) g
                                         in if all (`S.member` es1) (edges g') then [h*g1'] else []
    -- isAut h = all (`S.member` es') [e -^ h | e <- es]
    -- es' = S.fromList es

-- Given a partition search tree, return a base and strong generating set for graph automorphism group.
partitionBSGS g@(G vs es) t = (bs, findLevels t) where
    (canonical,bs) = leftLeaf t
    findLevels (T True (partition,_) []) = []
    findLevels (T False (partition,_) (t:ts)) =
        let hs = findLevels t
            -- TODO: It might be better to use the b that is added in t to find the cell that splits
            cell@(v:vs) = head $ filter (not . isSingleton) partition -- the cell that is going to split
        in findLevel v hs (zip vs ts)
    findLevel v hs ((v',t'):vts) = if v' `elem` v .^^ hs -- TODO: Memoize this orbit
                                   then findLevel v hs vts
                                   else let h = find1New t' in findLevel v (h++hs) vts
    findLevel _ hs [] = hs
    find1New (T False _ ts) = take 1 $ concatMap find1New ts
    -- Some leaf nodes correspond to different isomorphs of the graph, and hence don't yield automorphisms
    find1New (T True (partition,_) []) = let h = fromPairs $ zip canonical (concat partition)
                                         in filter isAut [h]
    isAut h = all (`S.member` es') [e -^ h | e <- es]
    es' = S.fromList es
-- The tree for g1 has leaf nodes of two different isomorphs, as does the tree for incidenceGraphPG 2 f8

-- Returns auts as Right, different isomorphs as Left
-- (Must be used with the tree which doesn't put singletons to end)
partitionBSGS3 g@(G vs es) t = (bs, findLevels t) where
    (p1,bs) = leftLeaf t
    findLevels (T True (partition,_) []) = []
    findLevels (T False (partition,_) (t:ts)) =
        let hs = findLevels t
            -- TODO: It might be better to use the b that is added in t to find the cell that splits
            cell@(v:vs) = head $ filter (not . isSingleton) partition -- the cell that is going to split
        in findLevel v hs (zip vs ts)
    findLevel v hs ((v',t'):vts) = if v' `elem` v .^^ rights hs
                                   then findLevel v hs vts
                                   else let h = find1New t' in findLevel v (h++hs) vts
    findLevel _ hs [] = hs
    find1New (T False _ ts) = take 1 $ concatMap find1New ts
    -- There is a leaf for every aut, but not necessarily an aut for every leaf, so we must check we have an aut
    -- (For example, incidenceGraphPG 2 f8 has leaf nodes which do not correspond to auts.)
    find1New (T True (partition,_) []) = let h = fromPairs $ zip p1 (concat partition)
                                         in if isAut h then [Right h] else [Left h]
    isAut h = all (`S.member` es') [e -^ h | e <- es]
    es' = S.fromList es
-- TODO: I think we are only justified in doing find1New (ie only finding 1) if we *do* find an aut.
-- If we don't, we should potentially keep looking in that subtree
-- (See section 6 of paper. If we find isomorphic leaves, then the two subtrees of their common parent are isomorphic,
-- so no need to continue searching the second.)

-- This is using a node invariant to do more pruning.
-- However, seems to be much slower on very regular graphs (where perhaps there is no pruning to be done)
-- (This suggests that perhaps using fmap is not good - perhaps a space leak?)
-- (Or perhaps it's just that calculating and comparing the node invariants is expensive)
-- TODO: Perhaps use something simpler, like just the number of cells in the partition
partitionBSGS2 g@(G vs es) t = (bs, findLevels t') where
    t' = fmap (\(p,bs) -> (p,bs,trace1 p)) t
    trace1 = length -- the number of cells in the partition
    (canonical,bs) = leftLeaf t
    findLevels (T True (partition,_,_) []) = []
    findLevels (T False (partition,_,_) (t:ts)) =
        let (T _ (_,_,trace) _) = t
            hs = findLevels t
            -- TODO: It might be better to use the b that is added in t to find the cell that splits
            cell@(v:vs) = head $ filter (not . isSingleton) partition -- the cell that is going to split
            vts = filter (\(_,T _ (_,_,trace') _) -> trace == trace') $ zip vs ts
        in findLevel v hs vts
    findLevel v hs ((v',t'):vts) = if v' `elem` v .^^ hs
                                   then findLevel v hs vts
                                   else let h = find1New t' in findLevel v (h++hs) vts
    findLevel _ hs [] = hs
    find1New (T False _ ts) = take 1 $ concatMap find1New ts
    -- There is a leaf for every aut, but not necessarily an aut for every leaf, so we must check we have an aut
    -- (For example, incidenceGraphPG 2 f8 has leaf nodes which do not correspond to auts.)
    -- (The graph g1, below, shows a simple example where this will happen.)
    find1New (T True (partition,_,_) []) = let h = fromPairs $ zip canonical (concat partition)
                                           in filter isAut [h]
    isAut h = all (`S.member` es') [e -^ h | e <- es]
    es' = S.fromList es

graphAuts7 g = (partitionBSGS g) (equitablePartitionGraphSearchTree g)

-- This is faster on kneser graphs, but slower on incidenceGraphPG
graphAuts8 g = (partitionBSGS g) (equitableDistancePartitionSearchTree g [vertices g])

-- This is a graph where the node invariant should cause pruning.
-- The initial equitable partition will be [[1..8],[9,10]], because it can do no better than distinguish by degree
-- However, vertices 1..4 and vertices 5..8 are in fact different (there is no aut that takes one set to the other),
-- so the subtrees starting 1..4 have a different invariant to those starting 5..8
g1 = G [1..10] [[1,2],[1,3],[1,9],[2,4],[2,10],[3,4],[3,9],[4,10],[5,6],[5,8],[5,9],[6,7],[6,10],[7,8],[7,9],[8,10]]

g1' = nf $ fmap (\x -> if x <= 4 then x+4 else if x <= 8 then x-4 else x) g1
-- G [1..10] [[1,2],[1,4],[1,9],[2,3],[2,10],[3,4],[3,9],[4,10],[5,6],[5,7],[5,9],[6,8],[6,10],[7,8],[7,9],[8,10]]

g2 = G [1..12] [[1,2],[1,4],[1,11],[2,3],[2,12],[3,4],[3,11],[4,12],[5,6],[5,8],[5,11],[6,9],[6,12],[7,8],[7,10],[7,11],[8,12],[9,10],[9,11],[10,12]]

-- NOT WORKING: This fails to find the isomorphism between g1 and g1' above.
-- Instead of using left leaf, we need to find the canonical isomorph, as described in the paper.
-- (In a graph where not all leaves lead to automorphisms, we might happen to end up with non-isomorphic left leaves)
maybeGraphIso g1 g2 = let (vs1,_) = (leftLeaf . equitablePartitionGraphSearchTree) g1
                          (vs2,_) = (leftLeaf . equitablePartitionGraphSearchTree) g2
                          f = M.fromList (zip vs1 vs2)
                      in if length vs1 == length vs2 && (nf . fmap (f M.!)) g1 == g2 then Just f else Nothing


-- This code is nearly identical to the corresponding graphAuts code, with two exceptions:
-- 1. We start by partitioning into lefts and rights.
-- This avoids left-right crossover auts, which while they are auts of the graph,
-- are not auts of the incidence structure
-- 2. When labelling the nodes, we filter out Right blocks, and unLeft the Left points
incidenceAutsDistancePartitionSearchTree g@(G vs es) = dfs [] (lrPart, lrPart) where
    dfs xys (srcPart,trgPart)
        | all isSingleton srcPart =
             let xys' = zip (concat srcPart) (concat trgPart)
             in T (isCompatible xys') (unLeft $ xys++xys') []
             -- Since the xys' are distance-compatible with the xys, they are certainly edge-compatible.
             -- However, we do need to check that the xys' are edge-compatible with each other.
        | otherwise = let (x:xs):srcCells = srcPart
                          yys   :trgCells = trgPart
                          srcPart' = intersectCells (xs : srcCells) (dps M.! x)
                      in T False (unLeft xys) -- the L.sort in the following line is so that we traverse vertices in natural order
                         [dfs ((x,y):xys) ((unzip . L.sort) (zip (filter (not . null) srcPart') (filter (not . null) trgPart')))
                         | (y,ys) <- picks yys,
                           let trgPart' = intersectCells (ys : trgCells) (dps M.! y),
                           map length srcPart' == map length trgPart']
    isCompatible xys = and [([x,x'] `S.member` es') == (L.sort [y,y'] `S.member` es') | (x,y) <- xys, (x',y') <- xys, x < x']
    (lefts, rights) = partitionEithers vs
    lrPart = [map Left lefts, map Right rights] -- Partition the vertices into left and right, to exclude crossover auts
    unLeft xys = [(x,y) | (Left x, Left y) <- xys] -- also filters out Rights
    es' = S.fromList es
    dps = M.fromList [(v, distancePartitionS vs es' v) | v <- vs]

-- |Given the incidence graph of an incidence structure between points and blocks
-- (for example, a set system),
-- @incidenceAuts g@ returns a strong generating set for the automorphism group of the incidence structure.
-- The generators are represented as permutations of the points.
-- The incidence graph should be represented with the points on the left and the blocks on the right.
incidenceAuts :: (Ord p, Ord b) => Graph (Either p b) -> [Permutation p]
incidenceAuts = filter (/= p []) . strongTerminals . incidenceAutsDistancePartitionSearchTree

-- TODO: Filter out rights, map unLeft - to bs and gs
incidenceAuts2 g = (partitionBSGS g) (equitablePartitionIncidenceSearchTree g)
    where unLeft (Left x) = x
          -- map (\g -> fromPairs . map (\(Left x, Left y) -> (x,y)) . filter (\(x,y) -> isLeft x) . toPairs) gs


-- !! not yet using equitable partitions, so could probably be more efficient

-- graphIsos :: (Ord a, Ord b) => Graph a -> Graph b -> [[(a,b)]]
graphIsos g1 g2
    | length cs1 /= length cs2 = []
    | otherwise = graphIsos' cs1 cs2
    where cs1 = map (inducedSubgraph g1) (components g1)
          cs2 = map (inducedSubgraph g2) (components g2)
          graphIsos' (ci:cis) cjs =
              [iso ++ iso' | (cj,cjs') <- picks cjs,
                             iso <- graphIsosCon ci cj,
                             iso' <- graphIsos' cis cjs']
          graphIsos' [] [] = [[]]

-- isos between connected graphs
graphIsosCon g1 g2
    | isConnected g1 && isConnected g2
        = concat [dfs [] (distancePartition g1 v1) (distancePartition g2 v2)
                 | v1 <- take 1 (vertices g1), v2 <- vertices g2]
                 -- the take 1 handles the case where g1 is the null graph
    | otherwise = error "graphIsosCon: either or both graphs are not connected"
    where dfs xys p1 p2
              | map length p1 /= map length p2 = []
              | otherwise =
                  let p1' = filter (not . null) p1
                      p2' = filter (not . null) p2
                  in if all isSingleton p1'
                     then let xys' = xys ++ zip (concat p1') (concat p2')
                          in if isCompatible xys' then [xys'] else []
                     else let (x:xs):p1'' = p1'
                              ys:p2'' = p2'
                          in concat [dfs ((x,y):xys)
                                         (intersectCells (xs : p1'') (dps1 M.! x))
                                         (intersectCells (ys': p2'') (dps2 M.! y))
                                         | (y,ys') <- picks ys]
          isCompatible xys = and [([x,x'] `S.member` es1) == (L.sort [y,y'] `S.member` es2) | (x,y) <- xys, (x',y') <- xys, x < x']
          dps1 = M.fromAscList [(v, distancePartitionS vs1 es1 v) | v <- vs1]
          dps2 = M.fromAscList [(v, distancePartitionS vs2 es2 v) | v <- vs2]
          -- dps1 = M.fromList [(v, distancePartition g1 v) | v <- vertices g1]
          -- dps2 = M.fromList [(v, distancePartition g2 v) | v <- vertices g2]
          vs1 = vertices g1
          vs2 = vertices g2
          es1 = S.fromList $ edges g1
          es2 = S.fromList $ edges g2

-- |Are the two graphs isomorphic?
isGraphIso :: (Ord a, Ord b) => Graph a -> Graph b -> Bool
isGraphIso g1 g2 = (not . null) (graphIsos g1 g2)
-- !! If we're only interested in seeing whether or not two graphs are iso,
-- !! then the cost of calculating distancePartitions may not be warranted
-- !! (see Math.Combinatorics.Poset: orderIsos01 versus orderIsos)

-- the following differs from graphIsos in only two ways
-- we avoid Left, Right crossover isos, by insisting that a Left is taken to a Left (first two lines)
-- we return only the action on the Lefts, and unLeft it
-- incidenceIsos :: (Ord p1, Ord b1, Ord p2, Ord b2) =>
--     Graph (Either p1 b1) -> Graph (Either p2 b2) -> [[(p1,p2)]]

incidenceIsos g1 g2
    | length cs1 /= length cs2 = []
    | otherwise = incidenceIsos' cs1 cs2
    where cs1 = map (inducedSubgraph g1) (filter (not . null . lefts) $ components g1)
          cs2 = map (inducedSubgraph g2) (filter (not . null . lefts) $ components g2)
          incidenceIsos' (ci:cis) cjs =
              [iso ++ iso' | (cj,cjs') <- picks cjs,
                             iso <- incidenceIsosCon ci cj,
                             iso' <- incidenceIsos' cis cjs']
          incidenceIsos' [] [] = [[]]

incidenceIsosCon g1 g2
    | isConnected g1 && isConnected g2
        = concat [dfs [] (distancePartition g1 v1) (distancePartition g2 v2)
                 | v1@(Left _) <- take 1 (vertices g1), v2@(Left _) <- vertices g2]
                 -- g1 may have no vertices
    | otherwise = error "incidenceIsos: one or both graphs not connected"
    where dfs xys p1 p2
              | map length p1 /= map length p2 = []
              | otherwise =
                  let p1' = filter (not . null) p1
                      p2' = filter (not . null) p2
                  in if all isSingleton p1'
                     then let xys' = xys ++ zip (concat p1') (concat p2')
                          in if isCompatible xys' then [[(x,y) | (Left x, Left y) <- xys']] else []
                     else let (x:xs):p1'' = p1'
                              ys:p2'' = p2'
                          in concat [dfs ((x,y):xys)
                                         (intersectCells (xs : p1'') (dps1 M.! x))
                                         (intersectCells (ys': p2'') (dps2 M.! y))
                                         | (y,ys') <- picks ys]
          isCompatible xys = and [([x,x'] `S.member` es1) == (L.sort [y,y'] `S.member` es2) | (x,y) <- xys, (x',y') <- xys, x < x']
          dps1 = M.fromList [(v, distancePartition g1 v) | v <- vertices g1]
          dps2 = M.fromList [(v, distancePartition g2 v) | v <- vertices g2]
          es1 = S.fromList $ edges g1
          es2 = S.fromList $ edges g2

-- |Are the two incidence structures represented by these incidence graphs isomorphic?
isIncidenceIso :: (Ord p1, Ord b1, Ord p2, Ord b2) =>
     Graph (Either p1 b1) -> Graph (Either p2 b2) -> Bool
isIncidenceIso g1 g2 = (not . null) (incidenceIsos g1 g2)