MLIR  18.0.0git
SparseGPUCodegen.cpp
Go to the documentation of this file.
1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a prototype GPU codegenerator for the sparsifier.
10 // The objective is to eventually use the right combination of
11 // direct code generation and libary calls into vendor-specific
12 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "CodegenUtils.h"
17 #include "LoopEmitter.h"
18 
28 #include "mlir/IR/IRMapping.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 // Sparse formats supported by cuSparse.
37 enum class CuSparseFormat {
38  kNone,
39  kCOO,
40  kCSR,
41  kCSC,
42  kBSR,
43 };
44 
45 //===----------------------------------------------------------------------===//
46 // Helper methods.
47 //===----------------------------------------------------------------------===//
48 
49 /// Marks the given top module as a GPU container module.
50 static void markAsGPUContainer(ModuleOp topModule) {
51  topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
52  UnitAttr::get(topModule->getContext()));
53 }
54 
55 /// Constructs a new GPU module (for GPU kernels) inside the given top module,
56 /// or returns an existing GPU module if one was built previously.
57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
58  for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
59  return op; // existing
60  markAsGPUContainer(topModule);
61  builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
62  return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
63  "sparse_kernels");
64 }
65 
66 /// Constructs a new GPU kernel in the given GPU module.
67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
68  SmallVectorImpl<Value> &args) {
69  // Get a unique kernel name. Not very creative,
70  // but we simply try kernel0, kernel1, etc.
71  unsigned kernelNumber = 0;
72  SmallString<16> kernelName;
73  do {
74  kernelName.clear();
75  ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
76  } while (gpuModule.lookupSymbol(kernelName));
77  // Then we insert a new kernel with given arguments into the module.
78  builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
79  SmallVector<Type> argsTp;
80  for (unsigned i = 0, e = args.size(); i < e; i++)
81  argsTp.push_back(args[i].getType());
82  FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
83  auto gpuFunc =
84  builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
85  gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
86  builder.getUnitAttr());
87  return gpuFunc;
88 }
89 
90 /// Constructs code to launch GPU kernel.
91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
93  SmallVectorImpl<Value> &tokens,
94  unsigned numThreads) {
95  Location loc = gpuFunc->getLoc();
97  Value one = constantIndex(builder, loc, 1);
98  Value numT = constantIndex(builder, loc, numThreads);
99  gpu::KernelDim3 gridSize = {one, one, one};
100  gpu::KernelDim3 blckSize = {numT, one, one};
101  return builder
102  .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
103  /*dynSharedMemSz*/ none, args,
104  builder.getType<gpu::AsyncTokenType>(), tokens)
105  .getAsyncToken();
106 }
107 
108 /// Maps the provided ranked host buffer into the device address space.
109 /// Writes from the host are guaranteed to be visible to device kernels
110 /// that are launched afterwards. Writes from the device are guaranteed
111 /// to be visible on the host after synchronizing with the device kernel
112 /// completion. Needs to cast the buffer to a unranked buffer.
113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
114  Value mem) {
115  MemRefType memTp = cast<MemRefType>(mem.getType());
116  UnrankedMemRefType resTp =
117  UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
118  Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
119  builder.create<gpu::HostRegisterOp>(loc, cast);
120  return cast;
121 }
122 
123 /// Unmaps the provided buffer, expecting the casted buffer.
124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
125  Value cast) {
126  builder.create<gpu::HostUnregisterOp>(loc, cast);
127 }
128 
129 /// Generates first wait in an asynchronous chain.
130 static Value genFirstWait(OpBuilder &builder, Location loc) {
131  Type tokenType = builder.getType<gpu::AsyncTokenType>();
132  return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
133  .getAsyncToken();
134 }
135 
136 /// Generates last, blocking wait in an asynchronous chain.
137 static void genBlockingWait(OpBuilder &builder, Location loc,
138  ValueRange operands) {
139  builder.create<gpu::WaitOp>(loc, Type(), operands);
140 }
141 
142 /// Allocates memory on the device.
143 /// TODO: A `host_shared` attribute could be used to indicate that
144 /// the buffer is visible by both host and device, but lowering
145 /// that feature does not seem to be fully supported yet.
146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
147  Value token) {
148  auto tp = cast<ShapedType>(mem.getType());
149  auto elemTp = tp.getElementType();
150  auto shape = tp.getShape();
151  auto memTp = MemRefType::get(shape, elemTp);
152  SmallVector<Value> dynamicSizes;
153  for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
154  if (shape[r] == ShapedType::kDynamic) {
155  Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
156  dynamicSizes.push_back(dimOp);
157  }
158  }
159  return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
160  token, dynamicSizes, ValueRange());
161 }
162 
163 // Allocates a typed buffer on the host with given size.
164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
165  Value size) {
166  const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
167  return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
168 }
169 
170 // Allocates a typed buffer on the device with given size.
171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
172  Value size, Value token) {
173  const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
174  return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
175  token, size, ValueRange());
176 }
177 
178 // Allocates a void buffer on the device with given size.
179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
180  Value token) {
181  return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
182 }
183 
184 /// Deallocates memory from the device.
185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
186  Value token) {
187  return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
188  .getAsyncToken();
189 }
190 
191 /// Copies memory between host and device (direction is implicit).
192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
193  Value src, Value token) {
194  return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
195  .getAsyncToken();
196 }
197 
198 /// Generates an alloc/copy pair.
199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
200  SmallVectorImpl<Value> &tokens) {
201  Value firstToken = genFirstWait(builder, loc);
202  auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203  Value devMem = alloc.getResult(0);
204  Value depToken = alloc.getAsyncToken(); // copy-after-alloc
205  tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
206  return devMem;
207 }
208 
209 /// Generates a memref from tensor operation.
210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
211  Value tensor) {
212  auto tensorType = llvm::cast<ShapedType>(tensor.getType());
213  auto memrefType =
214  MemRefType::get(tensorType.getShape(), tensorType.getElementType());
215  return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
216 }
217 
218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we
219 /// assume that the first buffer is the one allocated for output. We create
220 /// a set of properly chained asynchronous allocation/copy pairs to increase
221 /// overlap before launching the kernel.
222 static Value genParametersIn(OpBuilder &builder, Location loc,
223  SmallVectorImpl<Value> &scalars,
224  SmallVectorImpl<Value> &buffers,
226  SmallVectorImpl<Value> &tokens,
227  bool useHostRegistrationForOut) {
228  Value out;
229  // Scalars are passed by value.
230  for (Value s : scalars)
231  args.push_back(s);
232  // Buffers are need to be made visible on device.
233  for (Value b : buffers) {
234  if (useHostRegistrationForOut) {
235  out = genHostRegisterMemref(builder, loc, b);
236  args.push_back(b);
237  useHostRegistrationForOut = false;
238  continue;
239  }
240  args.push_back(genAllocCopy(builder, loc, b, tokens));
241  }
242  return out;
243 }
244 
245 /// Finalizes the outlined arguments. The output buffer is copied depending
246 /// on the kernel token and then deallocated. All other buffers are simply
247 /// deallocated. Then we wait for all operations to complete.
248 static void genParametersOut(OpBuilder &builder, Location loc, Value out,
249  Value kernelToken, SmallVectorImpl<Value> &scalars,
250  SmallVectorImpl<Value> &buffers,
252  SmallVectorImpl<Value> &tokens) {
253  unsigned base = scalars.size();
254  for (unsigned i = base, e = args.size(); i < e; i++) {
255  Value firstToken;
256  if (i == base) {
257  // Assumed output parameter: unregister or copy-out.
258  if (out) {
259  genHostUnregisterMemref(builder, loc, out);
260  out = Value();
261  continue;
262  }
263  firstToken =
264  genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
265  } else {
266  firstToken = genFirstWait(builder, loc);
267  }
268  tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
269  }
270 }
271 
272 /// Constructs code for new GPU kernel.
273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
274  scf::ParallelOp forallOp,
275  SmallVectorImpl<Value> &constants,
276  SmallVectorImpl<Value> &scalars,
277  SmallVectorImpl<Value> &buffers) {
278  Location loc = gpuFunc->getLoc();
279  Block &block = gpuFunc.getBody().front();
280  rewriter.setInsertionPointToStart(&block);
281 
282  // Re-generate the constants, recapture all arguments.
283  unsigned arg = 0;
284  IRMapping irMap;
285  for (Value c : constants)
286  irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
287  for (Value s : scalars)
288  irMap.map(s, block.getArgument(arg++));
289  for (Value b : buffers)
290  irMap.map(b, block.getArgument(arg++));
291 
292  // Assume 1-dimensional grid/block configuration (only x dimension),
293  // so that:
294  // row = blockIdx.x * blockDim.x + threadIdx.x
295  // inc = blockDim.x * gridDim.x
296  Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
297  Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
298  Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
299  Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
300  Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
301  Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
302  Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
303 
304  // Construct the iteration over the computational space that
305  // accounts for the fact that the total number of threads and
306  // the amount of work to be done usually do not match precisely.
307  // for (r = row; r < N; r += inc) {
308  // <loop-body>
309  // }
310  Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
311  scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
312  // The scf.for builder creates an empty block. scf.for does not allow multiple
313  // blocks in its region, so delete the block before `cloneRegionBefore` adds
314  // an additional block.
315  rewriter.eraseBlock(forOp.getBody());
316  rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
317  forOp.getRegion().begin(), irMap);
318 
319  // Done.
320  rewriter.setInsertionPointAfter(forOp);
321  rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // Library helper methods.
326 //===----------------------------------------------------------------------===//
327 
328 /// Helper to detect a + b with arguments taken from given block.
329 static bool matchAddOfArgs(Block *block, Value val) {
330  if (auto *def = val.getDefiningOp()) {
331  if (isa<arith::AddFOp, arith::AddIOp>(def)) {
332  Value a = block->getArguments()[0];
333  Value b = block->getArguments()[1];
334  return (def->getOperand(0) == a && def->getOperand(1) == b) ||
335  (def->getOperand(0) == b && def->getOperand(1) == a);
336  }
337  }
338  return false;
339 }
340 
341 /// Helper to detect a * b with arguments taken from given block.
342 static bool matchMulOfArgs(Block *block, Value val) {
343  if (auto *def = val.getDefiningOp()) {
344  if (isa<arith::MulFOp, arith::MulIOp>(def)) {
345  Value a = block->getArguments()[0];
346  Value b = block->getArguments()[1];
347  return (def->getOperand(0) == a && def->getOperand(1) == b) ||
348  (def->getOperand(0) == b && def->getOperand(1) == a);
349  }
350  }
351  return false;
352 }
353 
354 /// Helper to detect x = x + a * b
355 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
356  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
357  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
358  if (isa<arith::AddFOp, arith::AddIOp>(def)) {
359  Value x = op.getBlock()->getArguments()[2];
360  return (def->getOperand(0) == x &&
361  matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
362  (def->getOperand(1) == x &&
363  matchMulOfArgs(op.getBlock(), def->getOperand(0)));
364  }
365  }
366  return false;
367 }
368 
369 // Helper to detect c += spy(s) x (a * b)
370 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
371  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
372  // The linalg yields a custom reduce result.
373  Value s_out = op.getBlock()->getArguments()[2];
374  if (auto redOp =
375  yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
376  // The reduce consumes the output.
377  Value other;
378  if (s_out == redOp->getOperand(0))
379  other = redOp->getOperand(1);
380  else if (s_out == redOp->getOperand(1))
381  other = redOp->getOperand(0);
382  else
383  return false;
384  // The reduce op also consumes an unary which also consumes the output
385  // and does not define an absent value.
386  if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
387  if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
388  return false;
389  // And the bodies are as expected.
390  auto yieldUn = cast<sparse_tensor::YieldOp>(
391  unOp.getRegion(0).front().getTerminator());
392  auto yieldRed = cast<sparse_tensor::YieldOp>(
393  redOp.getRegion().front().getTerminator());
394  return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
395  matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
396  }
397  }
398  return false;
399 }
400 
401 /// Test for dense tensor.
402 static bool isDenseTensor(Value v) {
403  auto sTp = getSparseTensorType(v);
404  return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
405 }
406 
407 /// Test for suitable positions/coordinates width.
408 static bool isAdmissibleMetaData(SparseTensorType &aTp) {
409  return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
410  (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
411 }
412 
413 /// Test for sorted COO matrix with suitable metadata.
414 static bool isAdmissibleCOO(SparseTensorType &aTp) {
415  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
416  aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
417  aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
418  isAdmissibleMetaData(aTp);
419 }
420 
421 /// Test for CSR matrix with suitable metadata.
422 static bool isAdmissibleCSR(SparseTensorType &aTp) {
423  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
424  aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
425  aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
426 }
427 
428 /// Test for CSC matrix with suitable metadata.
429 static bool isAdmissibleCSC(SparseTensorType &aTp) {
430  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
431  aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
432  aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
433 }
434 
435 /// Test for BSR matrix with suitable metadata.
436 static bool isAdmissibleBSR(SparseTensorType &aTp) {
437  if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
438  aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
439  aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
440  // CuSparse only supports "square" blocks currently.
442  assert(dims.size() == 2);
443  return dims[0] == dims[1] && dims[0] > 1;
444  }
445  return false;
446 }
447 
448 /// Returns a suitable sparse format for the operation and given operand
449 /// types with cuSparse, or kNone if none is available.
450 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
451  SparseTensorType bTp,
452  SparseTensorType cTp, bool enableRT,
453  bool isMatVec) {
454  // The other operands have a dense type.
455  if (bTp.hasEncoding() || cTp.hasEncoding())
456  return CuSparseFormat::kNone;
457  // Now check for suitable operand type for the main operand.
458  if (isAdmissibleCOO(aTp))
459 #ifdef CUSPARSE_COO_AOS
460  return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
461 #else
462  return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
463 #endif
464  if (isAdmissibleCSR(aTp))
465  return CuSparseFormat::kCSR;
466  if (isAdmissibleCSC(aTp))
467  return CuSparseFormat::kCSC;
468  if (isAdmissibleBSR(aTp))
469  return CuSparseFormat::kBSR;
470  return CuSparseFormat::kNone;
471 }
472 
473 /// Generates the first positions/coordinates of a sparse matrix.
474 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
475  CuSparseFormat format, bool enableRT) {
476  if (format == CuSparseFormat::kCOO) {
477  // Library uses SoA COO, direct IR uses AoS COO.
478  if (enableRT)
479  return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
480  return genToCoordinatesBuffer(builder, loc, a);
481  }
482  // Formats CSR/CSC and BSR use positions at 1.
483  return genToPositions(builder, loc, a, 1);
484 }
485 
486 /// Generates the second coordinates of a sparse matrix.
487 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
488  CuSparseFormat format, bool enableRT) {
489  bool isCOO = format == CuSparseFormat::kCOO;
490  if (isCOO && !enableRT)
491  return Value(); // nothing needed
492  // Formats CSR/CSC and BSR use coordinates at 1.
493  return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
494 }
495 
496 /// Generates the sparse matrix handle.
497 static Operation *genSpMat(OpBuilder &builder, Location loc,
498  SparseTensorType &aTp, Type handleTp, Type tokenTp,
499  Value token, Value sz1, Value sz2, Value nseA,
500  Value rowA, Value colA, Value valA,
501  CuSparseFormat format, bool enableRT) {
502  if (format == CuSparseFormat::kCOO) {
503  // Library uses SoA COO, direct IR uses AoS COO.
504  if (enableRT) {
505  assert(colA);
506  return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
507  sz1, sz2, nseA, rowA, colA, valA);
508  }
509 #ifdef CUSPARSE_COO_AOS
510  assert(!colA);
511  return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
512  sz1, sz2, nseA, rowA, valA);
513 #else
514  llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
515 #endif
516  }
517  assert(colA);
518  if (format == CuSparseFormat::kCSR)
519  return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
520  sz2, nseA, rowA, colA, valA);
521  if (format == CuSparseFormat::kCSC)
522  return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
523  sz2, nseA, rowA, colA, valA);
524  // BSR requires a bit more work since we need to pass in the block size
525  // and all others sizes in terms of blocks (#block-rows, #block-cols,
526  // #nonzero-blocks).
527  assert(format == CuSparseFormat::kBSR);
529  assert(dims.size() == 2 && dims[0] == dims[1]);
530  uint64_t b = dims[0];
531  Value bSz = constantIndex(builder, loc, b);
532  Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
533  Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
534  Value bNum = builder.create<arith::DivUIOp>(
535  loc, nseA, constantIndex(builder, loc, b * b));
536  return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
537  bCols, bNum, bSz, bSz, rowA, colA,
538  valA);
539 }
540 
541 /// Match and rewrite SpMV kernel.
542 static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
543  linalg::GenericOp op, bool enableRT) {
544  Location loc = op.getLoc();
545  Value a = op.getOperand(0);
546  Value x = op.getOperand(1);
547  Value y = op.getOperand(2); // we have y = Ax
548  SmallVector<Value> tokens;
549 
550  // Only admissible sparse matrix format and dense vectors (no BSR).
554  auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
555  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
556  return failure();
557 
558  // Start sparse kernel and copy data from host to device.
559  // a : memR/memC/memV -> rowA,colA,valA
560  // x : memX -> vecX
561  // y : memY -> vecY
562  Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
563  Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
564  Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
565  Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
566  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
567  Value memV = genToValues(rewriter, loc, a);
568  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
569  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
570  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
571  Value memX = genTensorToMemref(rewriter, loc, x);
572  Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
573  Value memY = genTensorToMemref(rewriter, loc, y);
574  Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
575  genBlockingWait(rewriter, loc, tokens);
576  tokens.clear();
577 
578  // Create sparse environment and sparse matrix/dense vector handles.
579  Type indexTp = rewriter.getIndexType();
580  Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
581  Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
582  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
583  Value token = genFirstWait(rewriter, loc);
584  Operation *spGenA =
585  genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
586  nseA, rowA, colA, valA, format, enableRT);
587  Value spMatA = spGenA->getResult(0);
588  token = spGenA->getResult(1);
589  auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
590  loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
591  Value dnX = dvecX.getResult(0);
592  token = dvecX.getAsyncToken();
593  auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
594  loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
595  Value dnY = dvecY.getResult(0);
596  token = dvecY.getAsyncToken();
597  auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
598 
599  // Precompute buffersize for SpMV.
600  auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
601  loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
602  /*computeType=*/dnYType);
603  Value bufferSz = bufferComp.getResult(0);
604  token = bufferComp.getAsyncToken();
605  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
606  Value buffer = buf.getResult(0);
607  token = buf.getAsyncToken();
608 
609  // Perform the SpMV.
610  auto spmvComp = rewriter.create<gpu::SpMVOp>(
611  loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer);
612  token = spmvComp.getAsyncToken();
613 
614  // Copy data back to host and free all the resoures.
615  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
616  .getAsyncToken();
617  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
618  .getAsyncToken();
619  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
620  .getAsyncToken();
621  token = genDeallocMemRef(rewriter, loc, rowA, token);
622  if (colA)
623  token = genDeallocMemRef(rewriter, loc, colA, token);
624  token = genDeallocMemRef(rewriter, loc, valA, token);
625  token = genDeallocMemRef(rewriter, loc, buffer, token);
626  token = genDeallocMemRef(rewriter, loc, vecX, token);
627  token = genCopyMemRef(rewriter, loc, memY, vecY, token);
628  token = genDeallocMemRef(rewriter, loc, vecY, token);
629  tokens.push_back(token);
630  genBlockingWait(rewriter, loc, tokens);
631  tokens.clear();
632 
633  // Done.
634  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
635  return success();
636 }
637 
638 /// Match and rewrite SpMM kernel.
639 static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
640  linalg::GenericOp op, bool enableRT) {
641  Location loc = op.getLoc();
642  Value a = op.getOperand(0);
643  Value b = op.getOperand(1);
644  Value c = op.getOperand(2); // we have C = AB
645  SmallVector<Value> tokens;
646 
647  // Only admissible sparse matrix format and dense matrices (no BSR).
651  auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
652  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
653  return failure();
654 
655  // Start sparse kernel and copy data from host to device.
656  // a : memR/memC/memV -> rowA,colA,valA
657  // b : bufB -> matB
658  // c : bufC -> matC
659  Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
660  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
661  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
662  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
663  Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
664  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
665  Value memV = genToValues(rewriter, loc, a);
666  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
667  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
668  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
669  Value bufB = genTensorToMemref(rewriter, loc, b);
670  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
671  Value bufC = genTensorToMemref(rewriter, loc, c);
672  Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
673  genBlockingWait(rewriter, loc, tokens);
674  tokens.clear();
675 
676  // Create sparse environment and sparse matrix/dense matrix handles.
677  Type indexTp = rewriter.getIndexType();
678  Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
679  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
680  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
681  Value token = genFirstWait(rewriter, loc);
682  Operation *spGenA =
683  genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
684  nseA, rowA, colA, valA, format, enableRT);
685  Value spMatA = spGenA->getResult(0);
686  token = spGenA->getResult(1);
687  auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
688  loc, dnTensorHandleTp, tokenTp, token, matB,
689  SmallVector<Value>{szk, szn});
690  Value dnB = dmatB.getResult(0);
691  token = dmatB.getAsyncToken();
692  auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
693  loc, dnTensorHandleTp, tokenTp, token, matC,
694  SmallVector<Value>{szm, szn});
695  Value dnC = dmatC.getResult(0);
696  token = dmatC.getAsyncToken();
697  auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
698 
699  // Precompute buffersize for SpMM.
700  auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
701  loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
702  /*computeType=*/dmatCType);
703  Value bufferSz = bufferComp.getResult(0);
704  token = bufferComp.getAsyncToken();
705  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
706  Value buffer = buf.getResult(0);
707  token = buf.getAsyncToken();
708  auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
709 
710  // Perform the SpMM.
711  auto spmmComp = rewriter.create<gpu::SpMMOp>(
712  loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer);
713  token = spmmComp.getAsyncToken();
714 
715  // Copy data back to host and free all the resoures.
716  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
717  .getAsyncToken();
718  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
719  .getAsyncToken();
720  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
721  .getAsyncToken();
722  token = genDeallocMemRef(rewriter, loc, rowA, token);
723  if (colA)
724  token = genDeallocMemRef(rewriter, loc, colA, token);
725  token = genDeallocMemRef(rewriter, loc, valA, token);
726  token = genDeallocMemRef(rewriter, loc, buffer, token);
727  token = genDeallocMemRef(rewriter, loc, matB, token);
728  token = genCopyMemRef(rewriter, loc, bufC, matC, token);
729  token = genDeallocMemRef(rewriter, loc, matC, token);
730  tokens.push_back(token);
731  genBlockingWait(rewriter, loc, tokens);
732  tokens.clear();
733 
734  // Done.
735  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
736  return success();
737 }
738 
739 // Match and rewrite SpGEMM kernel.
740 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
741  linalg::GenericOp op, bool enableRT) {
742  Location loc = op.getLoc();
743  Value a = op.getOperand(0);
744  Value b = op.getOperand(1);
745  Value c = op.getOperand(2); // we have C = AB
746  SmallVector<Value> tokens;
747 
748  // Only CSR <- CSR x CSR supported.
749  auto format = CuSparseFormat::kCSR;
753  if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
754  return failure();
755 
756  // Start sparse kernel and copy data from host to device.
757  // a : amemR/amemC/amemV -> rowA,colA,valA
758  // b : bmemR/bmemC/bmemV -> rowB,colB,valB
759  // c : materializes
760  auto dnCType = cTp.getElementType();
761  Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
762  Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
763  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
764  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
765  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
766  Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
767  Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
768  Value amemV = genToValues(rewriter, loc, a);
769  Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
770  Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
771  Value bmemV = genToValues(rewriter, loc, b);
772  Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
773  Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
774  Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
775  Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
776  Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
777  Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
778  genBlockingWait(rewriter, loc, tokens);
779  tokens.clear();
780 
781  // Create sparse environment and sparse matrix/dense vector handles.
782  Type indexTp = rewriter.getIndexType();
783  Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
784  Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
785  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
786  Value token = genFirstWait(rewriter, loc);
787  Operation *spGenA =
788  genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
789  nseA, rowA, colA, valA, format, enableRT);
790  Value spMatA = spGenA->getResult(0);
791  token = spGenA->getResult(1);
792  Operation *spGenB =
793  genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
794  nseB, rowB, colB, valB, format, enableRT);
795  Value spMatB = spGenB->getResult(0);
796  token = spGenB->getResult(1);
797 
798  // Sparse matrix C materializes (also assumes beta == 0).
799  Value zero = constantIndex(rewriter, loc, 0);
800  Value one = constantIndex(rewriter, loc, 1);
801  Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
802  auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
803  Value rowC = e1.getResult(0);
804  token = e1.getAsyncToken();
805  auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
806  Value colC = e2.getResult(0); // no free needed
807  token = e2.getAsyncToken();
808  auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
809  Value valC = e3.getResult(0); // no free needed
810  token = e3.getAsyncToken();
811  Operation *spGenC =
812  genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
813  zero, rowC, colC, valC, format, enableRT);
814  Value spMatC = spGenC->getResult(0);
815  token = spGenC->getResult(1);
816 
817  // Precompute buffersizes for SpGEMM.
818  Operation *descOp =
819  rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
820  Value desc = descOp->getResult(0);
821  token = descOp->getResult(1);
822  Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
823  loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
824  gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
825  valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
826  Value bufferSz1 = work1->getResult(0);
827  token = work1->getResult(1);
828  auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
829  Value buffer1 = buf1.getResult(0);
830  token = buf1.getAsyncToken();
831  Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
832  loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
833  gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
834  bufferSz1, buffer1,
835  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
836  token = work2->getResult(1);
837 
838  // Compute step.
839  Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
840  loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
841  gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
842  valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
843  Value bufferSz2 = compute1->getResult(0);
844  token = compute1->getResult(1);
845  auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
846  Value buffer2 = buf2.getResult(0);
847  token = buf2.getAsyncToken();
848  Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
849  loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
850  gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
851  bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
852  token = compute2->getResult(1);
853 
854  // Get sizes.
855  Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
856  loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
857  Value nnz = sizes->getResult(2);
858  token = sizes->getResult(3);
859  auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
860  colC = a2.getResult(0);
861  token = a2.getAsyncToken();
862  auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
863  valC = a3.getResult(0);
864  token = a3.getAsyncToken();
865 
866  // Update C with new pointers and copy final product back into C.
867  Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
868  loc, tokenTp, token, spMatC, rowC, colC, valC);
869  token = update->getResult(0);
870  Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
871  loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
872  gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
873  token = copy->getResult(0);
874 
875  // Allocate buffers on host.
876  Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
877  Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
878  Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
879 
880  // Copy data back to host and free all the resoures.
881  token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
882  .getAsyncToken();
883  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
884  .getAsyncToken();
885  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
886  .getAsyncToken();
887  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
888  .getAsyncToken();
889  token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
890  token = genCopyMemRef(rewriter, loc, colH, colC, token);
891  token = genCopyMemRef(rewriter, loc, valH, valC, token);
892  token = genDeallocMemRef(rewriter, loc, rowA, token);
893  token = genDeallocMemRef(rewriter, loc, colA, token);
894  token = genDeallocMemRef(rewriter, loc, valA, token);
895  token = genDeallocMemRef(rewriter, loc, rowB, token);
896  token = genDeallocMemRef(rewriter, loc, colB, token);
897  token = genDeallocMemRef(rewriter, loc, valB, token);
898  token = genDeallocMemRef(rewriter, loc, rowC, token);
899  token = genDeallocMemRef(rewriter, loc, colC, token);
900  token = genDeallocMemRef(rewriter, loc, valC, token);
901  token = genDeallocMemRef(rewriter, loc, buffer1, token);
902  token = genDeallocMemRef(rewriter, loc, buffer2, token);
903  tokens.push_back(token);
904  genBlockingWait(rewriter, loc, tokens);
905  tokens.clear();
906 
907  // Done.
908  Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
909  Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
910  Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
911  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
912  ValueRange{rt, ct});
913  return success();
914 }
915 
916 // Match and rewrite 2:4 SpMM kernel.
917 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
918  linalg::GenericOp op) {
919  Location loc = op.getLoc();
920  Value A = op.getOperand(0);
921  Value B = op.getOperand(1);
922  Value C = op.getOperand(2); // we have C = AB
923  SmallVector<Value> tokens;
924 
925  // All input should be dense tensors.
926  if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
927  return failure();
928 
929  // Start sparse kernel and copy data from host to device.
930  // a : bufA -> matA
931  // b : bufB -> matB
932  // c : bufC -> matC
933  Value bufA = genTensorToMemref(rewriter, loc, A);
934  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
935  Value bufB = genTensorToMemref(rewriter, loc, B);
936  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
937  Value bufC = genTensorToMemref(rewriter, loc, C);
938  Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
939  genBlockingWait(rewriter, loc, tokens);
940  tokens.clear();
941 
942  // Create sparse environment and sparse matrix/dense vector handles.
943  Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
944  Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
945  Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
946  Type indexTp = rewriter.getIndexType();
947  Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
948  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
949  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
950  Value token = genFirstWait(rewriter, loc);
951  Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
952  loc, spMatHandleTp, tokenTp, token, szm, szk,
953  gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
954  Value spMatA = spGenA->getResult(0);
955  token = spGenA->getResult(1);
956  auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
957  loc, dnTensorHandleTp, tokenTp, token, matB,
958  SmallVector<Value>{szk, szn});
959  Value dnB = dmatB.getResult(0);
960  token = dmatB.getAsyncToken();
961  auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
962  loc, dnTensorHandleTp, tokenTp, token, matC,
963  SmallVector<Value>{szm, szn});
964  Value dnC = dmatC.getResult(0);
965  token = dmatC.getAsyncToken();
966  auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
967 
968  // Precompute buffersize for SpMM.
969  SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
970  TypeRange bufferTypes(bufferTypes_);
971  auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
972  loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
973  gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
974  /*computeType=*/dmatCType);
975  token = bufferComp.getAsyncToken();
976 
977  // Allocate buffers on host.
978  Value bufferSz1 = bufferComp.getResult(0);
979  auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
980  Value buffer1 = buf1.getResult(0);
981  token = buf1.getAsyncToken();
982  Value bufferSz2 = bufferComp.getResult(1);
983  auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
984  Value buffer2 = buf2.getResult(0);
985  token = buf2.getAsyncToken();
986  Value bufferSz3 = bufferComp.getResult(2);
987  auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
988  Value buffer3 = buf3.getResult(0);
989  token = buf3.getAsyncToken();
990 
991  // Perform the SpMM.
992  auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
993  auto spmmComp = rewriter.create<gpu::SpMMOp>(
994  loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
995  SmallVector<Value>{buffer1, buffer2, buffer3});
996  token = spmmComp.getAsyncToken();
997 
998  // Copy data back to host and free all the resources.
999  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1000  .getAsyncToken();
1001  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1002  .getAsyncToken();
1003  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1004  .getAsyncToken();
1005  SmallVector<Value> newDynamicSizes;
1006  token = genDeallocMemRef(rewriter, loc, buffer1, token);
1007  token = genDeallocMemRef(rewriter, loc, buffer2, token);
1008  token = genDeallocMemRef(rewriter, loc, buffer3, token);
1009  token = genDeallocMemRef(rewriter, loc, matA, token);
1010  token = genDeallocMemRef(rewriter, loc, matB, token);
1011  token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1012  token = genDeallocMemRef(rewriter, loc, matC, token);
1013  tokens.push_back(token);
1014  genBlockingWait(rewriter, loc, tokens);
1015  tokens.clear();
1016 
1017  // Done.
1018  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1019  return success();
1020 }
1021 
1022 /// Match and rewrite SDDMM kernel.
1023 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
1024  linalg::GenericOp op, bool enableRT) {
1025  Location loc = op.getLoc();
1026  Value a = op.getOperand(0);
1027  Value b = op.getOperand(1);
1028  Value c = op.getOperand(2);
1029  SmallVector<Value> tokens;
1030 
1031  // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1035  auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
1036  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1037  format == CuSparseFormat::kCSC)
1038  return failure();
1039 
1040  // The SDDMM does the in-place operation.
1041  // Start sparse kernel and copy data from host to device.
1042  // a : bufA -> matA
1043  // b : bufB -> matB
1044  // c : memR/memC/memV -> rowC,colC,valC
1045  Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
1046  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
1047  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
1048  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
1049  Value bufA = genTensorToMemref(rewriter, loc, a);
1050  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
1051  Value bufB = genTensorToMemref(rewriter, loc, b);
1052  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
1053  Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1054  Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
1055  Value memV = genToValues(rewriter, loc, c);
1056  Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1057  Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
1058  Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1059  genBlockingWait(rewriter, loc, tokens);
1060  tokens.clear();
1061 
1062  // Create sparse environment and sparse matrix/dense matrix handles.
1063  Type indexTp = rewriter.getIndexType();
1064  Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1065  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1066  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1067  Value token = genFirstWait(rewriter, loc);
1068  auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
1069  loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
1070  Value dnA = dmatA.getResult(0);
1071  token = dmatA.getAsyncToken();
1072  auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
1073  loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
1074  Value dnB = dmatB.getResult(0);
1075  token = dmatB.getAsyncToken();
1076  Operation *spGenC =
1077  genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1078  nseC, rowC, colC, valC, format, enableRT);
1079  Value spMatC = spGenC->getResult(0);
1080  token = spGenC->getResult(1);
1081  auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1082 
1083  // Precompute buffersize for SDDMM.
1084  auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
1085  loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1086  Value bufferSz = bufferComp.getResult(0);
1087  token = bufferComp.getAsyncToken();
1088  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1089  Value buffer = buf.getResult(0);
1090  token = buf.getAsyncToken();
1091 
1092  // Perform the SDDMM.
1093  auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1094  spMatC, dnCType, buffer);
1095  token = sddmmComp.getAsyncToken();
1096 
1097  // Copy data back to host and free all the resoures.
1098  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
1099  .getAsyncToken();
1100  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1101  .getAsyncToken();
1102  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
1103  .getAsyncToken();
1104  token = genDeallocMemRef(rewriter, loc, buffer, token);
1105  token = genDeallocMemRef(rewriter, loc, matA, token);
1106  token = genDeallocMemRef(rewriter, loc, matB, token);
1107  token = genDeallocMemRef(rewriter, loc, rowC, token);
1108  if (colC)
1109  token = genDeallocMemRef(rewriter, loc, colC, token);
1110  token = genCopyMemRef(rewriter, loc, memV, valC, token);
1111  token = genDeallocMemRef(rewriter, loc, valC, token);
1112  tokens.push_back(token);
1113  genBlockingWait(rewriter, loc, tokens);
1114  tokens.clear();
1115 
1116  // Done.
1117  rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1118  return success();
1119 }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // Rewriting rules for direct code generation.
1123 //===----------------------------------------------------------------------===//
1124 
1125 /// Proof-of-concept rewriter. This rule generates a GPU implementation
1126 /// for each outermost forall loop generated by the sparsifier.
1127 /// TODO: right now works with parallelization-strategy=dense-outer-loop
1128 /// but give this its own flags in the future
1129 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1131 
1132  ForallRewriter(MLIRContext *context, unsigned nT)
1133  : OpRewritePattern(context), numThreads(nT){};
1134 
1135  LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1136  PatternRewriter &rewriter) const override {
1137  // Reject inadmissible loop form.
1138  // Essentially only accept a loop, generated by the sparsifier,
1139  // of the form
1140  // forall (i = 0; i < N; i++)
1141  // so that cyclic scheduling over the threads is easy.
1142  if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1143  forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1144  !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1145  !matchPattern(forallOp.getStep()[0], m_One()))
1146  return failure();
1147  // Collect every value that is computed outside the parallel loop.
1148  SetVector<Value> invariants; // stable iteration!
1149  forallOp->walk([&](Operation *op) {
1150  // Collect all values of admissible ops.
1151  for (OpOperand &o : op->getOpOperands()) {
1152  Value val = o.get();
1153  Block *block;
1154  if (auto arg = dyn_cast<BlockArgument>(val))
1155  block = arg.getOwner();
1156  else
1157  block = val.getDefiningOp()->getBlock();
1158  if (!isNestedIn(block, forallOp))
1159  invariants.insert(val);
1160  }
1161  });
1162  // Outline the outside values as proper parameters. Fail when sharing
1163  // value between host and device is not straightforward.
1164  SmallVector<Value> constants;
1165  SmallVector<Value> scalars;
1166  SmallVector<Value> buffers;
1167  for (Value val : invariants) {
1168  Type tp = val.getType();
1169  if (val.getDefiningOp<arith::ConstantOp>())
1170  constants.push_back(val);
1171  else if (isa<FloatType>(tp) || tp.isIntOrIndex())
1172  scalars.push_back(val);
1173  else if (isa<MemRefType>(tp))
1174  buffers.push_back(val);
1175  else
1176  return failure(); // don't know how to share
1177  }
1178  // Pass outlined non-constant values.
1179  // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1180  // keep the feature at all (either through a heuristic or compiler
1181  // option for gpu codegen).
1182  Location loc = forallOp->getLoc();
1183  SmallVector<Value> args;
1184  SmallVector<Value> tokens;
1185  Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1186  /*useHostRegistrationForOut=*/false);
1187  // Set up GPU module and construct GPU function.
1188  auto saveIp = rewriter.saveInsertionPoint();
1189  ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1190  auto gpuModule = genGPUModule(rewriter, topModule);
1191  auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1192  genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1193  // Generate code that launches the kernel asynchronously, blocking on all
1194  // opens tokens and yielding a new token for the output.
1195  // TODO: Passing in tokens to launch up does not seem to be properly lowered
1196  // by cubin yet, hence the current blocking wait.
1197  rewriter.restoreInsertionPoint(saveIp);
1198  genBlockingWait(rewriter, loc, tokens);
1199  tokens.clear();
1200  Value kernelToken =
1201  genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1202  // Finalize the outlined arguments.
1203  genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1204  tokens);
1205  genBlockingWait(rewriter, loc, tokens);
1206  rewriter.eraseOp(forallOp);
1207  return success();
1208  }
1209 
1210 private:
1211  // Helper method to see if block appears in given loop.
1212  static bool isNestedIn(Block *block, scf::ParallelOp forallOp) {
1213  for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) {
1214  if (o == forallOp)
1215  return true;
1216  }
1217  return false;
1218  }
1219 
1220  unsigned numThreads;
1221 };
1222 
1223 //===----------------------------------------------------------------------===//
1224 // Rewriting rules for library recognition and code generation.
1225 //===----------------------------------------------------------------------===//
1226 
1227 /// Proof-of-concept rewriter. This rule recognizes certain math kernels
1228 /// and replaces these with corresponding calls into a sparse library.
1229 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1231 
1232  LinalgOpRewriter(MLIRContext *context, bool rt)
1233  : OpRewritePattern(context), enableRT(rt) {}
1234 
1235  LogicalResult matchAndRewrite(linalg::GenericOp op,
1236  PatternRewriter &rewriter) const override {
1237  if (op.getNumDpsInits() != 1)
1238  return failure(); // reject multi-output
1239 
1240  const unsigned numLoops = op.getNumLoops();
1241  const unsigned numTensors = op->getNumOperands();
1242  const auto iteratorTypes = op.getIteratorTypesArray();
1243  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1244 
1245  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1246  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1247  AffineExpr i, j, k;
1248  bindDims(getContext(), i, j, k);
1249 
1250  // TODO: more robust patterns, tranposed versions, more kernels,
1251  // identify alpha and beta and pass them to the CUDA calls.
1252 
1253  // Recognize a SpMV kernel.
1254  if (numLoops == 2 && numTensors == 3 &&
1255  linalg::isParallelIterator(iteratorTypes[0]) &&
1256  linalg::isReductionIterator(iteratorTypes[1]) &&
1257  maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1258  return rewriteSpMV(rewriter, op, enableRT);
1259  }
1260 
1261  // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1262  if (numLoops == 3 && numTensors == 3 &&
1263  linalg::isParallelIterator(iteratorTypes[0]) &&
1264  linalg::isParallelIterator(iteratorTypes[1]) &&
1265  linalg::isReductionIterator(iteratorTypes[2]) &&
1266  maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1267  if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1268  return rewriteSpGEMM(rewriter, op, enableRT);
1269  if (op->getAttr("DENSE24"))
1270  return rewrite2To4SpMM(rewriter, op);
1271  return rewriteSpMM(rewriter, op, enableRT);
1272  }
1273 
1274  // Recognize a SDDMM kernel.
1275  if (numLoops == 3 && numTensors == 3 &&
1276  linalg::isParallelIterator(iteratorTypes[0]) &&
1277  linalg::isParallelIterator(iteratorTypes[1]) &&
1278  linalg::isReductionIterator(iteratorTypes[2]) &&
1279  maps == infer({{i, k}, {k, j}, {i, j}}) &&
1280  matchSumReductionOfMulUnary(op)) {
1281  return rewriteSDDMM(rewriter, op, enableRT);
1282  }
1283 
1284  return failure();
1285  }
1286 
1287 private:
1288  bool enableRT;
1289 };
1290 
1291 } // namespace
1292 
1293 //===----------------------------------------------------------------------===//
1294 // Public method for populating GPU rewriting rules.
1295 //
1296 // Currently two set of rewriting rules are made available. The first set
1297 // implements direct code generation, currently by means of convering the
1298 // outermost paralell loop into GPU threads. The second set implements
1299 // libary recognition of a set of sparse operations. Eventually, the right
1300 // combination of these two approaches has to be found.
1301 //===----------------------------------------------------------------------===//
1302 
1304  unsigned numThreads) {
1305  patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
1306 }
1307 
1309  bool enableRT) {
1310  patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
1311 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:292
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
IndexType getIndexType()
Definition: Builders.cpp:71
IntegerType getI8Type()
Definition: Builders.cpp:79
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:370
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:375
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:113
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:253
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
bool isPermutation() const
Returns true if the dimToLvl mapping is a permutation.
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:88
Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor)
Infers the result type and generates ToCoordinatesBufferOp.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
Value genToValues(OpBuilder &builder, Location loc, Value tensor)
Infers the result type and generates ToValuesOp.
Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl)
Infers the result type and generates ToPositionsOp.
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor, Level lvl, Level cooStart)
Infers the result type and generates ToCoordinatesOp.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:334
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:37
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.