1
14
15 package gate.creole.ml.maxent;
16
17 import gate.creole.ml.*;
18 import gate.util.GateException;
19 import gate.creole.ExecutionException;
20 import java.util.List;
21 import java.util.Iterator;
22
23
27 public class MaxentWrapper
28 implements AdvancedMLEngine, gate.gui.ActionsPublisher {
29
30 boolean DEBUG=false;
31
32
41 public MaxentWrapper() {
42 actionsList = new java.util.ArrayList();
43 actionsList.add(new LoadModelAction());
44 actionsList.add(new SaveModelAction());
45 actionsList.add(null);
46 }
47
48
52 public void cleanUp() {
53 }
54
55
63 public List batchClassifyInstances(java.util.List instances)
64 throws ExecutionException {
65 throw new ExecutionException("The Maxent wrapper does not support "+
66 "batch classification. Remove the "+
67 "<BATCH-MODE-CLASSIFICATION/> entry "+
68 "from the XML configuration file and "+
69 "try again.");
70 }
71
72
78 public void setOptions(org.jdom.Element optionsElem) {
79 this.optionsElement = optionsElem;
80 }
81
82
88 private void extractAndCheckOptions() throws gate.creole.
89 ResourceInstantiationException {
90 setCutoff(optionsElement);
91 setConfidenceThreshold(optionsElement);
92 setVerbose(optionsElement);
93 setIterations(optionsElement);
94 setSmoothing(optionsElement);
95 setSmoothingObservation(optionsElement);
96 }
97
98
102 private void setVerbose(org.jdom.Element optionsElem) {
103 if (optionsElem.getChild("VERBOSE") == null) {
104 verbose = false;
105 }
106 else {
107 verbose = true;
108 }
109 }
110
111
115 private void setSmoothing(org.jdom.Element optionsElem) {
116 if (optionsElem.getChild("SMOOTHING") == null) {
117 smoothing = false;
118 }
119 else {
120 smoothing = true;
121 }
122 }
123
124
128 private void setSmoothingObservation(org.jdom.Element optionsElem) throws
129 gate.creole.ResourceInstantiationException {
130 String smoothingObservationString
131 = optionsElem.getChildTextTrim("SMOOTHING-OBSERVATION");
132 if (smoothingObservationString != null) {
133 try {
134 smoothingObservation = Double.parseDouble(smoothingObservationString);
135 }
136 catch (NumberFormatException e) {
137 throw new gate.creole.ResourceInstantiationException("Unable to parse " +
138 "<SMOOTHING-OBSERVATION> value in maxent configuration file.");
139 }
140 }
141 else {
142 smoothingObservation = 0.0;
143 }
144 }
145
146
150 private void setConfidenceThreshold(org.jdom.Element optionsElem) throws gate.
151 creole.ResourceInstantiationException {
152 String confidenceThresholdString
153 = optionsElem.getChildTextTrim("CONFIDENCE-THRESHOLD");
154 if (confidenceThresholdString != null) {
155 try {
156 confidenceThreshold = Double.parseDouble(confidenceThresholdString);
157 }
158 catch (NumberFormatException e) {
159 throw new gate.creole.ResourceInstantiationException("Unable to parse " +
160 "<CONFIDENCE-THRESHOLD> value in maxent configuration file.");
161 }
162 if (confidenceThreshold < 0.0 || confidenceThreshold > 1) {
163 throw new gate.creole.ResourceInstantiationException(
164 "<CONFIDENCE-THRESHOLD> in maxent configuration"
165 + " file must be set to a value between 0 and 1."
166 + " (It is a probability.)");
167 }
168 }
169 else {
170 confidenceThreshold = 0.0;
171 }
172 }
173
174
178 private void setCutoff(org.jdom.Element optionsElem) throws gate.creole.
179 ResourceInstantiationException {
180 String cutoffString = optionsElem.getChildTextTrim("CUT-OFF");
181 if (cutoffString != null) {
182 try {
183 cutoff = Integer.parseInt(cutoffString);
184 }
185 catch (NumberFormatException e) {
186 throw new gate.creole.ResourceInstantiationException(
187 "Unable to parse <CUT-OFF> value in maxent " +
188 "configuration file. It must be an integer.");
189 }
190 }
191 else {
192 cutoff = 0;
193 }
194 }
195
196
201 private void setIterations(org.jdom.Element optionsElem) throws gate.creole.
202 ResourceInstantiationException {
203 String iterationsString = optionsElem.getChildTextTrim("ITERATIONS");
204 if (iterationsString != null) {
205 try {
206 iterations = Integer.parseInt(iterationsString);
207 }
208 catch (NumberFormatException e) {
209 throw new gate.creole.ResourceInstantiationException(
210 "Unable to parse <ITERATIONS> value in maxent " +
211 "configuration file. It must be an integer.");
212 }
213 }
214 else {
215 iterations = 0;
216 }
217 }
218
219
227 public void addTrainingInstance(List attributeValues) {
228 markIndicesOnFeatures(attributeValues);
229 trainingData.add(attributeValues);
230 datasetChanged = true;
231 }
232
233
243 void markIndicesOnFeatures(List attributeValues) {
244 for (int i=0; i<attributeValues.size(); ++i) {
245 if (i != datasetDefinition.getClassIndex())
247 attributeValues.set(i, i+":"+(String)attributeValues.get(i));
248 }
249 }
250
251
258 public void setDatasetDefinition(DatasetDefintion definition) {
259 this.datasetDefinition = definition;
260 }
261
262
268 private void checkDatasetDefinition() throws gate.creole.
269 ResourceInstantiationException {
270 List attributes = datasetDefinition.getAttributes();
273 Iterator attributeIterator = attributes.iterator();
274 while (attributeIterator.hasNext()) {
275 gate.creole.ml.Attribute currentAttribute
276 = (gate.creole.ml.Attribute) attributeIterator.next();
277 if (currentAttribute.semanticType() != gate.creole.ml.Attribute.BOOLEAN) {
278 if (currentAttribute.semanticType() != gate.creole.ml.Attribute.NOMINAL
279 || !currentAttribute.isClass()) {
280 throw new gate.creole.ResourceInstantiationException(
281 "Error in maxent configuration file. All " +
282 "attributes except the <CLASS/> attribute " +
283 "must be boolean, and the <CLASS/> attribute" +
284 " must be boolean or nominal");
285 }
286 }
287 }
288 }
289
290
295 private void initialiseAndTrainClassifier() {
296 opennlp.maxent.GIS.PRINT_MESSAGES = verbose;
297 opennlp.maxent.GIS.SMOOTHING_OBSERVATION = smoothingObservation;
298
299 if (DEBUG) {
301 System.out.println("Number of training instances: "+trainingData.size());
302 System.out.println("Class index: "+datasetDefinition.getClassIndex());
303 System.out.println("Iterations: "+iterations);
304 System.out.println("Cutoff: "+cutoff);
305 System.out.println("Confidence threshold: "+confidenceThreshold);
306 System.out.println("Verbose: "+verbose);
307 System.out.println("Smoothing: "+smoothing);
308 System.out.println("Smoothing observation: "+smoothingObservation);
309
310 System.out.println("");
311 System.out.println("TRAINING DATA\n");
312 System.out.println(trainingData);
313 }
314 maxentClassifier = opennlp.maxent.GIS.trainModel(
315 new GateEventStream(trainingData, datasetDefinition.getClassIndex()),
316 iterations, cutoff,smoothing,verbose);
317 }
318
319
336 public Object classifyInstance(List attributeValues) throws
337 ExecutionException {
338 if (maxentClassifier == null || datasetChanged)
342 initialiseAndTrainClassifier();
343 datasetChanged=false;
346
347 markIndicesOnFeatures(attributeValues);
350
351 attributeValues.remove(datasetDefinition.getClassIndex());
355
356 if (confidenceThreshold == 0) { return maxentClassifier.
360 getBestOutcome(maxentClassifier.eval(
361 (String[])attributeValues.toArray(new String[0])));
362 }
363 else { double[] outcomeProbabilities = maxentClassifier.eval(
365 (String[]) attributeValues.toArray(new String[0]));
366
367 List allOutcomesOverThreshold = new java.util.ArrayList();
368 for (int i = 0; i < outcomeProbabilities.length; i++) {
369 if (outcomeProbabilities[i] >= confidenceThreshold) {
370 allOutcomesOverThreshold.add(maxentClassifier.getOutcome(i));
371 }
372 }
373 return allOutcomesOverThreshold;
374 }
375 }
377
384 public void init() throws GateException {
385 sListener = null;
387 java.util.Map listeners = gate.gui.MainFrame.getListeners();
388 if (listeners != null) {
389 sListener = (gate.event.StatusListener)
390 listeners.get("gate.event.StatusListener");
391 }
392
393 if (sListener != null) {
394 sListener.statusChanged("Setting classifier options...");
395 }
396 extractAndCheckOptions();
397
398 if (sListener != null) {
399 sListener.statusChanged("Checking dataset definition...");
400 }
401 checkDatasetDefinition();
402
403
407 if (sListener != null) {
409 sListener.statusChanged("Initialising dataset...");
410
411 }
412 trainingData = new java.util.ArrayList();
413
414 if (sListener != null) {
415 sListener.statusChanged("");
416 }
417 }
419
423 public void load(java.io.InputStream is) throws java.io.IOException {
424 if (sListener != null) {
425 sListener.statusChanged("Loading model...");
426
427 }
428 java.io.ObjectInputStream ois = new java.io.ObjectInputStream(is);
429
430 try {
431 maxentClassifier = (opennlp.maxent.MaxentModel) ois.readObject();
432 trainingData = (java.util.List) ois.readObject();
433 datasetDefinition = (DatasetDefintion) ois.readObject();
434 datasetChanged = ois.readBoolean();
435
436 cutoff = ois.readInt();
437 confidenceThreshold = ois.readDouble();
438 iterations = ois.readInt();
439 verbose = ois.readBoolean();
440 smoothing = ois.readBoolean();
441 smoothingObservation = ois.readDouble();
442 }
443 catch (ClassNotFoundException cnfe) {
444 throw new gate.util.GateRuntimeException(cnfe.toString());
445 }
446 ois.close();
447
448 if (sListener != null) {
449 sListener.statusChanged("");
450 }
451 }
452
453
457 public void save(java.io.OutputStream os) throws java.io.IOException {
458 if (sListener != null) {
459 sListener.statusChanged("Saving model...");
460
461 }
462 java.io.ObjectOutputStream oos = new java.io.ObjectOutputStream(os);
463
464 oos.writeObject(maxentClassifier);
465 oos.writeObject(trainingData);
466 oos.writeObject(datasetDefinition);
467 oos.writeBoolean(datasetChanged);
468
469 oos.writeInt(cutoff);
470 oos.writeDouble(confidenceThreshold);
471 oos.writeInt(iterations);
472 oos.writeBoolean(verbose);
473 oos.writeBoolean(smoothing);
474 oos.writeDouble(smoothingObservation);
475
476 oos.flush();
477 oos.close();
478
479 if (sListener != null) {
480 sListener.statusChanged("");
481 }
482 }
483
484
488 public java.util.List getActions() {
489 return actionsList;
490 }
491
492
496 public void setOwnerPR(gate.ProcessingResource pr) {
497 this.owner = pr;
498 }
499
500 public DatasetDefintion getDatasetDefinition() {
501 return datasetDefinition;
502 }
503
504 public boolean supportsBatchMode(){
505 return false;
506 }
507
508
511 protected class SaveModelAction
512 extends javax.swing.AbstractAction {
513 public SaveModelAction() {
514 super("Save model");
515 putValue(SHORT_DESCRIPTION, "Saves the ML model to a file");
516 }
517
518
524 public void actionPerformed(java.awt.event.ActionEvent evt) {
525 Runnable runnable = new Runnable() {
526 public void run() {
527 javax.swing.JFileChooser fileChooser
528 = gate.gui.MainFrame.getFileChooser();
529 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
530 fileChooser.setFileSelectionMode(javax.swing.JFileChooser.FILES_ONLY);
531 fileChooser.setMultiSelectionEnabled(false);
532 if (fileChooser.showSaveDialog(null)
533 == javax.swing.JFileChooser.APPROVE_OPTION) {
534 java.io.File file = fileChooser.getSelectedFile();
535 try {
536 gate.gui.MainFrame.lockGUI("Saving ML model...");
537 save(new java.util.zip.GZIPOutputStream(
538 new java.io.FileOutputStream(
539 file.getCanonicalPath(), false)));
540 }
541 catch (java.io.IOException ioe) {
542 javax.swing.JOptionPane.showMessageDialog(null,
543 "Error!\n" +
544 ioe.toString(),
545 "GATE", javax.swing.JOptionPane.ERROR_MESSAGE);
546 ioe.printStackTrace(gate.util.Err.getPrintWriter());
547 }
548 finally {
549 gate.gui.MainFrame.unlockGUI();
550 }
551 }
552 }
553 };
554 Thread thread = new Thread(runnable, "ModelSaver(serialisation)");
555 thread.setPriority(Thread.MIN_PRIORITY);
556 thread.start();
557 }
558 }
559
560
565 protected class LoadModelAction
566 extends javax.swing.AbstractAction {
567 public LoadModelAction() {
568 super("Load model");
569 putValue(SHORT_DESCRIPTION, "Loads a ML model from a file");
570 }
571
572
578 public void actionPerformed(java.awt.event.ActionEvent evt) {
579 Runnable runnable = new Runnable() {
580 public void run() {
581 javax.swing.JFileChooser fileChooser
582 = gate.gui.MainFrame.getFileChooser();
583 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
584 fileChooser.setFileSelectionMode(javax.swing.JFileChooser.FILES_ONLY);
585 fileChooser.setMultiSelectionEnabled(false);
586 if (fileChooser.showOpenDialog(null)
587 == javax.swing.JFileChooser.APPROVE_OPTION) {
588 java.io.File file = fileChooser.getSelectedFile();
589 try {
590 gate.gui.MainFrame.lockGUI("Loading model...");
591 load(new java.util.zip.GZIPInputStream(
592 new java.io.FileInputStream(file)));
593 }
594 catch (java.io.IOException ioe) {
595 javax.swing.JOptionPane.showMessageDialog(null,
596 "Error!\n" +
597 ioe.toString(),
598 "GATE", javax.swing.JOptionPane.ERROR_MESSAGE);
599 ioe.printStackTrace(gate.util.Err.getPrintWriter());
600 }
601 finally {
602 gate.gui.MainFrame.unlockGUI();
603 }
604 }
605 }
606 };
607 Thread thread = new Thread(runnable, "ModelLoader(serialisation)");
608 thread.setPriority(Thread.MIN_PRIORITY);
609 thread.start();
610 }
611 }
612
613 protected gate.creole.ml.DatasetDefintion datasetDefinition;
614
615
618 protected opennlp.maxent.MaxentModel maxentClassifier;
619
620
626 protected List trainingData;
627
628
631 protected org.jdom.Element optionsElement;
632
633
637 protected boolean datasetChanged = false;
638
639
643 protected List actionsList;
644
645 protected gate.ProcessingResource owner;
646
647 protected gate.event.StatusListener sListener;
648
649
655 protected int cutoff = 0;
656 protected double confidenceThreshold = 0;
657 protected int iterations = 10;
658 protected boolean verbose = false;
659 protected boolean smoothing = false;
660 protected double smoothingObservation = 0.1;
661
662 }