123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- using System;
- using System.Linq;
- using System.Collections.Generic;
- namespace GraphProcessor
- {
- public static class GraphUtils
- {
- enum State
- {
- White,
- Grey,
- Black,
- }
- class TarversalNode
- {
- public BaseNode node;
- public List<TarversalNode> inputs = new List<TarversalNode>();
- public List<TarversalNode> outputs = new List<TarversalNode>();
- public State state = State.White;
- public TarversalNode(BaseNode node) { this.node = node; }
- }
- // A structure made for easy graph traversal
- class TraversalGraph
- {
- public List<TarversalNode> nodes = new List<TarversalNode>();
- public List<TarversalNode> outputs = new List<TarversalNode>();
- }
- static TraversalGraph ConvertGraphToTraversalGraph(BaseGraph graph)
- {
- TraversalGraph g = new TraversalGraph();
- Dictionary<BaseNode, TarversalNode> nodeMap = new Dictionary<BaseNode, TarversalNode>();
- foreach (var node in graph.nodes)
- {
- var tn = new TarversalNode(node);
- g.nodes.Add(tn);
- nodeMap[node] = tn;
- if (graph.graphOutputs.Contains(node))
- g.outputs.Add(tn);
- }
- foreach (var tn in g.nodes)
- {
- tn.inputs = tn.node.GetInputNodes().Where(n => nodeMap.ContainsKey(n)).Select(n => nodeMap[n]).ToList();
- tn.outputs = tn.node.GetOutputNodes().Where(n => nodeMap.ContainsKey(n)).Select(n => nodeMap[n]).ToList();
- }
- return g;
- }
- public static List<BaseNode> DepthFirstSort(BaseGraph g)
- {
- var graph = ConvertGraphToTraversalGraph(g);
- List<BaseNode> depthFirstNodes = new List<BaseNode>();
- foreach (var n in graph.nodes)
- DFS(n);
- void DFS(TarversalNode n)
- {
- if (n.state == State.Black)
- return;
-
- n.state = State.Grey;
- if (n.node is ParameterNode parameterNode && parameterNode.accessor == ParameterAccessor.Get)
- {
- foreach (var setter in graph.nodes.Where(x=>
- x.node is ParameterNode p &&
- p.parameterGUID == parameterNode.parameterGUID &&
- p.accessor == ParameterAccessor.Set))
- {
- if (setter.state == State.White)
- DFS(setter);
- }
- }
- else
- {
- foreach (var input in n.inputs)
- {
- if (input.state == State.White)
- DFS(input);
- }
- }
- n.state = State.Black;
- // Only add the node when his children are completely visited
- depthFirstNodes.Add(n.node);
- }
- return depthFirstNodes;
- }
- public static void FindCyclesInGraph(BaseGraph g, Action<BaseNode> cyclicNode)
- {
- var graph = ConvertGraphToTraversalGraph(g);
- List<TarversalNode> cyclicNodes = new List<TarversalNode>();
- foreach (var n in graph.nodes)
- DFS(n);
- void DFS(TarversalNode n)
- {
- if (n.state == State.Black)
- return;
-
- n.state = State.Grey;
- foreach (var input in n.inputs)
- {
- if (input.state == State.White)
- DFS(input);
- else if (input.state == State.Grey)
- cyclicNodes.Add(n);
- }
- n.state = State.Black;
- }
- cyclicNodes.ForEach((tn) => cyclicNode?.Invoke(tn.node));
- }
- }
- }
|