Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Release #288

Merged
merged 200 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
200 commits
Select commit Hold shift + click to select a range
0bbae4d
added coverage report generation on git workflow
alkidbaci Aug 2, 2024
f5e3f7e
added coverage configs
alkidbaci Aug 2, 2024
98a3773
added coverage report
alkidbaci Aug 2, 2024
a7e09af
added static bandages
alkidbaci Aug 2, 2024
1d1be16
updated link for coverage bandage
alkidbaci Aug 2, 2024
ccb7b26
Merge pull request #253 from dice-group/static_bandages
Demirrr Aug 2, 2024
c53144c
added logo and favicon
alkidbaci Aug 9, 2024
0849d2c
Merge pull request #254 from dice-group/logo
Demirrr Aug 9, 2024
c1d2ef1
Refactoring In Progres: KGE models are updated. Unused Domain and Ran…
Demirrr Aug 14, 2024
7f6d194
DualE added
Demirrr Aug 14, 2024
f8a1aa9
Merge pull request #255 from dice-group/refactoring
Demirrr Aug 15, 2024
0d571d6
.nt file readded for read_with_pandas
Demirrr Aug 20, 2024
5ec6b1b
Tqdm integrated
Demirrr Aug 20, 2024
e2ed3b3
Merge pull request #257 from dice-group/refactoring
Demirrr Aug 21, 2024
2b44c7d
progress bar are alighed between trainers
Demirrr Aug 22, 2024
5db0643
Potential div by zero avoided
Demirrr Aug 22, 2024
4a6fa5e
Merge pull request #258 from dice-group/refactoring
Demirrr Aug 23, 2024
f8e0c56
torchrun and dicee entry points work together. DDB with two gpus only…
Demirrr Sep 20, 2024
263b2cb
compile added in ddp
Demirrr Sep 21, 2024
367fbca
default params changed to minimize the default training runtimes
Demirrr Sep 23, 2024
88a8c65
refactoring and removing unused imports and codes
Demirrr Sep 23, 2024
60f62b3
defaul lr and optim changed to Adam 0.1
Demirrr Sep 25, 2024
61a9a01
version incremanted
Demirrr Sep 25, 2024
10c4343
Refactoring In Progress: a KGE trained in C++ code base can be used i…
Demirrr Oct 4, 2024
2a8bdc5
Refactoring In Progress: Default optim Adam in base model
Demirrr Oct 4, 2024
25794d2
Refactoring In Progress: Assertion added for entity and relation mapp…
Demirrr Oct 4, 2024
f069b0e
Refactoring In Progress: KeyError exception added
Demirrr Oct 4, 2024
2cced25
__iter__, __len__ and inverse index aded
Demirrr Oct 4, 2024
1ae25ca
predict_top_k() can ba used with string and list of strings
Demirrr Oct 4, 2024
4c63437
unused import removed
Demirrr Oct 4, 2024
b240cde
Refactoring DDP for better logging
Demirrr Oct 4, 2024
1083598
Refactoring In Progress: unused imports and docstrings
Demirrr Oct 4, 2024
943a8f8
DualE with the abbrevations of noqa
Demirrr Oct 4, 2024
a43b3fb
WIP:Refactoring
Demirrr Oct 4, 2024
f3ec5d4
Docstring added to explain the scipt
Demirrr Oct 7, 2024
82197dc
lazy import error catch message
Demirrr Oct 7, 2024
c3f1661
Reading/Indexing/Storing
Demirrr Oct 7, 2024
fd74806
indexing example added
Demirrr Oct 7, 2024
772466c
KvsSample Dataset rewritten for the multi-class classfication problem
Demirrr Oct 9, 2024
e743425
KvsSample refactored
Demirrr Oct 10, 2024
dff2629
BaseKGELightning.training_setup() implements kvssample training step
Demirrr Oct 10, 2024
95c0ac4
kvssample for cliiford implemented
Demirrr Oct 10, 2024
5fe6c7b
distmult kvssample with einsum
Demirrr Oct 10, 2024
e924134
ComplEx.forward_k_vs_sample() implemented
Demirrr Oct 10, 2024
649b6f1
Keci().forward_k_vs_sample() validated with einsum p=0 q>=0
Demirrr Oct 11, 2024
bb620d0
multinomial() is being used to assign probabilities for entities. Thi…
Demirrr Oct 11, 2024
682eb4b
1vsSample name is being used
Demirrr Oct 11, 2024
5f2d9a6
1vsSample name is being used
Demirrr Oct 11, 2024
bf1bbef
1vsSample added
Demirrr Oct 11, 2024
fcdc52a
renamed from kvssample to onevssample
Demirrr Oct 12, 2024
00bbd72
neg sample stated explicitly
Demirrr Oct 12, 2024
a0bd139
attributes explicitly fixed
Demirrr Oct 12, 2024
256f1da
unused imports are removed
Demirrr Oct 12, 2024
28aeedf
1vsSample working and pickle usage will be deprecated
Demirrr Oct 12, 2024
274a992
new version of polars included
Demirrr Oct 14, 2024
45fcb0c
polars_dataframe_indexer() implemented
Demirrr Oct 14, 2024
d276783
preprocess_with_polars() refactored
Demirrr Oct 14, 2024
b6e43b6
if polars is a backend, indices are stored in csv
Demirrr Oct 14, 2024
de6ab29
Reindexing ignored if polars being used
Demirrr Oct 14, 2024
adf77d4
coefficients function only takes p and q dimensions
Demirrr Oct 14, 2024
9c4ccd4
CMult removed
Demirrr Oct 14, 2024
5aa51a8
Large Scale learning KGE on CPU
Demirrr Oct 15, 2024
09f1953
Refactored
Demirrr Oct 15, 2024
0e0ee7d
refactored
Demirrr Oct 15, 2024
c29db15
KvsSample reincluded
Demirrr Oct 15, 2024
47db2f9
KvsSample reincluded
Demirrr Oct 15, 2024
7884d4e
CMULT removed
Demirrr Oct 15, 2024
f1b809c
CMULT removed
Demirrr Oct 15, 2024
dbe5af8
WIP: Integrating kvssample
Demirrr Oct 18, 2024
b388291
Attempt to solve github coverage error
Demirrr Oct 18, 2024
7b4a6ba
Update github-actions-python-package.yml
Demirrr Oct 18, 2024
28836b5
Update .coveragerc
Demirrr Oct 18, 2024
94ed6b2
Merge pull request #261 from dice-group/larger_than_memory
Demirrr Oct 18, 2024
9867b43
load_term_mapping() is implemented to loead indices
Demirrr Oct 21, 2024
c3b95c7
scoring_techniques extended in eval
Demirrr Oct 21, 2024
b8b3e07
read_with_polars() refactored
Demirrr Oct 21, 2024
2ef2dac
raw sets are set to None to decrease the memory
Demirrr Oct 21, 2024
56e7656
Few comments added
Demirrr Oct 21, 2024
a2a2c83
Assertion scoring_technique checking
Demirrr Oct 21, 2024
e8a329c
reading indices as csv files
Demirrr Oct 21, 2024
d331323
Dynamic KvsSample working!
Demirrr Oct 21, 2024
33cdad5
kvssample is default scoring
Demirrr Oct 21, 2024
a1d87bf
todo added
Demirrr Oct 21, 2024
392daeb
Unused imports are removed
Demirrr Oct 21, 2024
4f5f069
ruff check made more restrictive
Demirrr Oct 21, 2024
4a35799
deprecated class method removed
Demirrr Oct 21, 2024
e95180f
raw datasets are not emptied if bpe is being used
Demirrr Oct 22, 2024
b31151a
typo fixed
Demirrr Oct 22, 2024
fe970d7
Merge pull request #262 from dice-group/larger_than_memory
Demirrr Oct 22, 2024
55bd4fc
separator for polars changed from \t to whitespace
Demirrr Oct 22, 2024
0034cec
loggings are aligned across gpus
Demirrr Oct 22, 2024
6e070e4
from numpy to torch conversion moved into __get__item in negsample to…
Demirrr Oct 22, 2024
e6ea7bb
Onevsall with memory map
Demirrr Oct 22, 2024
2eb5db4
WIP: replacing nump arrays with memory maps
Demirrr Oct 22, 2024
37f06f6
WIP: MultiGPU training with memory map
Demirrr Oct 23, 2024
2ffb4c9
WIP: MultiGPUs training
Demirrr Oct 23, 2024
6d84e19
WIP: multi-gpu memory map
Demirrr Oct 23, 2024
f992c34
fixes for separator
Demirrr Oct 23, 2024
8e5b840
formating fixes
Demirrr Oct 23, 2024
4dc027f
is_continual_training flag is removed in the start function of Execut…
Demirrr Oct 24, 2024
5277b3f
Update Readme & Exception handling
Demirrr Oct 24, 2024
29cf2bd
WIP:Unsed attributes
Demirrr Oct 24, 2024
5f8c404
GradScaler included
Demirrr Oct 24, 2024
3764aca
copy memmap array numpy to pytorch tensor
Demirrr Oct 25, 2024
3b1d838
printing selected opt
Demirrr Oct 25, 2024
f1fbb29
WIP: CL with memory map
Demirrr Oct 25, 2024
5612303
WIP: csvs are used instead of pickling dictionaries
Demirrr Oct 25, 2024
64b0487
format errors are fixed
Demirrr Oct 25, 2024
5e82778
Merge pull request #263 from dice-group/refactoring_memorymap
Demirrr Oct 25, 2024
7b83092
WIP: Model Paralellisim
Demirrr Oct 25, 2024
14676c7
WIP: MP negsample
Demirrr Oct 25, 2024
758c858
WIP: Model Paralelisim Refactoring
Demirrr Oct 27, 2024
ff7eb1c
PL and ML returns same results on a single GPU compute
Demirrr Oct 27, 2024
9834ad1
WIP: Refactoring ML
Demirrr Oct 27, 2024
628aa4b
Info about batches added
Demirrr Oct 28, 2024
9fcff4d
Fixed the lint errors
Demirrr Oct 28, 2024
e3dff25
Fixes None trainer error
Demirrr Oct 28, 2024
cf44939
typo fixed at torchDDP
Demirrr Oct 28, 2024
910a2ae
typo fixed at torchDDP
Demirrr Oct 28, 2024
e6f0d78
Merge pull request #265 from dice-group/refactoring_memorymap
Demirrr Oct 28, 2024
4e58fb1
No need to inform user about mappings
Demirrr Oct 29, 2024
71ffcc6
ASWA callback global and local ranks must be 0
Demirrr Oct 31, 2024
e331b0c
Merge branch 'refactoring_memorymap' of https://github.com/dice-group…
Demirrr Oct 31, 2024
9f25d7d
Benchmark results added
Demirrr Oct 31, 2024
f4a86a4
fix for ASWA callback for all trainers
Demirrr Oct 31, 2024
157b2fa
Merge pull request #266 from dice-group/refactoring_memorymap
Demirrr Oct 31, 2024
f32d448
Update README.md
Demirrr Oct 31, 2024
2684c46
Update README.md
Demirrr Oct 31, 2024
6389615
UMLS with AllvsAll
Demirrr Oct 31, 2024
19804c3
Update README.md
Demirrr Nov 6, 2024
2367bee
Update README.md
Demirrr Nov 12, 2024
03854cc
compressed kg can be read by polars.read_csv() directly
Demirrr Nov 12, 2024
7728c70
compressed kg with read few only available in polars
Demirrr Nov 12, 2024
437ffd0
Merge pull request #269 from dice-group/compressed_kg
Demirrr Nov 13, 2024
fb4f72d
write_csv_from_model_parallel implemented
Demirrr Nov 14, 2024
a0ac733
Model Parallel Regression Test and write_csv_from_model_parallel
Demirrr Nov 14, 2024
010c576
if no gpu, no errors fix
Demirrr Nov 14, 2024
794abad
k increased 10 to avoid assertion error. Randomes in answer_multi_hop…
Demirrr Nov 15, 2024
51ea8a9
Merge pull request #271 from dice-group/model_parallel_to_csv
Demirrr Nov 15, 2024
9aa9c28
Saving embeddings into csv implemented and tested
Demirrr Nov 15, 2024
9cc8348
write_csv_from_model_parallel() and from_pretrained_model_write_embed…
Demirrr Nov 15, 2024
f60485a
Merge pull request #272 from dice-group/extracting_embedding_in_csv
Demirrr Nov 16, 2024
d853720
WIP: Tensor Paralelisim
Demirrr Nov 16, 2024
cdc615d
WIP: Forward shapes fixing
Demirrr Nov 16, 2024
e701140
Working version of model/pipeline paralelisim
Demirrr Nov 16, 2024
c190112
Tensor Parallelisim implemented. Yet, torch seemed to have a bug http…
Demirrr Nov 18, 2024
1dae312
WIP: Tensor Paralelisim implemented
Demirrr Nov 19, 2024
7654051
WIP: Training a KGE model with Tensor Paralelisim
Demirrr Nov 19, 2024
f16bc54
MP changed to TP
Demirrr Nov 19, 2024
a78a548
TensorParallel Trainer returns the model
Demirrr Nov 20, 2024
69ff7da
Tensorparallel with EnsembelKGE written into disk with _partial_x.pt …
Demirrr Nov 20, 2024
598abf0
Initial working version of ensemble kge. Name must be changed
Demirrr Nov 20, 2024
b535fed
todo and comments added
Demirrr Nov 20, 2024
29068db
forward_triples() moved into callable()
Demirrr Nov 22, 2024
85af150
Auto batch finding as default for TP
Demirrr Nov 24, 2024
b0fcbd5
drop_last false and static forward_backward_update_loss used
Demirrr Nov 24, 2024
c1a14e7
Adopt included, pytorch version increased
Demirrr Nov 25, 2024
6e0a51b
Fixing lint errors
Demirrr Nov 25, 2024
d2736c3
Log info changed
Demirrr Nov 25, 2024
244d197
Update README.md
Demirrr Nov 25, 2024
abc5226
Merge pull request #274 from dice-group/tensor_parallel
Demirrr Nov 25, 2024
4e518bf
increment factor is the first batch size
Demirrr Nov 25, 2024
4b1a876
avg of last three batches gpu usage measured
Demirrr Nov 25, 2024
a6e15b7
dynomo import removed
Demirrr Nov 25, 2024
4e89b1d
expoential batch size increment is reduced to the linear
Demirrr Nov 26, 2024
0417f19
embeddings can be concatted horiziontally for csv
Demirrr Nov 26, 2024
f38bfa8
Merge pull request #275 from dice-group/tensor_parallel
Demirrr Nov 26, 2024
0871973
Auto batch finding as an argument
Demirrr Nov 26, 2024
9ccee78
Tensor Parallel is working
Demirrr Nov 26, 2024
f53c1e7
Merge pull request #276 from dice-group/tensor_parallel
Demirrr Nov 26, 2024
d3081a1
compile is removed and avg is reduced to single in gpu usage signal
Demirrr Nov 28, 2024
5551241
Improved batch finding in TP
Demirrr Nov 28, 2024
f325b8a
Fix deprecated variable.
sshivam95 Nov 28, 2024
9b17326
Add batched evaluation
sshivam95 Nov 28, 2024
b5de7b5
Merge pull request #279 from dice-group/query-generator-test
Demirrr Nov 28, 2024
94ab305
Update README.md
Demirrr Nov 28, 2024
fbf30ab
WIP: Reducing the runtime of finding a good search & removing redanda…
Demirrr Nov 28, 2024
44b9dbd
fstring usage without any placeholder fixed
Demirrr Nov 28, 2024
58aa98c
Merge pull request #280 from dice-group/tensor_parallel
Demirrr Nov 28, 2024
a5f1648
Comment old code
sshivam95 Nov 29, 2024
f1b263b
TP with auto batch finding can be used to train KGE with >20B
Demirrr Nov 29, 2024
e0ee128
Merge branch 'develop' into tensor_parallel
Demirrr Nov 29, 2024
fb436e6
Merge pull request #277 from dice-group/batched-evaluation
sshivam95 Nov 29, 2024
7580606
Merge pull request #282 from dice-group/tensor_parallel
Demirrr Nov 29, 2024
e8dddb6
Refactoring before new release
Demirrr Nov 29, 2024
6638da6
Fix memory allocation issue
sshivam95 Nov 30, 2024
787b535
Merge pull request #285 from dice-group/batched-evaluation-memory-fix
Demirrr Nov 30, 2024
1dc3000
F841 lint error fixed
Demirrr Nov 30, 2024
4326478
Merge branch 'develop' into refactor
Demirrr Nov 30, 2024
3332aec
Merge pull request #283 from dice-group/refactor
Demirrr Dec 1, 2024
b1b1faf
Regression test added for TP
Demirrr Dec 2, 2024
90e9285
EnsembleKGE moved to init
Demirrr Dec 2, 2024
128bd18
TP can be contiually trained
Demirrr Dec 2, 2024
ceffb79
dept func from executor removed and time postfix removed from continu…
Demirrr Dec 2, 2024
01e696b
unused import removed
Demirrr Dec 2, 2024
a766711
Merge pull request #286 from dice-group/refactor
Demirrr Dec 3, 2024
cddfc17
try catches are removed and example added
Demirrr Dec 3, 2024
fc9b352
if optim doesn't exist, it should return None
Demirrr Dec 3, 2024
6f8e4b5
Simple example of training a KGE with pytorch setup
Demirrr Dec 3, 2024
e4a0bd1
Merge pull request #287 from dice-group/literal_example
Demirrr Dec 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[run]
omit =
tests/*
/tmp/*

[report]
exclude_lines =
pragma: no cover
def __repr__
if self.debug:
if settings.DEBUG
raise AssertionError
raise NotImplementedError
if 0:
if __name__ == .__main__.:
if TYPE_CHECKING:
class .*\bProtocol\):
@(abc\.)?abstractmethod
pass
7 changes: 6 additions & 1 deletion .github/workflows/github-actions-python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ jobs:

- name: Lint with ruff
run: |
ruff check dicee/ --select=E501 --line-length=200
ruff check dicee/ --line-length=200
- name: Test with pytest
run: |
wget https://files.dice-research.org/datasets/dice-embeddings/KGs.zip --no-check-certificate && unzip KGs.zip
python -m pytest -p no:warnings -x
- name: Coverage report
run: |
pip install coverage
coverage run -m pytest -p no:warnings -x
coverage report -m
362 changes: 311 additions & 51 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dicee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .executer import Execute # noqa
from .dataset_classes import * # noqa
from .query_generator import QueryGenerator # noqa
__version__ = '0.1.4'
__version__ = '0.1.5'
6 changes: 6 additions & 0 deletions dicee/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# dicee/__main__.py

from dicee.scripts.run import main # Import the main entry point of dicee

if __name__ == "__main__":
main() # Call the main function to execute the program logic
27 changes: 5 additions & 22 deletions dicee/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, args, callbacks):
self.attributes = args
self.callbacks = callbacks
self.is_global_zero = True
self.global_rank=0
self.local_rank = 0
# Set True to use Model summary callback of pl.
torch.manual_seed(self.attributes.random_seed)
torch.cuda.manual_seed_all(self.attributes.random_seed)
Expand Down Expand Up @@ -178,21 +180,13 @@ def __init__(self, path: str = None, url: str = None, construct_ensemble: bool =
self.num_relations = len(self.relation_to_idx)
self.entity_to_idx: dict
self.relation_to_idx: dict
assert list(self.entity_to_idx.values()) == list(range(0, len(self.entity_to_idx)))
assert list(self.relation_to_idx.values()) == list(range(0, len(self.relation_to_idx)))
# 0, ....,
assert sorted(list(self.entity_to_idx.values())) == list(range(0, len(self.entity_to_idx)))
assert sorted(list(self.relation_to_idx.values())) == list(range(0, len(self.relation_to_idx)))

self.idx_to_entity = {v: k for k, v in self.entity_to_idx.items()}
self.idx_to_relations = {v: k for k, v in self.relation_to_idx.items()}

# See https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
# @TODO: Ignore temporalryIf file exists
# if os.path.exists(self.path + '/train_set.npy'):
# self.train_set = np.load(file=self.path + '/train_set.npy', mmap_mode='r')

# if apply_semantic_constraint:
# (self.domain_constraints_per_rel, self.range_constraints_per_rel,
# self.domain_per_rel, self.range_per_rel) = create_constraints(self.train_set)

def get_eval_report(self) -> dict:
return load_json(self.path + "/eval_report.json")

Expand Down Expand Up @@ -253,17 +247,6 @@ def get_padded_bpe_triple_representation(self, triples: List[List[str]]) -> Tupl
padded_bpe_t.append(self.get_bpe_token_representation(str_o))
return padded_bpe_h, padded_bpe_r, padded_bpe_t

def get_domain_of_relation(self, rel: str) -> List[str]:
x = [self.idx_to_entity[i] for i in self.domain_per_rel[self.relation_to_idx[rel]]]
res = set(x)
assert len(x) == len(res)
return res

def get_range_of_relation(self, rel: str) -> List[str]:
x = [self.idx_to_entity[i] for i in self.range_per_rel[self.relation_to_idx[rel]]]
res = set(x)
assert len(x) == len(res)
return res

def set_model_train_mode(self) -> None:
"""
Expand Down
26 changes: 11 additions & 15 deletions dicee/analyse_experiments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
""" This script should be moved to dicee/scripts"""
""" This script should be moved to dicee/scripts
Example:
python dicee/analyse_experiments.py --dir Experiments --features "model" "trainMRR" "testMRR"
"""
import os
import json
import pandas as pd
Expand Down Expand Up @@ -120,19 +123,13 @@ def analyse(args):
if os.path.isdir(full_path) is False:
continue


with open(f'{full_path}/configuration.json', 'r') as f:
config = json.load(f)

try:
with open(f'{full_path}/report.json', 'r') as f:
report = json.load(f)
report = {i: report[i] for i in ['Runtime', 'NumParam']}
with open(f'{full_path}/eval_report.json', 'r') as f:
eval_report = json.load(f)
except FileNotFoundError:
print("NOT found")
continue
with open(f'{full_path}/report.json', 'r') as f:
report = json.load(f)
report = {i: report[i] for i in ['Runtime', 'NumParam']}
with open(f'{full_path}/eval_report.json', 'r') as f:
eval_report = json.load(f)
config.update(eval_report)
config.update(report)
if "Train" in config:
Expand Down Expand Up @@ -160,10 +157,9 @@ def analyse(args):
# print(df.columns)
try:
df_features = df[args.features]
except:
except KeyError:
print(f"--features ({args.features}) is not a subset of {df.columns}")
exit(1)

raise KeyError
print(df_features.to_latex(index=False, float_format="%.3f"))
path_to_save = args.dir + '/summary.csv'
df_features.to_csv(path_or_buf=path_to_save)
Expand Down
11 changes: 8 additions & 3 deletions dicee/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ def on_fit_end(self, trainer, model):
if self.initial_eval_setting:
# ADD this info back
trainer.evaluator.args.eval_model = self.initial_eval_setting

param_ensemble = torch.load(f"{self.path}/aswa.pt", torch.device("cpu"))
model.load_state_dict(param_ensemble)

if trainer.global_rank==trainer.local_rank==0:
param_ensemble = torch.load(f"{self.path}/aswa.pt", torch.device("cpu"))
model.load_state_dict(param_ensemble)

@staticmethod
def compute_mrr(trainer, model) -> float:
Expand Down Expand Up @@ -241,6 +242,10 @@ def decide(self, running_model_state_dict, ensemble_state_dict, val_running_mode
return True

def on_train_epoch_end(self, trainer, model):

if (trainer.global_rank == trainer.local_rank == 0) is False:
return None

# (1) Increment epoch counter
self.epoch_count += 1
# (2) Save the given eval setting if it is not saved.
Expand Down
7 changes: 6 additions & 1 deletion dicee/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self, **kwargs):
self.backend: str = "pandas"
"""Backend to read, process, and index input knowledge graph. pandas, polars and rdflib available"""

self.separator: str = "\s+"
"""separator for extracting head, relation and tail from a triple"""

self.trainer: str = 'torchCPUTrainer'
"""Trainer for knowledge graph embedding model"""

Expand Down Expand Up @@ -82,7 +85,6 @@ def __init__(self, **kwargs):

self.label_smoothing_rate: float = 0.0


self.num_core: int = 0
"""Number of CPUs to be used in the mini-batch loading process"""

Expand Down Expand Up @@ -136,6 +138,9 @@ def __init__(self, **kwargs):
self.continual_learning=None
"Path of a pretrained model size of LLM"

self.auto_batch_finding=False
"A flag for using auto batch finding"

def __iter__(self):
# Iterate
for k, v in self.__dict__.items():
Expand Down
Loading
Loading