diff --git a/dhp-workflows/dhp-enrichment/src/main/java/eu/dnetlib/dhp/resulttocommunityfromorganization/PrepareResultCommunitySet.java b/dhp-workflows/dhp-enrichment/src/main/java/eu/dnetlib/dhp/resulttocommunityfromorganization/PrepareResultCommunitySet.java index 1a008797da..19b9859648 100644 --- a/dhp-workflows/dhp-enrichment/src/main/java/eu/dnetlib/dhp/resulttocommunityfromorganization/PrepareResultCommunitySet.java +++ b/dhp-workflows/dhp-enrichment/src/main/java/eu/dnetlib/dhp/resulttocommunityfromorganization/PrepareResultCommunitySet.java @@ -9,7 +9,9 @@ import java.util.*; import org.apache.commons.io.IOUtils; import org.apache.hadoop.io.compress.GzipCodec; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FilterFunction; import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapGroupsFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SparkSession; @@ -70,28 +72,33 @@ public class PrepareResultCommunitySet { String outputPath, OrganizationMap organizationMap) { - Dataset relation = readPath(spark, inputPath, Relation.class); - relation.createOrReplaceTempView("relation"); + Dataset relationAffiliation = readPath(spark, inputPath, Relation.class) + .filter( + (FilterFunction) r -> !r.getDataInfo().getDeletedbyinference() && + r.getRelClass().equalsIgnoreCase(ModelConstants.HAS_AUTHOR_INSTITUTION)); - String query = "SELECT result_organization.source resultId, result_organization.target orgId, org_set merges " - + "FROM (SELECT source, target " - + " FROM relation " - + " WHERE datainfo.deletedbyinference = false " - + " AND lower(relClass) = '" - + ModelConstants.HAS_AUTHOR_INSTITUTION.toLowerCase() - + "') result_organization " - + "LEFT JOIN (SELECT source, collect_set(target) org_set " - + " FROM relation " - + " WHERE datainfo.deletedbyinference = false " - + " AND lower(relClass) = '" - + ModelConstants.MERGES.toLowerCase() - + "' " - + " GROUP BY source) organization_organization " - + "ON result_organization.target = organization_organization.source "; + Dataset relationOrganization = readPath(spark, inputPath, Relation.class) + .filter( + (FilterFunction) r -> !r.getDataInfo().getDeletedbyinference() && + r.getRelClass().equalsIgnoreCase(ModelConstants.MERGES)); - Dataset result_organizationset = spark - .sql(query) - .as(Encoders.bean(ResultOrganizations.class)); + Dataset result_organizationset = relationAffiliation + .joinWith( + relationOrganization, + relationAffiliation.col("target").equalTo(relationOrganization.col("source")), + "left") + .groupByKey((MapFunction, String>) t2 -> t2._2().getSource(), Encoders.STRING()) + .mapGroups((MapGroupsFunction, ResultOrganizations>) (k, it) -> { + ResultOrganizations rOrgs = new ResultOrganizations(); + rOrgs.setOrgId(k); + Tuple2 first = it.next(); + rOrgs.setResultId(first._1().getSource()); + ArrayList merges = new ArrayList<>(); + merges.add(first._2().getTarget()); + it.forEachRemaining(t -> merges.add(t._2().getTarget())); + rOrgs.setMerges(merges); + return rOrgs; + }, Encoders.bean(ResultOrganizations.class)); result_organizationset .map(mapResultCommunityFn(organizationMap), Encoders.bean(ResultCommunityList.class))