Skip to content
This repository was archived by the owner on Nov 18, 2023. It is now read-only.

Commit 352aef1

Browse files
authored
Refactor Classifiers (#64)
We have removed repeated code among the supervised classifiers
1 parent 7a8b0e2 commit 352aef1

File tree

5 files changed

+140
-209
lines changed

5 files changed

+140
-209
lines changed

examples/kgcn/animal_trade/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ The [main](../../../examples/kgcn/animal_trade/main.py) function will:
6161

6262
- Search Grakn for the k-hop neighbours of the selected examples, and store information about them as arrays, denoted in the code as `context_arrays`. This data is saved to file so that subsequent steps can be re-run without recomputing these data
6363

64-
- Build the TensorFlow computation graph using `model.KGCN`, including a multi-class classification step and learning procedure defined by `classify.SupervisedKGCNClassifier`
64+
- Build the TensorFlow computation graph using `model.KGCN`, including a multi-class classification step and learning procedure defined by `classify.SupervisedKGCNMultiClassSingleLabelClassifier`
6565

6666
- Feed the `context_arrays` to the TensorFlow graph, and perform learning
6767

examples/kgcn/animal_trade/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def main(modes=(TRAIN, EVAL, PREDICT)):
9999
neighbour_sampling_limit_factor=4)
100100

101101
optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
102-
classifier = classify.SupervisedKGCNClassifier(kgcn, optimizer, FLAGS.num_classes, FLAGS.log_dir,
103-
max_training_steps=FLAGS.max_training_steps)
102+
classifier = classify.SupervisedKGCNMultiClassSingleLabelClassifier(kgcn, optimizer, FLAGS.num_classes, FLAGS.log_dir,
103+
max_training_steps=FLAGS.max_training_steps)
104104

105105
feed_dicts = {}
106106
feed_dict_storer = persistence.FeedDictStorer(BASE_PATH + 'input/')

examples/kgcn/animal_trade/test/end_to_end_test.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,76 @@
7676

7777
class TestEndToEnd(unittest.TestCase):
7878

79-
def test_end_to_end(self):
79+
@classmethod
80+
def setUpClass(cls):
81+
8082
# Unzip the Grakn distribution containing our data
8183
sub.run(['unzip', 'external/animaltrade_dist/file/downloaded', '-d',
8284
'external/animaltrade_dist/file/downloaded-unzipped'])
8385

8486
# Start Grakn
8587
sub.run(['external/animaltrade_dist/file/downloaded-unzipped/grakn-core-all-mac-animaltrade1.5.3/grakn', 'server', 'start'])
8688

89+
def test_multi_class_single_label_classification_end_to_end(self):
90+
tf.reset_default_graph()
91+
92+
modes = (TRAIN, EVAL)
93+
94+
client = grakn.client.GraknClient(uri=URI)
95+
sessions = server_mgmt.get_sessions(client, KEYSPACES)
96+
transactions = server_mgmt.get_transactions(sessions)
97+
98+
batch_size = NUM_PER_CLASS * FLAGS.num_classes
99+
kgcn = model.KGCN(NEIGHBOUR_SAMPLE_SIZES,
100+
FLAGS.features_size,
101+
FLAGS.starting_concepts_features_size,
102+
FLAGS.aggregated_size,
103+
FLAGS.embedding_size,
104+
transactions[TRAIN],
105+
batch_size,
106+
neighbour_sampling_method=random_sampling.random_sample,
107+
neighbour_sampling_limit_factor=4)
108+
109+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
110+
classifier = classify.SupervisedKGCNMultiClassSingleLabelClassifier(kgcn, optimizer, FLAGS.num_classes, None,
111+
max_training_steps=FLAGS.max_training_steps)
112+
113+
feed_dicts = {}
114+
115+
sampling_params = {
116+
TRAIN: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
117+
EVAL: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
118+
PREDICT: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
119+
}
120+
concepts, labels = thing_mgmt.compile_labelled_concepts(EXAMPLES_QUERY, EXAMPLE_CONCEPT_TYPE,
121+
LABEL_ATTRIBUTE_TYPE, ATTRIBUTE_VALUES,
122+
transactions[TRAIN], transactions[PREDICT],
123+
sampling_params)
124+
125+
for mode in modes:
126+
mode_labels = labels[mode]
127+
feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=mode_labels)
128+
129+
# Note: The ground-truth attribute labels haven't been removed from Grakn, so the results found here are
130+
# invalid, and used as an end-to-end test only
131+
132+
# Train
133+
if TRAIN in modes:
134+
print("\n\n********** TRAIN Keyspace **********")
135+
classifier.train(feed_dicts[TRAIN])
136+
137+
# Eval
138+
if EVAL in modes:
139+
print("\n\n********** EVAL Keyspace **********")
140+
# Presently, eval keyspace is the same as the TRAIN keyspace
141+
classifier.eval(feed_dicts[EVAL])
142+
143+
server_mgmt.close(sessions)
144+
server_mgmt.close(transactions)
145+
146+
def test_multi_class_multi_label_classification_end_to_end(self):
147+
tf.reset_default_graph()
148+
87149
modes = (TRAIN, EVAL)
88150

89151
client = grakn.client.GraknClient(uri=URI)
@@ -102,7 +164,7 @@ def test_end_to_end(self):
102164
neighbour_sampling_limit_factor=4)
103165

104166
optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
105-
classifier = classify.SupervisedKGCNClassifier(kgcn, optimizer, FLAGS.num_classes, None,
167+
classifier = classify.SupervisedKGCNMultiClassMultiLabelClassifier(kgcn, optimizer, FLAGS.num_classes, None,
106168
max_training_steps=FLAGS.max_training_steps)
107169

108170
feed_dicts = {}
@@ -118,7 +180,8 @@ def test_end_to_end(self):
118180
sampling_params)
119181

120182
for mode in modes:
121-
feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=labels[mode])
183+
mode_labels = [[0, 1, 1]] * len(labels[mode])
184+
feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=mode_labels)
122185

123186
# Note: The ground-truth attribute labels haven't been removed from Grakn, so the results found here are
124187
# invalid, and used as an end-to-end test only

kglib/kgcn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ kgcn = model.KGCN(neighbour_sample_sizes,
6969

7070
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
7171

72-
classifier = classify.SupervisedKGCNClassifier(kgcn,
72+
classifier = classify.SupervisedKGCNMultiClassSingleLabelClassifier(kgcn,
7373
optimizer,
7474
num_classes,
7575
log_dir,

0 commit comments

Comments
 (0)