diff --git a/modeling.py b/modeling.py index a7d719cfb..713be5637 100644 --- a/modeling.py +++ b/modeling.py @@ -233,7 +233,7 @@ def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, if bsz is not None: # With bi_data, the batch size should be divisible by 2. - assert bsz%2 == 0 + tf.debugging.assert_equal(bsz % 2, 0) fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2) bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2) else: