SHOGUN  v3.2.0
LatentModel.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Viktor Gal
8  * Copyright (C) 2012 Viktor Gal
9  */
10 
13 
14 using namespace shogun;
15 
17  : m_features(NULL),
18  m_labels(NULL),
19  m_do_caching(false),
20  m_cached_psi(NULL)
21 {
22  register_parameters();
23 }
24 
25 CLatentModel::CLatentModel(CLatentFeatures* feats, CLatentLabels* labels, bool do_caching)
26  : m_features(feats),
27  m_labels(labels),
28  m_do_caching(do_caching),
29  m_cached_psi(NULL)
30 {
31  register_parameters();
34 }
35 
37 {
41 }
42 
44 {
45  return m_features->get_num_vectors();
46 }
47 
49 {
50  SG_REF(labs);
52  m_labels = labs;
53 }
54 
56 {
58  return m_labels;
59 }
60 
62 {
63  SG_REF(feats);
65  m_features = feats;
66 }
67 
69 {
70  int32_t num = get_num_vectors();
72  ASSERT(num > 0)
73  ASSERT(num == m_labels->get_num_labels())
74 
75  // argmax_h only for positive examples
76  for (int32_t i = 0; i < num; ++i)
77  {
78  if (y->get_label(i) == 1)
79  {
80  // infer h and set it for the argmax_h <w,psi(x,h)>
81  CData* latent_data = infer_latent_variable(w, i);
82  m_labels->set_latent_label(i, latent_data);
83  }
84  }
85 }
86 
87 void CLatentModel::register_parameters()
88 {
89  m_parameters->add((CSGObject**) &m_features, "features", "Latent features");
90  m_parameters->add((CSGObject**) &m_labels, "labels", "Latent labels");
91  m_parameters->add((CSGObject**) &m_cached_psi, "cached_psi", "Cached PSI features after argmax_h");
92  m_parameters->add(&m_do_caching, "do_caching", "Indicate whether or not do PSI vector caching after argmax_h");
93 }
94 
95 
97 {
99  return m_features;
100 }
101 
103 {
104  if (m_do_caching)
105  {
106  if (m_cached_psi)
110  }
111 }
112 
114 {
115  if (m_do_caching)
116  {
118  return m_cached_psi;
119  }
120  return NULL;
121 }
virtual int32_t get_num_labels() const
CDotFeatures * m_cached_psi
Definition: LatentModel.h:150
Latent Features class The class if for representing features for latent learning, e...
#define SG_UNREF(x)
Definition: SGRefObject.h:35
Parameter * m_parameters
Definition: SGObject.h:482
float64_t get_label(int32_t idx)
Features that support dot products among other operations.
Definition: DotFeatures.h:41
void set_labels(CLatentLabels *labs)
Definition: LatentModel.cpp:48
static CBinaryLabels * to_binary(CLabels *base_labels)
void add(bool *param, const char *name, const char *description="")
Definition: Parameter.cpp:27
#define ASSERT(x)
Definition: SGIO.h:203
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:102
CLatentLabels * m_labels
Definition: LatentModel.h:146
dummy data holder
Definition: Data.h:23
CLatentLabels * get_labels() const
Definition: LatentModel.cpp:55
#define SG_REF(x)
Definition: SGRefObject.h:34
virtual CData * infer_latent_variable(const SGVector< float64_t > &w, index_t idx)=0
CLatentFeatures * m_features
Definition: LatentModel.h:144
virtual void argmax_h(const SGVector< float64_t > &w)
Definition: LatentModel.cpp:68
virtual CDotFeatures * get_psi_feature_vectors()=0
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:16
CLabels * get_labels() const
Binary Labels for binary classification.
Definition: BinaryLabels.h:36
CLatentFeatures * get_features() const
Definition: LatentModel.cpp:96
virtual int32_t get_num_vectors() const
void set_features(CLatentFeatures *feats)
Definition: LatentModel.cpp:61
CDotFeatures * get_cached_psi_features() const
abstract class for latent labels As latent labels always depends on the given application, this class only defines the API that the user has to implement for latent labels.
Definition: LatentLabels.h:24
virtual int32_t get_num_vectors() const
Definition: LatentModel.cpp:43
bool set_latent_label(int32_t idx, CData *label)

SHOGUN Machine Learning Toolbox - Documentation