1
14 package gate.creole.ml;
15
16 import java.util.*;
17
18 import org.jdom.Element;
19 import org.jdom.JDOMException;
20 import org.jdom.input.SAXBuilder;
21
22 import gate.*;
23 import gate.creole.*;
24 import gate.gui.ActionsPublisher;
25 import gate.util.*;
26
27
31
32 public class MachineLearningPR extends AbstractLanguageAnalyser
33 implements gate.gui.ActionsPublisher{
34
35 public MachineLearningPR(){
36 actionList = new ArrayList();
37 actionList.add(null);
38 }
39
40
46 public void cleanup() {
47 super.cleanup();
50
51 if (engine!=null) {
54 engine.cleanUp();
55 }
56 }
57
58
59 public Resource init() throws ResourceInstantiationException {
60 if(configFileURL == null){
61 throw new ResourceInstantiationException(
62 "No configuration file provided!");
63 }
64
65 org.jdom.Document jdomDoc;
66 SAXBuilder saxBuilder = new SAXBuilder(false);
67 try {
68 try{
69 jdomDoc = saxBuilder.build(configFileURL);
70 }catch(JDOMException jde){
71 throw new ResourceInstantiationException(jde);
72 }
73 } catch (java.io.IOException ex) {
74 throw new ResourceInstantiationException(ex);
75 }
76
77 Element rootElement = jdomDoc.getRootElement();
79 if(!rootElement.getName().equals("ML-CONFIG"))
80 throw new ResourceInstantiationException(
81 "Root element of dataset defintion file is \"" + rootElement.getName() +
82 "\" instead of \"ML-CONFIG\"!");
83
84 Element datasetElement = rootElement.getChild("DATASET");
86 if(datasetElement == null) throw new ResourceInstantiationException(
87 "No dataset definition provided in the configuration file!");
88 try{
89 datasetDefinition = new DatasetDefintion(datasetElement);
90 }catch(GateException ge){
91 throw new ResourceInstantiationException(ge);
92 }
93
94 Element engineElement = rootElement.getChild("ENGINE");
96 if(engineElement == null) throw new ResourceInstantiationException(
97 "No engine option provided in the configuration file!");
98 Element engineClassElement = engineElement.getChild("WRAPPER");
99 if(engineClassElement == null) throw new ResourceInstantiationException(
100 "No ML engine class provided!");
101 String engineClassName = engineClassElement.getTextTrim();
102 try{
103 Class engineClass =
105 Class.forName(engineClassName, true, Gate.getClassLoader());
106 engine = (MLEngine)engineClass.newInstance();
107 }catch(ClassNotFoundException cnfe){
108 throw new ResourceInstantiationException(
109 "ML engine class:" + engineClassName + "not found!");
110 }catch(IllegalAccessException iae){
111 throw new ResourceInstantiationException(iae);
112 }catch(InstantiationException ie){
113 throw new ResourceInstantiationException(ie);
114 }
115
116 if (engineElement.getChild("BATCH-MODE-CLASSIFICATION") == null) {
118 batchModeClassification = false;
119 } else {
120 if (engine instanceof AdvancedMLEngine){
124 batchModeClassification = ((AdvancedMLEngine)engine).supportsBatchMode();
125 }
126 else batchModeClassification = false;
127 }
128
129 engine.setDatasetDefinition(datasetDefinition);
130 engine.setOptions(engineElement.getChild("OPTIONS"));
131 engine.setOwnerPR(this);
132 try{
133 engine.init();
134 }catch(GateException ge){
135 throw new ResourceInstantiationException(ge);
136 }
137
138 return this;
139 }
141
142
145 public void execute() throws ExecutionException {
146 interrupted = false;
147 if (document == null) {
149 throw new ExecutionException(
150 "No document provided!"
151 );
152 }
153
154 if (inputASName == null ||
155 inputASName.equals(""))
156 annotationSet = document.getAnnotations();
157 else
158 annotationSet = document.getAnnotations(inputASName);
159
160 if (training.booleanValue()) {
161 fireStatusChanged(
162 "Collecting training data from " + document.getName() + "...");
163 }
164 else {
165 fireStatusChanged(
166 "Applying ML model to " + document.getName() + "...");
167 }
168 fireProgressChanged(0);
169 AnnotationSet anns = annotationSet.
170 get(datasetDefinition.getInstanceType());
171 annotations = (anns == null || anns.isEmpty()) ?
172 new ArrayList() : new ArrayList(anns);
173 Collections.sort(annotations, new OffsetComparator());
174 Iterator annotationIter = annotations.iterator();
175 int index = 0;
176 int size = annotations.size();
177
178 cache = new Cache();
180
181 if (!batchModeClassification || training.booleanValue()) {
182 while (annotationIter.hasNext()) {
186 Annotation instanceAnn = (Annotation) annotationIter.next();
187 List attributeValues = new ArrayList(datasetDefinition.
188 getAttributes().size());
189 Iterator attrIter = datasetDefinition.getAttributes().iterator();
191 while (attrIter.hasNext()) {
192 Attribute attr = (Attribute) attrIter.next();
193 if (attr.isClass && !training.booleanValue()) {
194 attributeValues.add(null);
196 }
197 else {
198 attributeValues.add(cache.getAttributeValue(index, attr));
199 }
200 }
201
202 if (training.booleanValue()) {
203 engine.addTrainingInstance(attributeValues);
204 }
205 else {
206 Object result = engine.classifyInstance(attributeValues);
207 if (result instanceof Collection) {
208 Iterator resIter = ( (Collection) result).iterator();
209 while (resIter.hasNext())
210 updateDocument(resIter.next(), index);
211 }
212 else {
213 updateDocument(result, index);
214 }
215 }
216
217 cache.shift();
218 if (index % 10 == 0) {
220 fireProgressChanged(index * 100 / size);
221 if (isInterrupted())
222 throw new ExecutionInterruptedException();
223 }
224 index++;
225 }
226
227 }
228 else {
229
234 List instancesToBeClassified = new ArrayList();
236
237 while (annotationIter.hasNext()) {
238 Annotation instanceAnn = (Annotation) annotationIter.next();
239 List attributeValues = new ArrayList(datasetDefinition.
240 getAttributes().size());
241 Iterator attrIter = datasetDefinition.getAttributes().iterator();
243 while (attrIter.hasNext()) {
244 Attribute attr = (Attribute) attrIter.next();
245 if (attr.isClass) {
246 attributeValues.add(null);
248 }
249 else {
250 attributeValues.add(cache.getAttributeValue(index, attr));
251 }
252 }
253
254 instancesToBeClassified.add(attributeValues);
257
258 cache.shift();
259
260 index++;
261 }
262
263 List classificationResults = engine.batchClassifyInstances(
266 instancesToBeClassified);
267
268
271 index = 0;
273 Iterator resultsIterator = classificationResults.iterator();
274 while (resultsIterator.hasNext()) {
275
276 Object result = resultsIterator.next();
277 if (result instanceof Collection) {
278 Iterator resIter = ( (Collection) result).iterator();
279 while (resIter.hasNext())
280 updateDocument(resIter.next(), index);
281 }
282 else {
283 updateDocument(result, index);
284 }
285
286 index++;
288 }
289 }
290 annotations = null;
291 }
293
294 protected void updateDocument(Object classificationResult, int instanceIndex){
295 Attribute classAttr = datasetDefinition.getClassAttribute();
297 String type = classAttr.getType();
298 String feature = classAttr.getFeature();
299 List classValues = classAttr.getValues();
300 FeatureMap features = Factory.newFeatureMap();
301 boolean shouldCreateAnnotation = true;
302 if(classValues != null && !classValues.isEmpty()){
303 String featureValue = (String)classificationResult;
306 features.put(feature, featureValue);
307 }else{
308 if(feature == null){
309 shouldCreateAnnotation = classificationResult.equals("true");
311 }else{
312 String featureValue = classificationResult.toString();
314 features.put(feature, featureValue);
315 }
316 }
317
318 if(shouldCreateAnnotation){
319 int coveredInstanceIndex = instanceIndex + classAttr.getPosition();
321 if(coveredInstanceIndex >= 0 &&
322 coveredInstanceIndex < annotations.size()){
323 Annotation coveredInstance = (Annotation)annotations.
324 get(coveredInstanceIndex);
325 annotationSet.add(coveredInstance.getStartNode(),
326 coveredInstance.getEndNode(),
327 type, features);
328 }
329 }
330 }
331
332
333
337 public List getActions(){
338 List result = new ArrayList();
339 result.addAll(actionList);
340 if(engine instanceof ActionsPublisher){
341 result.addAll(((ActionsPublisher)engine).getActions());
342 }
343 return result;
344 }
345
346 protected class Cache{
347 public Cache(){
348 int forwardCacheSize = 0;
350 int backwardCacheSize = 0;
351 Iterator attrIter = datasetDefinition.getAttributes().iterator();
352 while(attrIter.hasNext()){
353 Attribute anAttribute = (Attribute)attrIter.next();
354 if(anAttribute.getPosition() > 0){
355 if(anAttribute.getPosition() > forwardCacheSize){
357 forwardCacheSize = anAttribute.getPosition();
358 }
359 }else if(anAttribute.getPosition() < 0){
360 if(-anAttribute.getPosition() > backwardCacheSize){
362 backwardCacheSize = -anAttribute.getPosition();
363 }
364 }
365 }
366 forwardCache = new ArrayList(forwardCacheSize);
368 for(int i =0; i < forwardCacheSize; i++) forwardCache.add(null);
369 backwardCache = new ArrayList(backwardCacheSize);
370 for(int i =0; i < backwardCacheSize; i++) backwardCache.add(null);
371 }
372
373
380 public String getAttributeValue(int instanceIndex, Attribute attribute){
381 int actualPosition = instanceIndex + attribute.getPosition();
383 if(actualPosition < 0 || actualPosition >= annotations.size()) return null;
384
385 if(attribute.getPosition() == 0){
387 if(currentAttributes == null) currentAttributes = new HashMap();
389 return getValue(attribute, instanceIndex, currentAttributes);
390 }else if(attribute.getPosition() > 0){
391 Map attributesMap = (Map)forwardCache.get(attribute.getPosition() - 1);
393 if(attributesMap == null){
394 attributesMap = new HashMap();
395 forwardCache.set(attribute.getPosition() - 1, attributesMap);
396 }
397 return getValue(attribute, actualPosition, attributesMap);
398 }else if(attribute.getPosition() < 0){
399 Map attributesMap = (Map)backwardCache.get(-attribute.getPosition() - 1);
401 if(attributesMap == null){
402 attributesMap = new HashMap();
403 backwardCache.set(-attribute.getPosition() - 1, attributesMap);
404 }
405 return getValue(attribute, actualPosition, attributesMap);
406 }
407 throw new LuckyException(
409 "Attribute position is neither 0, nor negative nor positive!");
410 }
411
412
416 public void shift(){
417 if(backwardCache.isEmpty()){
418 }else{
421 backwardCache.remove(backwardCache.size() - 1);
422 backwardCache.add(0, currentAttributes);
423 }
424 if(forwardCache.isEmpty()){
425 if(currentAttributes != null) currentAttributes.clear();
427 }else{
428 currentAttributes = (Map) forwardCache.remove(0);
429 forwardCache.add(null);
430 }
431 }
432
433
444 protected String getValue(Attribute attribute,
445 int instanceIndex,
446 Map cache){
447 String value = null;
448 String annType = attribute.getType();
449 String featureName = attribute.getFeature();
450 Map typeData = (Map)cache.get(annType);
451 if(typeData != null){
452 if(featureName == null){
453 value = (String)typeData.get(null);
455 }else{
456 value = (String)typeData.get(featureName);
457 }
458 }else{
459 Annotation instanceAnnot = (Annotation)annotations.get(instanceIndex);
462
463 typeData = new HashMap();
464 cache.put(annType, typeData);
465
471 if (instanceAnnot.getType().equals(annType)){
472 typeData.putAll(instanceAnnot.getFeatures());
473 typeData.put(null, "true");
474
475 String stringvalue = (String)typeData.get(featureName);
476 if(featureName == null) return "true";
477 return stringvalue;
478 }
479
480 AnnotationSet typeSubset = annotationSet.get(annType);
484 AnnotationSet coverSubset = null;
485 if (typeSubset!=null) coverSubset = typeSubset.get(
486 annType,
487 instanceAnnot.getStartNode().getOffset(),
488 instanceAnnot.getEndNode().getOffset());
489
490 if(coverSubset == null || coverSubset.isEmpty()){
491 typeData.put(null, "false");
493 if(featureName == null) value = "false";
494 else value = null;
495 }else{
496 typeData.putAll(((Annotation)coverSubset.iterator().next()).
497 getFeatures());
498 typeData.put(null, "true");
499 if(featureName == null) value = "true";
500 else value = (String)typeData.get(featureName);
501 }
502 }
503 return value;
504 }
505
506
519 protected List forwardCache;
520
521
534 protected List backwardCache;
535
536
548 protected Map currentAttributes;
549
550 }
551
552
553 public void setInputASName(String inputASName) {
554 this.inputASName = inputASName;
555 }
556 public String getInputASName() {
557 return inputASName;
558 }
559 public java.net.URL getConfigFileURL() {
560 return configFileURL;
561 }
562 public void setConfigFileURL(java.net.URL configFileURL) {
563 this.configFileURL = configFileURL;
564 }
565 public void setTraining(Boolean training) {
566 this.training = training;
567 }
568 public Boolean getTraining() {
569 return training;
570 }
571 public MLEngine getEngine() {
572 return engine;
573 }
574 public void setEngine(MLEngine engine) {
575 this.engine = engine;
576 }
577
578 private java.net.URL configFileURL;
579 protected DatasetDefintion datasetDefinition;
580
581 protected MLEngine engine;
582
583 protected String inputASName;
584
585 protected AnnotationSet annotationSet;
586
587 protected List annotations;
588
589 protected List actionList;
590
591 protected Cache cache;
592 private Boolean training;
593
594
599 protected boolean batchModeClassification;
600 }
601