From c20af461a731953d162319614921f7267341ad7c Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 7 Nov 2023 14:41:48 -0500 Subject: [PATCH] Relax some uniform tile size assumptions --- src/he2hb.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/he2hb.cc b/src/he2hb.cc index 6f63be04a..c55f15f18 100644 --- a/src/he2hb.cc +++ b/src/he2hb.cc @@ -72,7 +72,6 @@ void he2hb( const bool set_hold = 1; // Do tileGetAndHold in the bcast int64_t n = A.n(); - int64_t nb_A = A.tileNb( 0 ); GridOrder grid_order; int nprow, npcol, myrow, mycol; // For the workspace matrix (W) change the GPU distribution to row cyclic, @@ -82,8 +81,8 @@ void he2hb( A.gridinfo( &grid_order, &nprow, &npcol, &myrow, &mycol ); assert( grid_order == GridOrder::Col ); // todo: update for Row - auto tileNb = slate::func::uniform_blocksize( n, nb_A ); - auto tileRank = slate::func::process_2d_grid( GridOrder::Col, nprow, npcol ); + auto tileNb = A.tileNbFunc(); + auto tileRank = A.tileRankFunc(); int num_devices = blas::get_device_count(); auto tileDevice = slate::func::device_1d_grid( GridOrder::Col, nprow, num_devices ); @@ -135,7 +134,7 @@ void he2hb( lapack::Queue* queue = A.compute_queue( panel_device, queue_0 ); - int64_t nb = A.tileNb(0); + int64_t nb = func::max_blocksize(A.nt(), A.tileNbFunc()); size_t size_tau = (size_t) std::min( mlocal, nb ); size_t size_A = (size_t) blas::max( 1, mlocal ) * nb; size_t hsize, dsize;