refactoring

This commit is contained in:
Miriam Baglioni 2020-04-23 12:57:35 +02:00
parent 44fab140de
commit edb00db86a
3 changed files with 144 additions and 122 deletions

View File

@ -6,13 +6,13 @@ import java.util.List;
public class OrganizationMap extends HashMap<String, List<String>> {
public OrganizationMap(){
public OrganizationMap() {
super();
}
public List<String> get(String key){
public List<String> get(String key) {
if (super.get(key) == null){
if (super.get(key) == null) {
return new ArrayList<>();
}
return super.get(key);

View File

@ -1,20 +1,19 @@
package eu.dnetlib.dhp.resulttocommunityfromorganization;
import static eu.dnetlib.dhp.PropagationConstant.*;
import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkHiveSession;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.gson.Gson;
import eu.dnetlib.dhp.application.ArgumentApplicationParser;
import eu.dnetlib.dhp.schema.oaf.Relation;
import java.util.*;
import org.apache.commons.io.IOUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import static eu.dnetlib.dhp.PropagationConstant.*;
import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkHiveSession;
public class PrepareResultCommunitySet {
private static final Logger log = LoggerFactory.getLogger(PrepareResultCommunitySet.class);
@ -22,11 +21,12 @@ public class PrepareResultCommunitySet {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
public static void main(String[] args) throws Exception {
String jsonConfiguration = IOUtils.toString(PrepareResultCommunitySet.class
.getResourceAsStream("/eu/dnetlib/dhp/resulttocommunityfromorganization/input_preparecommunitytoresult_parameters.json"));
String jsonConfiguration =
IOUtils.toString(
PrepareResultCommunitySet.class.getResourceAsStream(
"/eu/dnetlib/dhp/resulttocommunityfromorganization/input_preparecommunitytoresult_parameters.json"));
final ArgumentApplicationParser parser = new ArgumentApplicationParser(
jsonConfiguration);
final ArgumentApplicationParser parser = new ArgumentApplicationParser(jsonConfiguration);
parser.parseArgument(args);
@ -39,76 +39,91 @@ public class PrepareResultCommunitySet {
final String outputPath = parser.get("outputPath");
log.info("outputPath: {}", outputPath);
final OrganizationMap organizationMap = new Gson().fromJson(parser.get("organizationtoresultcommunitymap"), OrganizationMap.class);
final OrganizationMap organizationMap =
new Gson()
.fromJson(
parser.get("organizationtoresultcommunitymap"),
OrganizationMap.class);
log.info("organizationMap: {}", new Gson().toJson(organizationMap));
SparkConf conf = new SparkConf();
conf.set("hive.metastore.uris", parser.get("hive_metastore_uris"));
runWithSparkHiveSession(conf, isSparkSessionManaged,
runWithSparkHiveSession(
conf,
isSparkSessionManaged,
spark -> {
if (isTest(parser)) {
removeOutputDir(spark, outputPath);
}
prepareInfo(spark, inputPath, outputPath, organizationMap);
prepareInfo(spark, inputPath, outputPath, organizationMap);
});
}
private static void prepareInfo(SparkSession spark, String inputPath, String outputPath, OrganizationMap organizationMap) {
private static void prepareInfo(
SparkSession spark,
String inputPath,
String outputPath,
OrganizationMap organizationMap) {
Dataset<Relation> relation = readRelations(spark, inputPath);
relation.createOrReplaceTempView("relation");
String query = "SELECT result_organization.source resultId, result_organization.target orgId, org_set merges " +
"FROM (SELECT source, target " +
" FROM relation " +
" WHERE datainfo.deletedbyinference = false " +
" AND relClass = '" + RELATION_RESULT_ORGANIZATION_REL_CLASS + "') result_organization " +
"LEFT JOIN (SELECT source, collect_set(target) org_set " +
" FROM relation " +
" WHERE datainfo.deletedbyinference = false " +
" AND relClass = '" + RELATION_REPRESENTATIVERESULT_RESULT_CLASS + "' " +
" GROUP BY source) organization_organization " +
"ON result_organization.target = organization_organization.source ";
String query =
"SELECT result_organization.source resultId, result_organization.target orgId, org_set merges "
+ "FROM (SELECT source, target "
+ " FROM relation "
+ " WHERE datainfo.deletedbyinference = false "
+ " AND relClass = '"
+ RELATION_RESULT_ORGANIZATION_REL_CLASS
+ "') result_organization "
+ "LEFT JOIN (SELECT source, collect_set(target) org_set "
+ " FROM relation "
+ " WHERE datainfo.deletedbyinference = false "
+ " AND relClass = '"
+ RELATION_REPRESENTATIVERESULT_RESULT_CLASS
+ "' "
+ " GROUP BY source) organization_organization "
+ "ON result_organization.target = organization_organization.source ";
org.apache.spark.sql.Dataset<ResultOrganizations> result_organizationset = spark.sql(query)
.as(Encoders.bean(ResultOrganizations.class));
org.apache.spark.sql.Dataset<ResultOrganizations> result_organizationset =
spark.sql(query).as(Encoders.bean(ResultOrganizations.class));
result_organizationset
.map(value -> {
String rId = value.getResultId();
List<String> orgs = value.getMerges();
String oTarget = value.getOrgId();
Set<String> communitySet = new HashSet<>();
if (organizationMap.containsKey(oTarget)) {
communitySet.addAll(organizationMap.get(oTarget));
}
try{
for (String oId : orgs) {
if (organizationMap.containsKey(oId)) {
communitySet.addAll(organizationMap.get(oId));
.map(
value -> {
String rId = value.getResultId();
Optional<List<String>> orgs = Optional.ofNullable(value.getMerges());
String oTarget = value.getOrgId();
Set<String> communitySet = new HashSet<>();
if (organizationMap.containsKey(oTarget)) {
communitySet.addAll(organizationMap.get(oTarget));
}
}
}catch(Exception e){
}
if (communitySet.size() > 0){
ResultCommunityList rcl = new ResultCommunityList();
rcl.setResultId(rId);
ArrayList<String> communityList = new ArrayList<>();
communityList.addAll(communitySet);
rcl.setCommunityList(communityList);
return rcl;
}
return null;
}, Encoders.bean(ResultCommunityList.class))
.filter(r -> r!= null)
.toJSON()
.write()
.mode(SaveMode.Overwrite)
.option("compression", "gzip")
.text(outputPath);
if (orgs.isPresent())
// try{
for (String oId : orgs.get()) {
if (organizationMap.containsKey(oId)) {
communitySet.addAll(organizationMap.get(oId));
}
}
// }catch(Exception e){
//
// }
if (communitySet.size() > 0) {
ResultCommunityList rcl = new ResultCommunityList();
rcl.setResultId(rId);
ArrayList<String> communityList = new ArrayList<>();
communityList.addAll(communitySet);
rcl.setCommunityList(communityList);
return rcl;
}
return null;
},
Encoders.bean(ResultCommunityList.class))
.filter(r -> r != null)
.toJSON()
.write()
.mode(SaveMode.Overwrite)
.option("compression", "gzip")
.text(outputPath);
}
}

View File

@ -1,8 +1,13 @@
package eu.dnetlib.dhp.resulttocommunityfromorganization;
import static eu.dnetlib.dhp.PropagationConstant.*;
import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkHiveSession;
import com.fasterxml.jackson.databind.ObjectMapper;
import eu.dnetlib.dhp.application.ArgumentApplicationParser;
import eu.dnetlib.dhp.schema.oaf.*;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.io.IOUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Encoders;
@ -10,25 +15,21 @@ import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.stream.Collectors;
import static eu.dnetlib.dhp.PropagationConstant.*;
import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkHiveSession;
public class SparkResultToCommunityFromOrganizationJob2 {
private static final Logger log = LoggerFactory.getLogger(SparkResultToCommunityFromOrganizationJob2.class);
private static final Logger log =
LoggerFactory.getLogger(SparkResultToCommunityFromOrganizationJob2.class);
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
public static void main(String[] args) throws Exception {
String jsonConfiguration = IOUtils.toString(SparkResultToCommunityFromOrganizationJob2.class
.getResourceAsStream("/eu/dnetlib/dhp/resulttocommunityfromorganization/input_communitytoresult_parameters.json"));
String jsonConfiguration =
IOUtils.toString(
SparkResultToCommunityFromOrganizationJob2.class.getResourceAsStream(
"/eu/dnetlib/dhp/resulttocommunityfromorganization/input_communitytoresult_parameters.json"));
final ArgumentApplicationParser parser = new ArgumentApplicationParser(
jsonConfiguration);
final ArgumentApplicationParser parser = new ArgumentApplicationParser(jsonConfiguration);
parser.parseArgument(args);
@ -47,73 +48,79 @@ public class SparkResultToCommunityFromOrganizationJob2 {
final String resultClassName = parser.get("resultTableName");
log.info("resultTableName: {}", resultClassName);
final Boolean saveGraph = Optional
.ofNullable(parser.get("saveGraph"))
.map(Boolean::valueOf)
.orElse(Boolean.TRUE);
final Boolean saveGraph =
Optional.ofNullable(parser.get("saveGraph"))
.map(Boolean::valueOf)
.orElse(Boolean.TRUE);
log.info("saveGraph: {}", saveGraph);
Class<? extends Result> resultClazz = (Class<? extends Result>) Class.forName(resultClassName);
Class<? extends Result> resultClazz =
(Class<? extends Result>) Class.forName(resultClassName);
SparkConf conf = new SparkConf();
conf.set("hive.metastore.uris", parser.get("hive_metastore_uris"));
runWithSparkHiveSession(conf, isSparkSessionManaged,
runWithSparkHiveSession(
conf,
isSparkSessionManaged,
spark -> {
if (isTest(parser)) {
removeOutputDir(spark, outputPath);
}
execPropagation(spark, inputPath, outputPath, resultClazz, possibleupdatespath);
});
}
private static <R extends Result> void execPropagation(SparkSession spark, String inputPath, String outputPath,
Class<R> resultClazz, String possibleUpdatesPath) {
org.apache.spark.sql.Dataset<ResultCommunityList> possibleUpdates = readResultCommunityList(spark, possibleUpdatesPath);
private static <R extends Result> void execPropagation(
SparkSession spark,
String inputPath,
String outputPath,
Class<R> resultClazz,
String possibleUpdatesPath) {
org.apache.spark.sql.Dataset<ResultCommunityList> possibleUpdates =
readResultCommunityList(spark, possibleUpdatesPath);
org.apache.spark.sql.Dataset<R> result = readPathEntity(spark, inputPath, resultClazz);
result
.joinWith(possibleUpdates, result.col("id").equalTo(possibleUpdates.col("resultId")),
result.joinWith(
possibleUpdates,
result.col("id").equalTo(possibleUpdates.col("resultId")),
"left_outer")
.map(value -> {
R ret = value._1();
Optional<ResultCommunityList> rcl = Optional.ofNullable(value._2());
if(rcl.isPresent()){
ArrayList<String> communitySet = rcl.get().getCommunityList();
List<String> contextList = ret.getContext().stream().map(con -> con.getId()).collect(Collectors.toList());
Result res = new Result();
res.setId(ret.getId());
List<Context> propagatedContexts = new ArrayList<>();
for(String cId:communitySet){
if(!contextList.contains(cId)){
Context newContext = new Context();
newContext.setId(cId);
newContext.setDataInfo(Arrays.asList(getDataInfo(PROPAGATION_DATA_INFO_TYPE,
PROPAGATION_RESULT_COMMUNITY_ORGANIZATION_CLASS_ID, PROPAGATION_RESULT_COMMUNITY_ORGANIZATION_CLASS_NAME)));
propagatedContexts.add(newContext);
.map(
value -> {
R ret = value._1();
Optional<ResultCommunityList> rcl = Optional.ofNullable(value._2());
if (rcl.isPresent()) {
ArrayList<String> communitySet = rcl.get().getCommunityList();
List<String> contextList =
ret.getContext().stream()
.map(con -> con.getId())
.collect(Collectors.toList());
Result res = new Result();
res.setId(ret.getId());
List<Context> propagatedContexts = new ArrayList<>();
for (String cId : communitySet) {
if (!contextList.contains(cId)) {
Context newContext = new Context();
newContext.setId(cId);
newContext.setDataInfo(
Arrays.asList(
getDataInfo(
PROPAGATION_DATA_INFO_TYPE,
PROPAGATION_RESULT_COMMUNITY_ORGANIZATION_CLASS_ID,
PROPAGATION_RESULT_COMMUNITY_ORGANIZATION_CLASS_NAME)));
propagatedContexts.add(newContext);
}
}
res.setContext(propagatedContexts);
ret.mergeFrom(res);
}
}
res.setContext(propagatedContexts);
ret.mergeFrom(res);
}
return ret;
}, Encoders.bean(resultClazz))
return ret;
},
Encoders.bean(resultClazz))
.toJSON()
.write()
.mode(SaveMode.Overwrite)
.option("compression","gzip")
.option("compression", "gzip")
.text(outputPath);
}
private static org.apache.spark.sql.Dataset<ResultCommunityList> readResultCommunityList(SparkSession spark, String possibleUpdatesPath) {
return spark
.read()
.textFile(possibleUpdatesPath)
.map(value -> OBJECT_MAPPER.readValue(value, ResultCommunityList.class), Encoders.bean(ResultCommunityList.class));
}
}