diff --git a/python/cugraph/traversal/traveling_salesperson.py b/python/cugraph/traversal/traveling_salesperson.py index 80f9cd7441b..ae17555e4ea 100644 --- a/python/cugraph/traversal/traveling_salesperson.py +++ b/python/cugraph/traversal/traveling_salesperson.py @@ -20,7 +20,7 @@ def traveling_salesperson(pos_list, restarts=100000, beam_search=True, k=4, - nstart=1, + nstart=None, verbose=False, ): """ @@ -62,7 +62,7 @@ def traveling_salesperson(pos_list, null_check(pos_list['x']) null_check(pos_list['y']) - if not pos_list[pos_list['vertex'] == nstart].index: + if nstart is not None and not pos_list[pos_list['vertex'] == nstart].index: raise ValueError("nstart should be in vertex ids") route, cost = traveling_salesperson_wrapper.traveling_salesperson( diff --git a/python/cugraph/traversal/traveling_salesperson_wrapper.pyx b/python/cugraph/traversal/traveling_salesperson_wrapper.pyx index b728c3ff37d..5f87c42a638 100644 --- a/python/cugraph/traversal/traveling_salesperson_wrapper.pyx +++ b/python/cugraph/traversal/traveling_salesperson_wrapper.pyx @@ -31,7 +31,7 @@ def traveling_salesperson(pos_list, restarts=100000, beam_search=True, k=4, - nstart=1, + nstart=None, verbose=False, renumber=True, ): @@ -43,6 +43,7 @@ def traveling_salesperson(pos_list, cdef uintptr_t x_pos = NULL cdef uintptr_t y_pos = NULL + pos_list['vertex'] = pos_list['vertex'].astype(np.int32) pos_list['x'] = pos_list['x'].astype(np.float32) pos_list['y'] = pos_list['y'].astype(np.float32) x_pos = pos_list['x'].__cuda_array_interface__['data'][0] @@ -61,7 +62,10 @@ def traveling_salesperson(pos_list, cdef uintptr_t vtx_ptr = NULL vtx_ptr = pos_list['vertex'].__cuda_array_interface__['data'][0] - renumbered_nstart = pos_list[pos_list['vertex'] == nstart].index[0] + if nstart is None: + renumbered_nstart = 0 + else: + renumbered_nstart = pos_list[pos_list['vertex'] == nstart].index[0] final_cost = c_traveling_salesperson(handle_[0], vtx_ptr,