+
+
+
+def_single_tensor_adopt(
+ params:List[Tensor],
+ grads:List[Tensor],
+ exp_avgs:List[Tensor],
+ exp_avg_sqs:List[Tensor],
+ state_steps:List[Tensor],
+ grad_scale:Optional[Tensor],
+ found_inf:Optional[Tensor],
+ *,
+ has_complex:bool,
+ beta1:float,
+ beta2:float,
+ lr:Union[float,Tensor],
+ clip_lambda:Optional[Callable[[int],float]],
+ weight_decay:float,
+ decouple:bool,
+ eps:float,
+ maximize:bool,
+ capturable:bool,
+ differentiable:bool,
+):
+ assertgrad_scaleisNoneandfound_infisNone
+
+ iftorch.jit.is_scripting():
+ # this assert is due to JIT being dumb and not realizing that the ops below
+ # have overloads to handle both float and Tensor lrs, so we just assert it's
+ # a float since most people using JIT are using floats
+ assertisinstance(lr,float)
+
+ fori,paraminenumerate(params):
+ grad=grads[i]ifnotmaximizeelse-grads[i]
+ exp_avg=exp_avgs[i]
+ exp_avg_sq=exp_avg_sqs[i]
+ step_t=state_steps[i]
+
+ # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
+ ifnottorch._utils.is_compiling()andcapturable:
+ capturable_supported_devices=_get_capturable_supported_devices()
+ assert(
+ param.device.type==step_t.device.type
+ andparam.device.typeincapturable_supported_devices
+ ),f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
+
+ step=step_tifcapturableordifferentiableelse_get_value(step_t)
+
+ ifweight_decay!=0andnotdecouple:
+ grad=grad.add(param,alpha=weight_decay)
+
+ iftorch.is_complex(param):
+ grad=torch.view_as_real(grad)
+ ifexp_avgisnotNone:
+ exp_avg=torch.view_as_real(exp_avg)
+ ifexp_avg_sqisnotNone:
+ exp_avg_sq=torch.view_as_real(exp_avg_sq)
+ param=torch.view_as_real(param)
+
+ ifstep==0:
+ exp_avg_sq.addcmul_(grad,grad.conj())
+ # update step
+ step_t+=1
+ continue
+
+ ifweight_decay!=0anddecouple:
+ param.add_(param,alpha=-lr*weight_decay)
+
+ denom=torch.clamp(exp_avg_sq.sqrt(),eps)
+ normed_grad=grad.div(denom)
+ ifclip_lambdaisnotNone:
+ clip=clip_lambda(step)
+ normed_grad.clamp_(-clip,clip)
+
+ exp_avg.lerp_(normed_grad,1-beta1)
+
+ param.add_(exp_avg,alpha=-lr)
+ exp_avg_sq.mul_(beta2).addcmul_(grad,grad.conj(),value=1-beta2)
+
+ # update step
+ step_t+=1
+
+
+def_multi_tensor_adopt(
+ params:List[Tensor],
+ grads:List[Tensor],
+ exp_avgs:List[Tensor],
+ exp_avg_sqs:List[Tensor],
+ state_steps:List[Tensor],
+ grad_scale:Optional[Tensor],
+ found_inf:Optional[Tensor],
+ *,
+ has_complex:bool,
+ beta1:float,
+ beta2:float,
+ lr:Union[float,Tensor],
+ clip_lambda:Optional[Callable[[int],float]],
+ weight_decay:float,
+ decouple:bool,
+ eps:float,
+ maximize:bool,
+ capturable:bool,
+ differentiable:bool,
+):
+ iflen(params)==0:
+ return
+
+ ifisinstance(lr,Tensor)andnotcapturable:
+ raiseRuntimeError(
+ "lr as a Tensor is not supported for capturable=False and foreach=True"
+ )
+
+ # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
+ ifnottorch._utils.is_compiling()andcapturable:
+ capturable_supported_devices=_get_capturable_supported_devices(
+ supports_xla=False
+ )
+ assertall(
+ p.device.type==step.device.type
+ andp.device.typeincapturable_supported_devices
+ forp,stepinzip(params,state_steps)
+ ),f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
+
+ assertgrad_scaleisNoneandfound_infisNone
+
+ assertnotdifferentiable,"_foreach ops don't support autograd"
+
+ grouped_tensors=Optimizer._group_tensors_by_device_and_dtype(
+ [params,grads,exp_avgs,exp_avg_sqs,state_steps]# type: ignore[list-item]
+ )
+ for(
+ device_params_,
+ device_grads_,
+ device_exp_avgs_,
+ device_exp_avg_sqs_,
+ device_state_steps_,
+ ),_ingrouped_tensors.values():
+ device_params=cast(List[Tensor],device_params_)
+ device_grads=cast(List[Tensor],device_grads_)
+ device_exp_avgs=cast(List[Tensor],device_exp_avgs_)
+ device_exp_avg_sqs=cast(List[Tensor],device_exp_avg_sqs_)
+ device_state_steps=cast(List[Tensor],device_state_steps_)
+
+ # Handle complex parameters
+ ifhas_complex:
+ _view_as_real(
+ device_params,device_grads,device_exp_avgs,device_exp_avg_sqs
+ )
+
+ ifmaximize:
+ device_grads=torch._foreach_neg(device_grads)# type: ignore[assignment]
+
+ ifweight_decay!=0andnotdecouple:
+ # Re-use the intermediate memory (device_grads) already allocated for maximize
+ ifmaximize:
+ torch._foreach_add_(device_grads,device_params,alpha=weight_decay)
+ else:
+ device_grads=torch._foreach_add(# type: ignore[assignment]
+ device_grads,device_params,alpha=weight_decay
+ )
+
+ ifdevice_state_steps[0]==0:
+ torch._foreach_addcmul_(device_exp_avg_sqs,device_grads,device_grads)
+
+ # Update steps
+ # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
+ # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
+ # wrapped it once now. The alpha is required to assure we go to the right overload.
+ ifnottorch._utils.is_compiling()anddevice_state_steps[0].is_cpu:
+ torch._foreach_add_(
+ device_state_steps,torch.tensor(1.0,device="cpu"),alpha=1.0
+ )
+ else:
+ torch._foreach_add_(device_state_steps,1)
+
+ continue
+
+ ifweight_decay!=0anddecouple:
+ torch._foreach_add_(device_params,device_params,alpha=-lr*weight_decay)
+
+ exp_avg_sq_sqrt=torch._foreach_sqrt(device_exp_avg_sqs)
+ torch._foreach_maximum_(exp_avg_sq_sqrt,eps)
+
+ normed_grad=torch._foreach_div(device_grads,exp_avg_sq_sqrt)
+ ifclip_lambdaisnotNone:
+ clip=clip_lambda(device_state_steps[0])
+ torch._foreach_maximum_(normed_grad,-clip)
+ torch._foreach_minimum_(normed_grad,clip)
+
+ torch._foreach_lerp_(device_exp_avgs,normed_grad,1-beta1)
+
+ torch._foreach_add_(device_params,device_exp_avgs,alpha=-lr)
+ torch._foreach_mul_(device_exp_avg_sqs,beta2)
+ torch._foreach_addcmul_(
+ device_exp_avg_sqs,device_grads,device_grads,value=1-beta2
+ )
+
+ # Update steps
+ # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
+ # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
+ # wrapped it once now. The alpha is required to assure we go to the right overload.
+ ifnottorch._utils.is_compiling()anddevice_state_steps[0].is_cpu:
+ torch._foreach_add_(
+ device_state_steps,torch.tensor(1.0,device="cpu"),alpha=1.0
+ )
+ else:
+ torch._foreach_add_(device_state_steps,1)
+
+
+
+[docs]
+@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
+defadopt(
+ params:List[Tensor],
+ grads:List[Tensor],
+ exp_avgs:List[Tensor],
+ exp_avg_sqs:List[Tensor],
+ state_steps:List[Tensor],
+ # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
+ # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
+ foreach:Optional[bool]=None,
+ capturable:bool=False,
+ differentiable:bool=False,
+ fused:Optional[bool]=None,
+ grad_scale:Optional[Tensor]=None,
+ found_inf:Optional[Tensor]=None,
+ has_complex:bool=False,
+ *,
+ beta1:float,
+ beta2:float,
+ lr:Union[float,Tensor],
+ clip_lambda:Optional[Callable[[int],float]],
+ weight_decay:float,
+ decouple:bool,
+ eps:float,
+ maximize:bool,
+):
+r"""Functional API that performs ADOPT algorithm computation.
+
+ """
+ # Respect when the user inputs False/True for foreach or fused. We only want to change
+ # the default when neither have been user-specified. Note that we default to foreach
+ # and pass False to use_fused. This is not a mistake--we want to give the fused impl
+ # bake-in time before making it the default, even if it is typically faster.
+ iffusedisNoneandforeachisNone:
+ _,foreach=_default_to_fused_or_foreach(
+ params,differentiable,use_fused=False
+ )
+ # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
+ ifforeachandisinstance(lr,Tensor)andnotcapturable:
+ foreach=False
+ iffusedisNone:
+ fused=False
+ ifforeachisNone:
+ foreach=False
+
+ # this check is slow during compilation, so we skip it
+ # if it's strictly needed we can add this check back in dynamo
+ ifnottorch._utils.is_compiling()andnotall(
+ isinstance(t,torch.Tensor)fortinstate_steps
+ ):
+ raiseRuntimeError(
+ "API has changed, `state_steps` argument must contain a list of singleton tensors"
+ )
+
+ ifforeachandtorch.jit.is_scripting():
+ raiseRuntimeError("torch.jit.script not supported with foreach optimizers")
+ iffusedandtorch.jit.is_scripting():
+ raiseRuntimeError("torch.jit.script not supported with fused optimizers")
+
+ iffusedandnottorch.jit.is_scripting():
+ func=_fused_adopt# noqa: F821
+ elifforeachandnottorch.jit.is_scripting():
+ func=_multi_tensor_adopt
+ else:
+ func=_single_tensor_adopt
+
+ func(
+ params,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ has_complex=has_complex,
+ beta1=beta1,
+ beta2=beta2,
+ lr=lr,
+ clip_lambda=clip_lambda,
+ weight_decay=weight_decay,
+ decouple=decouple,
+ eps=eps,
+ maximize=maximize,
+ capturable=capturable,
+ differentiable=differentiable,
+ grad_scale=grad_scale,
+ found_inf=found_inf,
+ )
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/dicee/models/base_model.html b/_modules/dicee/models/base_model.html
index 163da5ee..3d39cd84 100644
--- a/_modules/dicee/models/base_model.html
+++ b/_modules/dicee/models/base_model.html
@@ -94,6 +94,7 @@
lr=self.learning_rate,lambd=0.0001,alpha=0.75,weight_decay=self.weight_decay)else:
- raiseKeyError()
+ raiseKeyError(f"{self.optimizer_name} is not found!")print(self.selected_optimizer)returnself.selected_optimizer
+
+
+"""
+ def __getattr__(self, name):
+ # Create a function that will call the same attribute/method on each model
+ def method(*args, **kwargs):
+ results = []
+ for model in self.models:
+ attr = getattr(model, name)
+ if callable(attr):
+ # If it's a method, call it with provided arguments
+ results.append(attr(*args, **kwargs))
+ else:
+ # If it's an attribute, just get its value
+ results.append(attr)
+ return results
+ return method
+ """
+
+[docs]
+ def__str__(self):
+ returnf"EnsembleKGE of {len(self.models)}{self.models[0]}"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/dicee/scripts/run.html b/_modules/dicee/scripts/run.html
index 89800321..d9165617 100644
--- a/_modules/dicee/scripts/run.html
+++ b/_modules/dicee/scripts/run.html
@@ -134,9 +134,9 @@
Source code for dicee.scripts.run
help="Available knowledge graph embedding models. ""To use other knowledge graph embedding models available in python, e.g.,""**Pykeen_BoxE** and add this into choices")
- parser.add_argument('--optim',type=str,default='Adam',
+ parser.add_argument('--optim',type=str,default='Adopt',help='An optimizer',
- choices=['Adam','AdamW','SGD',"NAdam","Adagrad","ASGD"])
+ choices=['Adam','AdamW','SGD',"NAdam","Adagrad","ASGD","Adopt"])parser.add_argument('--embedding_dim',type=int,default=32,help='Number of dimensions for an embedding vector. ')parser.add_argument("--num_epochs",type=int,default=10,help='Number of epochs for training. ')
@@ -147,8 +147,8 @@
print(e)print(model.name)print('Could not save the model correctly')
+ elifisinstance(model,EnsembleKGE):
+ fori,partial_modelinenumerate(model):
+ new_path=path.replace("model.pt",f"model_partial_{i}.pt")
+ torch.save(partial_model.state_dict(),new_path)else:torch.save(model.model.state_dict(),path)
@@ -437,13 +442,7 @@
Source code for dicee.static_funcs
assertfull_storage_pathisnotNoneassertisinstance(model_name,str)assertlen(model_name)>1
-
- # (1) Save pytorch model in trained_model .
- ifhasattr(trained_model,"is_ensemble"):
- fori,kgeinenumerate(trained_model):
- torch.save(kge.state_dict(),full_storage_path+f'/{model_name}_{i}.pt')
- else:
- save_checkpoint_model(model=trained_model,path=full_storage_path+f'/{model_name}.pt')
+ save_checkpoint_model(model=trained_model,path=full_storage_path+f'/{model_name}.pt')ifsave_embeddings_as_csv:entity_emb,relation_ebm=trained_model.get_embeddings()
diff --git a/_modules/dicee/static_funcs_training.html b/_modules/dicee/static_funcs_training.html
index bb045541..76907c94 100644
--- a/_modules/dicee/static_funcs_training.html
+++ b/_modules/dicee/static_funcs_training.html
@@ -138,13 +138,13 @@
sanity_checking_with_arguments(args)ifargs.model=='Shallom':args.scoring_technique='KvsAll'
- # TODO: we need need to define as "NONE ?
+
ifargs.normalization=='None':args.normalization=Noneassertargs.normalizationin[None,'LayerNorm','BatchNorm1d']
diff --git a/_modules/dicee/trainer/dice_trainer.html b/_modules/dicee/trainer/dice_trainer.html
index 226742e5..08cd4be9 100644
--- a/_modules/dicee/trainer/dice_trainer.html
+++ b/_modules/dicee/trainer/dice_trainer.html
@@ -97,7 +97,8 @@
-[docs]
- def__getattr__(self,name):
- # Create a function that will call the same attribute/method on each model
- defmethod(*args,**kwargs):
- results=[]
- formodelinself.models:
- attr=getattr(model,name)
- ifcallable(attr):
- # If it's a method, call it with provided arguments
- results.append(attr(*args,**kwargs))
- else:
- # If it's an attribute, just get its value
- results.append(attr)
- returnresults
- returnmethod
-
-
-
-[docs]
- def__str__(self):
- returnf"EnsembleKGE of {len(self.models)}{self.models[0]}"
assertisinstance(knowledge_graph,np.memmap)orisinstance(knowledge_graph,KG), \
f"knowledge_graph must be an instance of KG or np.memmap. Currently {type(knowledge_graph)}"ifself.args.num_folds_for_cv==0:
- self.trainer:Union[MP,TorchTrainer,TorchDDPTrainer,pl.Trainer]
+ self.trainer:Union[TensorParallel,TorchTrainer,TorchDDPTrainer,pl.Trainer]self.trainer=self.initialize_trainer(callbacks=get_callbacks(self.args))
-
model,form_of_labelling=self.initialize_or_load_model()self.trainer.evaluator=self.evaluatorself.trainer.dataset=knowledge_graphself.trainer.form_of_labelling=form_of_labelling
- ifisinstance(self.trainer,MP):
- model=EnsembleKGE(model)
- self.trainer.fit(model,train_dataloaders=self.init_dataloader(self.init_dataset()))
+ # TODO: Later, maybe we should write a callback to save the models in disk
+
+ ifisinstance(self.trainer,TensorParallel):
+ model=self.trainer.fit(model,train_dataloaders=self.init_dataloader(self.init_dataset()))
+ assertisinstance(model,EnsembleKGE)
+ else:
+ self.trainer.fit(model,train_dataloaders=self.init_dataloader(self.init_dataset()))returnmodel,form_of_labelling
diff --git a/_modules/dicee/trainer/model_parallelism.html b/_modules/dicee/trainer/model_parallelism.html
index 4afffb35..f55b3749 100644
--- a/_modules/dicee/trainer/model_parallelism.html
+++ b/_modules/dicee/trainer/model_parallelism.html
@@ -91,78 +91,256 @@
+[docs]
+defforward_backward_update_loss(z:Tuple,ensemble_model):
+ # () Get the i-th batch of data points.
+ x_batch,y_batch=extract_input_outputs(z)
+ # () Move the batch of labels into the master GPU : GPU-0
+ y_batch=y_batch.to("cuda:0")
+ # () Forward Pass on the batch. Yhat located on the master GPU.
+ yhat=ensemble_model(x_batch)
+ # () Compute the loss
+ loss=torch.nn.functional.binary_cross_entropy_with_logits(yhat,y_batch)
+ # () Compute the gradient of the loss w.r.t. parameters.
+ loss.backward()
+ # () Parameter update.
+ ensemble_model.step()
+ # () Report the batch and epoch losses.
+ batch_loss=loss.item()
+ # () Accumulate batch loss
+ returnbatch_loss
-[docs]
- defextract_input_outputs_set_device(self,batch:list)->Tuple:
-"""
- Construct inputs and outputs from a batch of inputs with outputs From a batch of inputs and put
-
- Arguments
- ----------
- batch: (list) mini-batch inputs on CPU
-
- Returns
- -------
- (tuple) mini-batch on select device
- """
- iflen(batch)==2:
- x_batch,y_batch=batch
-
- ifisinstance(x_batch,tuple):
- # Triple and Byte
- returnx_batch,y_batch
- else:
- # (1) NegSample: x is a triple, y is a float
- x_batch,y_batch=batch
- returnx_batch.to(self.device),y_batch.to(self.device)
- eliflen(batch)==3:
- x_batch,y_idx_batch,y_batch,=batch
- x_batch,y_idx_batch,y_batch=x_batch.to(self.device),y_idx_batch.to(self.device),y_batch.to(
- self.device)
- return(x_batch,y_idx_batch),y_batch
- else:
- print(len(batch))
- print("Unexpected batch shape..")
- raiseRuntimeError
Parameters need to be specified as collections that have a deterministic
+ordering that is consistent between runs. Examples of objects that don’t
+satisfy those properties are sets and iterators over values of dictionaries.
+
+
+
Parameters:
+
+
params (iterable) – an iterable of torch.Tensor s or
+dict s. Specifies what Tensors should be optimized.
+
defaults – (dict): a dict containing default values of optimization
+options (used when a parameter group doesn’t specify them).
Parameters need to be specified as collections that have a deterministic
+ordering that is consistent between runs. Examples of objects that don’t
+satisfy those properties are sets and iterators over values of dictionaries.
+
+
+
Parameters:
+
+
params (iterable) – an iterable of torch.Tensor s or
+dict s. Specifies what Tensors should be optimized.
+
defaults – (dict): a dict containing default values of optimization
+options (used when a parameter group doesn’t specify them).