dnet-docker/dnet-app/libs/dnet-wf-common/src/main/java/eu/dnetlib/wfs/utils/GraphUtils.java

160 lines
5.6 KiB
Java

package eu.dnetlib.wfs.utils;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.context.expression.MapAccessor;
import org.springframework.core.env.Environment;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import eu.dnetlib.domain.wfs.WorkflowsConstants;
import eu.dnetlib.domain.wfs.graph.Node;
import eu.dnetlib.domain.wfs.graph.runtime.RuntimeArc;
import eu.dnetlib.domain.wfs.graph.runtime.RuntimeGraph;
import eu.dnetlib.domain.wfs.graph.runtime.RuntimeNode;
import eu.dnetlib.errors.WorkflowManagerException;
import eu.dnetlib.wfs.procs.RuntimeEnv;
public class GraphUtils {
public static RuntimeGraph prepareRuntineGraph(final List<Node> nodes, final Map<String, Object> globalParams, final Environment environment)
throws WorkflowManagerException {
final RuntimeGraph graph = new RuntimeGraph();
graph.getNodes().put(WorkflowsConstants.SUCCESS_NODE, RuntimeNode.newSuccessNode());
for (final Node node : nodes) {
final String nodeName = node.getName();
final String nodeType = node.getType();
final boolean isStart = node.isStart();
final boolean isJoin = node.isJoin();
final Map<String, Object> params = node.calculateInitialParams(globalParams, environment);
final Map<String, Object> envParams = node.findEnvParams();
final Map<String, String> outputEnvMap = node.outputToEnvMap();
if (isStart) {
graph.getNodes().put(nodeName, RuntimeNode.newStartNode(nodeName, nodeType, params, envParams, outputEnvMap));
} else if (isJoin) {
graph.getNodes().put(nodeName, RuntimeNode.newJoinNode(nodeName, nodeType, params, envParams, outputEnvMap));
} else {
graph.getNodes().put(nodeName, RuntimeNode.newNode(nodeName, nodeType, params, envParams, outputEnvMap));
}
if ((node.getArcs() == null) || node.getArcs().isEmpty()) {
graph.getArcs().add(new RuntimeArc(nodeName, WorkflowsConstants.SUCCESS_NODE, null));
} else {
node.getArcs().forEach(a -> graph.getArcs().add(new RuntimeArc(nodeName, a.getTo(), a.getCondition())));
}
}
checkValidity(graph);
return graph;
}
private static void checkValidity(final RuntimeGraph graph) throws WorkflowManagerException {
final Set<String> nodesFromArcs = new HashSet<>();
boolean foundSuccess = false;
boolean foundStart = false;
for (final RuntimeArc arc : graph.getArcs()) {
if (StringUtils.isBlank(arc.getFrom()) || StringUtils.isBlank(arc.getFrom())) {
throw new WorkflowManagerException("Invalid arc: missing from e/o to");
}
if (StringUtils.equals(arc.getTo(), WorkflowsConstants.SUCCESS_NODE)) {
foundSuccess = true;
}
nodesFromArcs.add(arc.getFrom());
nodesFromArcs.add(arc.getTo());
}
if (!foundSuccess) { throw new WorkflowManagerException("Arc to success not found"); }
final Collection<String> diff = CollectionUtils.disjunction(graph.getNodes().keySet(), nodesFromArcs);
if (!diff.isEmpty()) { throw new WorkflowManagerException("Missing or invalid nodes in arcs: " + diff); }
for (final RuntimeNode n : graph.getNodes().values()) {
if (StringUtils.isBlank(n.getName())) { throw new WorkflowManagerException("Invalid node: missing name"); }
if (n.isStart()) {
foundStart = true;
}
}
if (!foundStart) { throw new WorkflowManagerException("Start node not found"); }
}
public static void checkValidity(final RuntimeGraph graph, final Set<String> validTypes) throws WorkflowManagerException {
checkValidity(graph);
for (final RuntimeNode n : graph.getNodes().values()) {
if ((n.getType() != null) && !validTypes.contains(n.getType())) { throw new WorkflowManagerException("Invalid node type: " + n.getType()); }
}
}
public static Set<RuntimeNode> startNodes(final RuntimeGraph graph) {
return graph.getNodes()
.values()
.stream()
.filter(RuntimeNode::isStart)
.collect(Collectors.toSet());
}
public static long getNumberOfIncomingArcs(final RuntimeGraph graph, final RuntimeNode node) {
return graph.getArcs()
.stream()
.map(RuntimeArc::getTo)
.filter(to -> to.equals(node.getName()))
.count();
}
public static Set<RuntimeNode> nextNodes(final RuntimeGraph graph, final RuntimeNode current, final RuntimeEnv env) {
final List<RuntimeArc> arcs = graph.getArcs()
.stream()
.filter(arc -> StringUtils.equals(arc.getFrom(), current.getName()))
.filter(arc -> isValidArc(arc, env))
.toList();
final Set<RuntimeNode> res = new HashSet<>();
arcs.forEach(arc -> {
arc.setCompleted(true);
res.add(graph.getNodes().get(arc.getTo()));
});
return res;
}
private static boolean isValidArc(final RuntimeArc arc, final RuntimeEnv env) {
final Function<RuntimeEnv, Boolean> condFunction = generateFunction(arc.getCondition());
if (condFunction != null) { return condFunction.apply(env); }
return true;
}
private static Function<RuntimeEnv, Boolean> generateFunction(final String condition) {
if (StringUtils.isBlank(condition)) { return env -> true; }
return env -> {
final ExpressionParser parser = new SpelExpressionParser();
final StandardEvaluationContext context = new StandardEvaluationContext(env.getAttributes());
context.addPropertyAccessor(new MapAccessor());
return parser.parseExpression(condition).getValue(context, Boolean.class);
};
}
}