MLIR  22.0.0git
SparseGPUCodegen.cpp
Go to the documentation of this file.
1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a prototype GPU codegenerator for the sparsifier.
10 // The objective is to eventually use the right combination of
11 // direct code generation and libary calls into vendor-specific
12 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "Utils/CodegenUtils.h"
17 #include "Utils/LoopEmitter.h"
18 
28 #include "mlir/IR/IRMapping.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 // Sparse formats supported by cuSparse.
37 enum class CuSparseFormat {
38  kNone,
39  kCOO,
40  kCSR,
41  kCSC,
42  kBSR,
43 };
44 
45 //===----------------------------------------------------------------------===//
46 // Helper methods.
47 //===----------------------------------------------------------------------===//
48 
49 /// Marks the given top module as a GPU container module.
50 static void markAsGPUContainer(ModuleOp topModule) {
51  topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
52  UnitAttr::get(topModule->getContext()));
53 }
54 
55 /// Constructs a new GPU module (for GPU kernels) inside the given top module,
56 /// or returns an existing GPU module if one was built previously.
57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
58  for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
59  return op; // existing
60  markAsGPUContainer(topModule);
61  builder.setInsertionPointToStart(topModule.getBody());
62  return gpu::GPUModuleOp::create(builder, topModule->getLoc(),
63  "sparse_kernels");
64 }
65 
66 /// Constructs a new GPU kernel in the given GPU module.
67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
68  SmallVectorImpl<Value> &args) {
69  // Get a unique kernel name. Not very creative,
70  // but we simply try kernel0, kernel1, etc.
71  unsigned kernelNumber = 0;
72  SmallString<16> kernelName;
73  do {
74  kernelName.clear();
75  ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
76  } while (gpuModule.lookupSymbol(kernelName));
77  // Then we insert a new kernel with given arguments into the module.
78  builder.setInsertionPointToStart(gpuModule.getBody());
79  SmallVector<Type> argsTp;
80  for (auto arg : args)
81  argsTp.push_back(arg.getType());
82  FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
83  auto gpuFunc =
84  gpu::GPUFuncOp::create(builder, gpuModule->getLoc(), kernelName, type);
85  gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
86  builder.getUnitAttr());
87  return gpuFunc;
88 }
89 
90 /// Constructs code to launch GPU kernel.
91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
93  SmallVectorImpl<Value> &tokens,
94  unsigned numThreads) {
95  Location loc = gpuFunc->getLoc();
97  Value one = constantIndex(builder, loc, 1);
98  Value numT = constantIndex(builder, loc, numThreads);
99  gpu::KernelDim3 gridSize = {one, one, one};
100  gpu::KernelDim3 blckSize = {numT, one, one};
101  return gpu::LaunchFuncOp::create(builder, loc, gpuFunc, gridSize, blckSize,
102  /*dynSharedMemSz*/ none, args,
103  builder.getType<gpu::AsyncTokenType>(),
104  tokens)
105  .getAsyncToken();
106 }
107 
108 /// Maps the provided ranked host buffer into the device address space.
109 /// Writes from the host are guaranteed to be visible to device kernels
110 /// that are launched afterwards. Writes from the device are guaranteed
111 /// to be visible on the host after synchronizing with the device kernel
112 /// completion. Needs to cast the buffer to a unranked buffer.
113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
114  Value mem) {
115  MemRefType memTp = cast<MemRefType>(mem.getType());
116  UnrankedMemRefType resTp =
117  UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
118  Value cast = memref::CastOp::create(builder, loc, resTp, mem);
119  gpu::HostRegisterOp::create(builder, loc, cast);
120  return cast;
121 }
122 
123 /// Unmaps the provided buffer, expecting the casted buffer.
124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
125  Value cast) {
126  gpu::HostUnregisterOp::create(builder, loc, cast);
127 }
128 
129 /// Generates first wait in an asynchronous chain.
130 static Value genFirstWait(OpBuilder &builder, Location loc) {
131  Type tokenType = builder.getType<gpu::AsyncTokenType>();
132  return gpu::WaitOp::create(builder, loc, tokenType, ValueRange())
133  .getAsyncToken();
134 }
135 
136 /// Generates last, blocking wait in an asynchronous chain.
137 static void genBlockingWait(OpBuilder &builder, Location loc,
138  ValueRange operands) {
139  gpu::WaitOp::create(builder, loc, Type(), operands);
140 }
141 
142 /// Allocates memory on the device.
143 /// TODO: A `host_shared` attribute could be used to indicate that
144 /// the buffer is visible by both host and device, but lowering
145 /// that feature does not seem to be fully supported yet.
146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
147  Value token) {
148  auto tp = cast<ShapedType>(mem.getType());
149  auto elemTp = tp.getElementType();
150  auto shape = tp.getShape();
151  auto memTp = MemRefType::get(shape, elemTp);
152  SmallVector<Value> dynamicSizes;
153  for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
154  if (shape[r] == ShapedType::kDynamic) {
155  Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
156  dynamicSizes.push_back(dimOp);
157  }
158  }
159  return gpu::AllocOp::create(builder, loc, TypeRange({memTp, token.getType()}),
160  token, dynamicSizes, ValueRange());
161 }
162 
163 // Allocates a typed buffer on the host with given size.
164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
165  Value size) {
166  const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
167  return memref::AllocOp::create(builder, loc, memTp, size).getResult();
168 }
169 
170 // Allocates a typed buffer on the device with given size.
171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
172  Value size, Value token) {
173  const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
174  return gpu::AllocOp::create(builder, loc, TypeRange({memTp, token.getType()}),
175  token, size, ValueRange());
176 }
177 
178 // Allocates a void buffer on the device with given size.
179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
180  Value token) {
181  return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
182 }
183 
184 /// Deallocates memory from the device.
185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
186  Value token) {
187  return gpu::DeallocOp::create(builder, loc, token.getType(), token, mem)
188  .getAsyncToken();
189 }
190 
191 /// Copies memory between host and device (direction is implicit).
192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
193  Value src, Value token) {
194  return gpu::MemcpyOp::create(builder, loc, token.getType(), token, dst, src)
195  .getAsyncToken();
196 }
197 
198 /// Generates an alloc/copy pair.
199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
200  SmallVectorImpl<Value> &tokens) {
201  Value firstToken = genFirstWait(builder, loc);
202  auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203  Value devMem = alloc.getResult(0);
204  Value depToken = alloc.getAsyncToken(); // copy-after-alloc
205  tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
206  return devMem;
207 }
208 
209 /// Generates a memref from tensor operation.
210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
211  Value tensor) {
212  auto tensorType = llvm::cast<ShapedType>(tensor.getType());
213  auto memrefType =
214  MemRefType::get(tensorType.getShape(), tensorType.getElementType());
215  return bufferization::ToBufferOp::create(rewriter, loc, memrefType, tensor);
216 }
217 
218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we
219 /// assume that the first buffer is the one allocated for output. We create
220 /// a set of properly chained asynchronous allocation/copy pairs to increase
221 /// overlap before launching the kernel.
222 static Value genParametersIn(OpBuilder &builder, Location loc,
223  SmallVectorImpl<Value> &scalars,
224  SmallVectorImpl<Value> &buffers,
226  SmallVectorImpl<Value> &tokens,
227  bool useHostRegistrationForOut) {
228  Value out;
229  // Scalars are passed by value.
230  for (Value s : scalars)
231  args.push_back(s);
232  // Buffers are need to be made visible on device.
233  for (Value b : buffers) {
234  if (useHostRegistrationForOut) {
235  out = genHostRegisterMemref(builder, loc, b);
236  args.push_back(b);
237  useHostRegistrationForOut = false;
238  continue;
239  }
240  args.push_back(genAllocCopy(builder, loc, b, tokens));
241  }
242  return out;
243 }
244 
245 /// Finalizes the outlined arguments. The output buffer is copied depending
246 /// on the kernel token and then deallocated. All other buffers are simply
247 /// deallocated. Then we wait for all operations to complete.
248 static void genParametersOut(OpBuilder &builder, Location loc, Value out,
249  Value kernelToken, SmallVectorImpl<Value> &scalars,
250  SmallVectorImpl<Value> &buffers,
252  SmallVectorImpl<Value> &tokens) {
253  unsigned base = scalars.size();
254  for (unsigned i = base, e = args.size(); i < e; i++) {
255  Value firstToken;
256  if (i == base) {
257  // Assumed output parameter: unregister or copy-out.
258  if (out) {
259  genHostUnregisterMemref(builder, loc, out);
260  out = Value();
261  continue;
262  }
263  firstToken =
264  genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
265  } else {
266  firstToken = genFirstWait(builder, loc);
267  }
268  tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
269  }
270 }
271 
272 /// Constructs code for new GPU kernel.
273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
274  scf::ParallelOp forallOp,
275  SmallVectorImpl<Value> &constants,
276  SmallVectorImpl<Value> &scalars,
277  SmallVectorImpl<Value> &buffers) {
278  Location loc = gpuFunc->getLoc();
279  Block &block = gpuFunc.getBody().front();
280  rewriter.setInsertionPointToStart(&block);
281 
282  // Re-generate the constants, recapture all arguments.
283  unsigned arg = 0;
284  IRMapping irMap;
285  for (Value c : constants)
286  irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
287  for (Value s : scalars)
288  irMap.map(s, block.getArgument(arg++));
289  for (Value b : buffers)
290  irMap.map(b, block.getArgument(arg++));
291 
292  // Assume 1-dimensional grid/block configuration (only x dimension),
293  // so that:
294  // row = blockIdx.x * blockDim.x + threadIdx.x
295  // inc = blockDim.x * gridDim.x
296  Value bid = gpu::BlockIdOp::create(rewriter, loc, gpu::Dimension::x);
297  Value bsz = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x);
298  Value tid = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
299  Value gsz = gpu::GridDimOp::create(rewriter, loc, gpu::Dimension::x);
300  Value mul = arith::MulIOp::create(rewriter, loc, bid, bsz);
301  Value row = arith::AddIOp::create(rewriter, loc, mul, tid);
302  Value inc = arith::MulIOp::create(rewriter, loc, bsz, gsz);
303 
304  // Construct the iteration over the computational space that
305  // accounts for the fact that the total number of threads and
306  // the amount of work to be done usually do not match precisely.
307  // for (r = row; r < N; r += inc) {
308  // <loop-body>
309  // }
310  Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
311  scf::ForOp forOp = scf::ForOp::create(rewriter, loc, row, upper, inc);
312  // The scf.for builder creates an empty block. scf.for does not allow multiple
313  // blocks in its region, so delete the block before `cloneRegionBefore` adds
314  // an additional block.
315  rewriter.eraseBlock(forOp.getBody());
316  rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
317  forOp.getRegion().begin(), irMap);
318  // Replace the scf.reduce terminator.
319  rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
320  rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
321 
322  // Done.
323  rewriter.setInsertionPointAfter(forOp);
324  gpu::ReturnOp::create(rewriter, gpuFunc->getLoc());
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // Library helper methods.
329 //===----------------------------------------------------------------------===//
330 
331 /// Helper to detect a + b with arguments taken from given block.
332 static bool matchAddOfArgs(Block *block, Value val) {
333  if (auto *def = val.getDefiningOp()) {
334  if (isa<arith::AddFOp, arith::AddIOp>(def)) {
335  Value a = block->getArguments()[0];
336  Value b = block->getArguments()[1];
337  return (def->getOperand(0) == a && def->getOperand(1) == b) ||
338  (def->getOperand(0) == b && def->getOperand(1) == a);
339  }
340  }
341  return false;
342 }
343 
344 /// Helper to detect a * b with arguments taken from given block.
345 static bool matchMulOfArgs(Block *block, Value val) {
346  if (auto *def = val.getDefiningOp()) {
347  if (isa<arith::MulFOp, arith::MulIOp>(def)) {
348  Value a = block->getArguments()[0];
349  Value b = block->getArguments()[1];
350  return (def->getOperand(0) == a && def->getOperand(1) == b) ||
351  (def->getOperand(0) == b && def->getOperand(1) == a);
352  }
353  }
354  return false;
355 }
356 
357 /// Helper to detect x = x + a * b
358 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
359  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
360  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
361  if (isa<arith::AddFOp, arith::AddIOp>(def)) {
362  Value x = op.getBlock()->getArguments()[2];
363  return (def->getOperand(0) == x &&
364  matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
365  (def->getOperand(1) == x &&
366  matchMulOfArgs(op.getBlock(), def->getOperand(0)));
367  }
368  }
369  return false;
370 }
371 
372 // Helper to detect c += spy(s) x (a * b)
373 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
374  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
375  // The linalg yields a custom reduce result.
376  Value s_out = op.getBlock()->getArguments()[2];
377  if (auto redOp =
378  yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
379  // The reduce consumes the output.
380  Value other;
381  if (s_out == redOp->getOperand(0))
382  other = redOp->getOperand(1);
383  else if (s_out == redOp->getOperand(1))
384  other = redOp->getOperand(0);
385  else
386  return false;
387  // The reduce op also consumes an unary which also consumes the output
388  // and does not define an absent value.
389  if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
390  if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
391  return false;
392  // And the bodies are as expected.
393  auto yieldUn = cast<sparse_tensor::YieldOp>(
394  unOp.getRegion(0).front().getTerminator());
395  auto yieldRed = cast<sparse_tensor::YieldOp>(
396  redOp.getRegion().front().getTerminator());
397  return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
398  matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
399  }
400  }
401  return false;
402 }
403 
404 /// Test for dense tensor.
405 static bool isDenseTensor(Value v) {
406  auto sTp = getSparseTensorType(v);
407  return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
408 }
409 
410 /// Test for suitable positions/coordinates width.
411 static bool isAdmissibleMetaData(SparseTensorType &aTp) {
412  return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
413  (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
414 }
415 
416 /// Test for sorted COO matrix with suitable metadata.
417 static bool isAdmissibleCOO(SparseTensorType &aTp) {
418  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
419  aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
420  aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
421  isAdmissibleMetaData(aTp);
422 }
423 
424 /// Test for CSR matrix with suitable metadata.
425 static bool isAdmissibleCSR(SparseTensorType &aTp) {
426  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
427  aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
428  aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
429 }
430 
431 /// Test for CSC matrix with suitable metadata.
432 static bool isAdmissibleCSC(SparseTensorType &aTp) {
433  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
434  aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
435  aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
436 }
437 
438 /// Test for BSR matrix with suitable metadata.
439 static bool isAdmissibleBSR(SparseTensorType &aTp) {
440  if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
441  aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
442  aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
443  // CuSparse only supports "square" blocks currently.
445  assert(dims.size() == 2);
446  return dims[0] == dims[1] && dims[0] > 1;
447  }
448  return false;
449 }
450 
451 /// Test for 2:4 matrix with suitable metadata.
452 static bool isAdmissible24(SparseTensorType &aTp) {
453  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
454  aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
455 }
456 
457 /// Test for conversion into 2:4 matrix.
458 static bool isConversionInto24(Value v) {
459  if (auto cnv = v.getDefiningOp<ConvertOp>()) {
460  Value a = cnv.getResult();
461  Value d = cnv.getSource();
463  return isDenseTensor(d) && isAdmissible24(aTp);
464  }
465  return false;
466 }
467 
468 /// Returns a suitable sparse format for the operation and given operand
469 /// types with cuSparse, or kNone if none is available.
470 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
471  SparseTensorType bTp,
472  SparseTensorType cTp, bool enableRT,
473  bool isMatVec) {
474  // The other operands have a dense type.
475  if (bTp.hasEncoding() || cTp.hasEncoding())
476  return CuSparseFormat::kNone;
477  // Now check for suitable operand type for the main operand.
478  if (isAdmissibleCOO(aTp))
479 #ifdef CUSPARSE_COO_AOS
480  return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
481 #else
482  return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
483 #endif
484  if (isAdmissibleCSR(aTp))
485  return CuSparseFormat::kCSR;
486  if (isAdmissibleCSC(aTp))
487  return CuSparseFormat::kCSC;
488  if (isAdmissibleBSR(aTp))
489  return CuSparseFormat::kBSR;
490  return CuSparseFormat::kNone;
491 }
492 
493 /// Generates the first positions/coordinates of a sparse matrix.
494 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
495  CuSparseFormat format, bool enableRT) {
496  if (format == CuSparseFormat::kCOO) {
497  // Library uses SoA COO, direct IR uses AoS COO.
498  if (enableRT)
499  return ToCoordinatesOp::create(builder, loc, a, 0);
500  return ToCoordinatesBufferOp::create(builder, loc, a);
501  }
502  // Formats CSR/CSC and BSR use positions at 1.
503  return ToPositionsOp::create(builder, loc, a, 1);
504 }
505 
506 /// Generates the second coordinates of a sparse matrix.
507 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
508  CuSparseFormat format, bool enableRT) {
509  bool isCOO = format == CuSparseFormat::kCOO;
510  if (isCOO && !enableRT)
511  return Value(); // nothing needed
512  // Formats CSR/CSC and BSR use coordinates at 1.
513  return ToCoordinatesOp::create(builder, loc, a, 1);
514 }
515 
516 /// Generates the sparse matrix handle.
517 static Operation *genSpMat(OpBuilder &builder, Location loc,
518  SparseTensorType &aTp, Type handleTp, Type tokenTp,
519  Value token, Value sz1, Value sz2, Value nseA,
520  Value rowA, Value colA, Value valA,
521  CuSparseFormat format, bool enableRT) {
522  if (format == CuSparseFormat::kCOO) {
523  // Library uses SoA COO, direct IR uses AoS COO.
524  if (enableRT) {
525  assert(colA);
526  return gpu::CreateCooOp::create(builder, loc, handleTp, tokenTp, token,
527  sz1, sz2, nseA, rowA, colA, valA);
528  }
529 #ifdef CUSPARSE_COO_AOS
530  assert(!colA);
531  return gpu::CreateCooAoSOp::create(builder, loc, handleTp, tokenTp, token,
532  sz1, sz2, nseA, rowA, valA);
533 #else
534  llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
535 #endif
536  }
537  assert(colA);
538  if (format == CuSparseFormat::kCSR)
539  return gpu::CreateCsrOp::create(builder, loc, handleTp, tokenTp, token, sz1,
540  sz2, nseA, rowA, colA, valA);
541  if (format == CuSparseFormat::kCSC)
542  return gpu::CreateCscOp::create(builder, loc, handleTp, tokenTp, token, sz1,
543  sz2, nseA, rowA, colA, valA);
544  // BSR requires a bit more work since we need to pass in the block size
545  // and all others sizes in terms of blocks (#block-rows, #block-cols,
546  // #nonzero-blocks).
547  assert(format == CuSparseFormat::kBSR);
549  assert(dims.size() == 2 && dims[0] == dims[1]);
550  uint64_t b = dims[0];
551  Value bSz = constantIndex(builder, loc, b);
552  Value bRows = arith::DivUIOp::create(builder, loc, sz1, bSz);
553  Value bCols = arith::DivUIOp::create(builder, loc, sz2, bSz);
554  Value bNum = arith::DivUIOp::create(builder, loc, nseA,
555  constantIndex(builder, loc, b * b));
556  return gpu::CreateBsrOp::create(builder, loc, handleTp, tokenTp, token, bRows,
557  bCols, bNum, bSz, bSz, rowA, colA, valA);
558 }
559 
560 /// Match and rewrite SpMV kernel.
561 static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
562  linalg::GenericOp op, bool enableRT) {
563  Location loc = op.getLoc();
564  Value a = op.getOperand(0);
565  Value x = op.getOperand(1);
566  Value y = op.getOperand(2); // we have y = Ax
567  SmallVector<Value> tokens;
568 
569  // Only admissible sparse matrix format and dense vectors (no BSR).
573  auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
574  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
575  return failure();
576 
577  // Start sparse kernel and copy data from host to device.
578  // a : memR/memC/memV -> rowA,colA,valA
579  // x : memX -> vecX
580  // y : memY -> vecY
581  Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
582  Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
583  Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
584  Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
585  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
586  Value memV = ToValuesOp::create(rewriter, loc, a);
587  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
588  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
589  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
590  Value memX = genTensorToMemref(rewriter, loc, x);
591  Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
592  Value memY = genTensorToMemref(rewriter, loc, y);
593  Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
594  genBlockingWait(rewriter, loc, tokens);
595  tokens.clear();
596 
597  // Create sparse environment and sparse matrix/dense vector handles.
598  Type indexTp = rewriter.getIndexType();
599  Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
600  Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
601  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
602  Value token = genFirstWait(rewriter, loc);
603  Operation *spGenA =
604  genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
605  nseA, rowA, colA, valA, format, enableRT);
606  Value spMatA = spGenA->getResult(0);
607  token = spGenA->getResult(1);
608  auto dvecX = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp,
609  tokenTp, token, vecX, szX);
610  Value dnX = dvecX.getResult(0);
611  token = dvecX.getAsyncToken();
612  auto dvecY = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp,
613  tokenTp, token, vecY, szY);
614  Value dnY = dvecY.getResult(0);
615  token = dvecY.getAsyncToken();
616  auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
617 
618  // Precompute buffersize for SpMV.
619  auto bufferComp = gpu::SpMVBufferSizeOp::create(
620  rewriter, loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
621  /*computeType=*/dnYType);
622  Value bufferSz = bufferComp.getResult(0);
623  token = bufferComp.getAsyncToken();
624  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
625  Value buffer = buf.getResult(0);
626  token = buf.getAsyncToken();
627 
628  // Perform the SpMV.
629  auto spmvComp =
630  gpu::SpMVOp::create(rewriter, loc, tokenTp, token, spMatA, dnX, dnY,
631  /*computeType=*/dnYType, buffer);
632  token = spmvComp.getAsyncToken();
633 
634  // Copy data back to host and free all the resoures.
635  token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
636  .getAsyncToken();
637  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnX)
638  .getAsyncToken();
639  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnY)
640  .getAsyncToken();
641  token = genDeallocMemRef(rewriter, loc, rowA, token);
642  if (colA)
643  token = genDeallocMemRef(rewriter, loc, colA, token);
644  token = genDeallocMemRef(rewriter, loc, valA, token);
645  token = genDeallocMemRef(rewriter, loc, buffer, token);
646  token = genDeallocMemRef(rewriter, loc, vecX, token);
647  token = genCopyMemRef(rewriter, loc, memY, vecY, token);
648  token = genDeallocMemRef(rewriter, loc, vecY, token);
649  tokens.push_back(token);
650  genBlockingWait(rewriter, loc, tokens);
651  tokens.clear();
652 
653  // Done.
654  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, y.getType(), memY);
655  return success();
656 }
657 
658 /// Match and rewrite SpMM kernel.
659 static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
660  linalg::GenericOp op, bool enableRT) {
661  Location loc = op.getLoc();
662  Value a = op.getOperand(0);
663  Value b = op.getOperand(1);
664  Value c = op.getOperand(2); // we have C = AB
665  SmallVector<Value> tokens;
666 
667  // Only admissible sparse matrix format and dense matrices (no BSR).
671  auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
672  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
673  return failure();
674 
675  // Start sparse kernel and copy data from host to device.
676  // a : memR/memC/memV -> rowA,colA,valA
677  // b : bufB -> matB
678  // c : bufC -> matC
679  Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
680  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
681  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
682  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
683  Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
684  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
685  Value memV = ToValuesOp::create(rewriter, loc, a);
686  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
687  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
688  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
689  Value bufB = genTensorToMemref(rewriter, loc, b);
690  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
691  Value bufC = genTensorToMemref(rewriter, loc, c);
692  Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
693  genBlockingWait(rewriter, loc, tokens);
694  tokens.clear();
695 
696  // Create sparse environment and sparse matrix/dense matrix handles.
697  Type indexTp = rewriter.getIndexType();
698  Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
699  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
700  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
701  Value token = genFirstWait(rewriter, loc);
702  Operation *spGenA =
703  genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
704  nseA, rowA, colA, valA, format, enableRT);
705  Value spMatA = spGenA->getResult(0);
706  token = spGenA->getResult(1);
707  auto dmatB =
708  gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
709  token, matB, SmallVector<Value>{szk, szn});
710  Value dnB = dmatB.getResult(0);
711  token = dmatB.getAsyncToken();
712  auto dmatC =
713  gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
714  token, matC, SmallVector<Value>{szm, szn});
715  Value dnC = dmatC.getResult(0);
716  token = dmatC.getAsyncToken();
717  auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
718 
719  // Precompute buffersize for SpMM.
720  auto bufferComp = gpu::SpMMBufferSizeOp::create(
721  rewriter, loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
722  /*computeType=*/dmatCType);
723  Value bufferSz = bufferComp.getResult(0);
724  token = bufferComp.getAsyncToken();
725  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
726  Value buffer = buf.getResult(0);
727  token = buf.getAsyncToken();
728  auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
729 
730  // Perform the SpMM.
731  auto spmmComp =
732  gpu::SpMMOp::create(rewriter, loc, tokenTp, token, spMatA, dnB, dnC,
733  /*computeType=*/dnCType, buffer);
734  token = spmmComp.getAsyncToken();
735 
736  // Copy data back to host and free all the resoures.
737  token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
738  .getAsyncToken();
739  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
740  .getAsyncToken();
741  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC)
742  .getAsyncToken();
743  token = genDeallocMemRef(rewriter, loc, rowA, token);
744  if (colA)
745  token = genDeallocMemRef(rewriter, loc, colA, token);
746  token = genDeallocMemRef(rewriter, loc, valA, token);
747  token = genDeallocMemRef(rewriter, loc, buffer, token);
748  token = genDeallocMemRef(rewriter, loc, matB, token);
749  token = genCopyMemRef(rewriter, loc, bufC, matC, token);
750  token = genDeallocMemRef(rewriter, loc, matC, token);
751  tokens.push_back(token);
752  genBlockingWait(rewriter, loc, tokens);
753  tokens.clear();
754 
755  // Done.
756  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, c.getType(), bufC);
757  return success();
758 }
759 
760 // Match and rewrite SpGEMM kernel.
761 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
762  linalg::GenericOp op, bool enableRT) {
763  Location loc = op.getLoc();
764  Value a = op.getOperand(0);
765  Value b = op.getOperand(1);
766  Value c = op.getOperand(2); // we have C = AB
767  SmallVector<Value> tokens;
768 
769  // Only CSR <- CSR x CSR supported.
770  auto format = CuSparseFormat::kCSR;
774  if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
775  return failure();
776 
777  // Start sparse kernel and copy data from host to device.
778  // a : amemR/amemC/amemV -> rowA,colA,valA
779  // b : bmemR/bmemC/bmemV -> rowB,colB,valB
780  // c : materializes
781  auto dnCType = cTp.getElementType();
782  Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
783  Value nseB = NumberOfEntriesOp::create(rewriter, loc, b);
784  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
785  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
786  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
787  Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
788  Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
789  Value amemV = ToValuesOp::create(rewriter, loc, a);
790  Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
791  Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
792  Value bmemV = ToValuesOp::create(rewriter, loc, b);
793  Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
794  Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
795  Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
796  Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
797  Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
798  Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
799  genBlockingWait(rewriter, loc, tokens);
800  tokens.clear();
801 
802  // Create sparse environment and sparse matrix/dense vector handles.
803  Type indexTp = rewriter.getIndexType();
804  Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
805  Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
806  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
807  Value token = genFirstWait(rewriter, loc);
808  Operation *spGenA =
809  genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
810  nseA, rowA, colA, valA, format, enableRT);
811  Value spMatA = spGenA->getResult(0);
812  token = spGenA->getResult(1);
813  Operation *spGenB =
814  genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
815  nseB, rowB, colB, valB, format, enableRT);
816  Value spMatB = spGenB->getResult(0);
817  token = spGenB->getResult(1);
818 
819  // Sparse matrix C materializes (also assumes beta == 0).
820  Value zero = constantIndex(rewriter, loc, 0);
821  Value one = constantIndex(rewriter, loc, 1);
822  Value mplus1 = arith::AddIOp::create(rewriter, loc, szm, one);
823  auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
824  Value rowC = e1.getResult(0);
825  token = e1.getAsyncToken();
826  auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
827  Value colC = e2.getResult(0); // no free needed
828  token = e2.getAsyncToken();
829  auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
830  Value valC = e3.getResult(0); // no free needed
831  token = e3.getAsyncToken();
832  Operation *spGenC =
833  genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
834  zero, rowC, colC, valC, format, enableRT);
835  Value spMatC = spGenC->getResult(0);
836  token = spGenC->getResult(1);
837 
838  // Precompute buffersizes for SpGEMM.
839  Operation *descOp =
840  gpu::SpGEMMCreateDescrOp::create(rewriter, loc, descTp, tokenTp, token);
841  Value desc = descOp->getResult(0);
842  token = descOp->getResult(1);
843  Operation *work1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
844  rewriter, loc, indexTp, tokenTp, token, desc,
845  gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
846  spMatA, spMatB, spMatC, dnCType, zero, valC,
847  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
848  Value bufferSz1 = work1->getResult(0);
849  token = work1->getResult(1);
850  auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
851  Value buffer1 = buf1.getResult(0);
852  token = buf1.getAsyncToken();
853  Operation *work2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
854  rewriter, loc, indexTp, tokenTp, token, desc,
855  gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
856  spMatA, spMatB, spMatC, dnCType, bufferSz1, buffer1,
857  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
858  token = work2->getResult(1);
859 
860  // Compute step.
861  Operation *compute1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
862  rewriter, loc, indexTp, tokenTp, token, desc,
863  gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
864  spMatA, spMatB, spMatC, dnCType, zero, valC,
865  gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
866  Value bufferSz2 = compute1->getResult(0);
867  token = compute1->getResult(1);
868  auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
869  Value buffer2 = buf2.getResult(0);
870  token = buf2.getAsyncToken();
871  Operation *compute2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
872  rewriter, loc, indexTp, tokenTp, token, desc,
873  gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
874  spMatA, spMatB, spMatC, dnCType, bufferSz2, buffer2,
875  gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
876  token = compute2->getResult(1);
877 
878  // Get sizes.
879  Operation *sizes = gpu::SpMatGetSizeOp::create(
880  rewriter, loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
881  Value nnz = sizes->getResult(2);
882  token = sizes->getResult(3);
883  auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
884  colC = a2.getResult(0);
885  token = a2.getAsyncToken();
886  auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
887  valC = a3.getResult(0);
888  token = a3.getAsyncToken();
889 
890  // Update C with new pointers and copy final product back into C.
891  Operation *update = gpu::SetCsrPointersOp::create(
892  rewriter, loc, tokenTp, token, spMatC, rowC, colC, valC);
893  token = update->getResult(0);
894  Operation *copy = gpu::SpGEMMCopyOp::create(
895  rewriter, loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
896  gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
897  token = copy->getResult(0);
898 
899  // Allocate buffers on host.
900  Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
901  Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
902  Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
903 
904  // Copy data back to host and free all the resoures.
905  token = gpu::SpGEMMDestroyDescrOp::create(rewriter, loc, tokenTp, token, desc)
906  .getAsyncToken();
907  token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
908  .getAsyncToken();
909  token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatB)
910  .getAsyncToken();
911  token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC)
912  .getAsyncToken();
913  token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
914  token = genCopyMemRef(rewriter, loc, colH, colC, token);
915  token = genCopyMemRef(rewriter, loc, valH, valC, token);
916  token = genDeallocMemRef(rewriter, loc, rowA, token);
917  token = genDeallocMemRef(rewriter, loc, colA, token);
918  token = genDeallocMemRef(rewriter, loc, valA, token);
919  token = genDeallocMemRef(rewriter, loc, rowB, token);
920  token = genDeallocMemRef(rewriter, loc, colB, token);
921  token = genDeallocMemRef(rewriter, loc, valB, token);
922  token = genDeallocMemRef(rewriter, loc, rowC, token);
923  token = genDeallocMemRef(rewriter, loc, colC, token);
924  token = genDeallocMemRef(rewriter, loc, valC, token);
925  token = genDeallocMemRef(rewriter, loc, buffer1, token);
926  token = genDeallocMemRef(rewriter, loc, buffer2, token);
927  tokens.push_back(token);
928  genBlockingWait(rewriter, loc, tokens);
929  tokens.clear();
930 
931  // Done.
932  Value vt = bufferization::ToTensorOp::create(
933  rewriter, loc, memref::getTensorTypeFromMemRefType(valH.getType()), valH);
934  Value rt = bufferization::ToTensorOp::create(
935  rewriter, loc, memref::getTensorTypeFromMemRefType(rowH.getType()), rowH);
936  Value ct = bufferization::ToTensorOp::create(
937  rewriter, loc, memref::getTensorTypeFromMemRefType(colH.getType()), colH);
938  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
939  vt);
940  return success();
941 }
942 
943 // Match and rewrite 2:4 SpMM kernel.
944 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
945  linalg::GenericOp op) {
946  Location loc = op.getLoc();
947  Value A = op.getOperand(0);
948  Value B = op.getOperand(1);
949  Value C = op.getOperand(2); // we have C = AB
950  SmallVector<Value> tokens;
951 
952  // The cuSparselt API currently only allows pruning and compression
953  // to occur on the device. So we recognize the pattern
954  // A' = convert A ; dense to 2:4
955  // C = A'B ; 2:4 matrix mult
956  // and then perform compression and matrix multiplication on device.
957  auto cnv = A.getDefiningOp<ConvertOp>();
958  assert(cnv);
959  A = cnv.getSource();
960 
961  // All input should be dense tensors.
962  if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
963  return failure();
964 
965  // Start sparse kernel and copy data from host to device.
966  // a : bufA -> matA
967  // b : bufB -> matB
968  // c : bufC -> matC
969  Value bufA = genTensorToMemref(rewriter, loc, A);
970  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
971  Value bufB = genTensorToMemref(rewriter, loc, B);
972  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
973  Value bufC = genTensorToMemref(rewriter, loc, C);
974  Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
975  genBlockingWait(rewriter, loc, tokens);
976  tokens.clear();
977 
978  // Create sparse environment and sparse matrix/dense vector handles.
979  Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
980  Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
981  Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
982  Type indexTp = rewriter.getIndexType();
983  Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
984  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
985  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
986  Value token = genFirstWait(rewriter, loc);
987  Operation *spGenA = gpu::Create2To4SpMatOp::create(
988  rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk,
989  gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
990  Value spMatA = spGenA->getResult(0);
991  token = spGenA->getResult(1);
992  auto dmatB =
993  gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
994  token, matB, SmallVector<Value>{szk, szn});
995  Value dnB = dmatB.getResult(0);
996  token = dmatB.getAsyncToken();
997  auto dmatC =
998  gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
999  token, matC, SmallVector<Value>{szm, szn});
1000  Value dnC = dmatC.getResult(0);
1001  token = dmatC.getAsyncToken();
1002  auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1003 
1004  // Precompute buffersize for SpMM.
1005  SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
1006  TypeRange bufferTypes(bufferTypes_);
1007  auto bufferComp = gpu::SpMMBufferSizeOp::create(
1008  rewriter, loc, bufferTypes, tokenTp, token,
1009  gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
1010  spMatA, dnB, dnC,
1011  /*computeType=*/dmatCType);
1012  token = bufferComp.getAsyncToken();
1013 
1014  // Allocate buffers on host.
1015  Value bufferSz1 = bufferComp.getResult(0);
1016  auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
1017  Value buffer1 = buf1.getResult(0);
1018  token = buf1.getAsyncToken();
1019  Value bufferSz2 = bufferComp.getResult(1);
1020  auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
1021  Value buffer2 = buf2.getResult(0);
1022  token = buf2.getAsyncToken();
1023  Value bufferSz3 = bufferComp.getResult(2);
1024  auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
1025  Value buffer3 = buf3.getResult(0);
1026  token = buf3.getAsyncToken();
1027 
1028  // Perform the SpMM.
1029  auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1030  auto spmmComp = gpu::SpMMOp::create(
1031  rewriter, loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
1032  SmallVector<Value>{buffer1, buffer2, buffer3});
1033  token = spmmComp.getAsyncToken();
1034 
1035  // Copy data back to host and free all the resources.
1036  token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
1037  .getAsyncToken();
1038  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
1039  .getAsyncToken();
1040  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC)
1041  .getAsyncToken();
1042  token = genDeallocMemRef(rewriter, loc, buffer1, token);
1043  token = genDeallocMemRef(rewriter, loc, buffer2, token);
1044  token = genDeallocMemRef(rewriter, loc, buffer3, token);
1045  token = genDeallocMemRef(rewriter, loc, matA, token);
1046  token = genDeallocMemRef(rewriter, loc, matB, token);
1047  token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1048  token = genDeallocMemRef(rewriter, loc, matC, token);
1049  tokens.push_back(token);
1050  genBlockingWait(rewriter, loc, tokens);
1051  tokens.clear();
1052 
1053  // Done.
1054  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, C.getType(), bufC);
1055  return success();
1056 }
1057 
1058 /// Match and rewrite SDDMM kernel.
1059 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
1060  linalg::GenericOp op, bool enableRT) {
1061  Location loc = op.getLoc();
1062  Value a = op.getOperand(0);
1063  Value b = op.getOperand(1);
1064  Value c = op.getOperand(2);
1065  SmallVector<Value> tokens;
1066 
1067  // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1071  auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
1072  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1073  format == CuSparseFormat::kCSC)
1074  return failure();
1075 
1076  // The SDDMM does the in-place operation.
1077  // Start sparse kernel and copy data from host to device.
1078  // a : bufA -> matA
1079  // b : bufB -> matB
1080  // c : memR/memC/memV -> rowC,colC,valC
1081  Value nseC = NumberOfEntriesOp::create(rewriter, loc, c);
1082  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
1083  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
1084  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
1085  Value bufA = genTensorToMemref(rewriter, loc, a);
1086  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
1087  Value bufB = genTensorToMemref(rewriter, loc, b);
1088  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
1089  Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1090  Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
1091  Value memV = ToValuesOp::create(rewriter, loc, c);
1092  Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1093  Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
1094  Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1095  genBlockingWait(rewriter, loc, tokens);
1096  tokens.clear();
1097 
1098  // Create sparse environment and sparse matrix/dense matrix handles.
1099  Type indexTp = rewriter.getIndexType();
1100  Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1101  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1102  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1103  Value token = genFirstWait(rewriter, loc);
1104  auto dmatA =
1105  gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp,
1106  token, matA, SmallVector<Value>{szm, szk});
1107  Value dnA = dmatA.getResult(0);
1108  token = dmatA.getAsyncToken();
1109  auto dmatB =
1110  gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp,
1111  token, matB, SmallVector<Value>{szk, szn});
1112  Value dnB = dmatB.getResult(0);
1113  token = dmatB.getAsyncToken();
1114  Operation *spGenC =
1115  genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1116  nseC, rowC, colC, valC, format, enableRT);
1117  Value spMatC = spGenC->getResult(0);
1118  token = spGenC->getResult(1);
1119  auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1120 
1121  // Precompute buffersize for SDDMM.
1122  auto bufferComp = gpu::SDDMMBufferSizeOp::create(
1123  rewriter, loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1124  Value bufferSz = bufferComp.getResult(0);
1125  token = bufferComp.getAsyncToken();
1126  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1127  Value buffer = buf.getResult(0);
1128  token = buf.getAsyncToken();
1129 
1130  // Perform the SDDMM.
1131  auto sddmmComp = gpu::SDDMMOp::create(rewriter, loc, tokenTp, token, dnA, dnB,
1132  spMatC, dnCType, buffer);
1133  token = sddmmComp.getAsyncToken();
1134 
1135  // Copy data back to host and free all the resoures.
1136  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnA)
1137  .getAsyncToken();
1138  token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
1139  .getAsyncToken();
1140  token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC)
1141  .getAsyncToken();
1142  token = genDeallocMemRef(rewriter, loc, buffer, token);
1143  token = genDeallocMemRef(rewriter, loc, matA, token);
1144  token = genDeallocMemRef(rewriter, loc, matB, token);
1145  token = genDeallocMemRef(rewriter, loc, rowC, token);
1146  if (colC)
1147  token = genDeallocMemRef(rewriter, loc, colC, token);
1148  token = genCopyMemRef(rewriter, loc, memV, valC, token);
1149  token = genDeallocMemRef(rewriter, loc, valC, token);
1150  tokens.push_back(token);
1151  genBlockingWait(rewriter, loc, tokens);
1152  tokens.clear();
1153 
1154  // Done.
1155  rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1156  return success();
1157 }
1158 
1159 //===----------------------------------------------------------------------===//
1160 // Rewriting rules for direct code generation.
1161 //===----------------------------------------------------------------------===//
1162 
1163 /// Proof-of-concept rewriter. This rule generates a GPU implementation
1164 /// for each outermost forall loop generated by the sparsifier.
1165 /// TODO: right now works with parallelization-strategy=dense-outer-loop
1166 /// but give this its own flags in the future
1167 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1169 
1170  ForallRewriter(MLIRContext *context, unsigned nT)
1171  : OpRewritePattern(context), numThreads(nT) {};
1172 
1173  LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1174  PatternRewriter &rewriter) const override {
1175  // Reject inadmissible loop form.
1176  // Essentially only accept a loop, generated by the sparsifier,
1177  // of the form
1178  // forall (i = 0; i < N; i++)
1179  // so that cyclic scheduling over the threads is easy.
1180  if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1181  forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1182  !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1183  !matchPattern(forallOp.getStep()[0], m_One()))
1184  return failure();
1185  // Collect every value that is computed outside the parallel loop.
1186  SetVector<Value> invariants; // stable iteration!
1187  forallOp->walk([&](Operation *op) {
1188  // Collect all values of admissible ops.
1189  for (OpOperand &o : op->getOpOperands()) {
1190  Value val = o.get();
1191  Block *block;
1192  if (auto arg = dyn_cast<BlockArgument>(val))
1193  block = arg.getOwner();
1194  else
1195  block = val.getDefiningOp()->getBlock();
1196  if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
1197  invariants.insert(val);
1198  }
1199  });
1200  // Outline the outside values as proper parameters. Fail when sharing
1201  // value between host and device is not straightforward.
1202  SmallVector<Value> constants;
1203  SmallVector<Value> scalars;
1204  SmallVector<Value> buffers;
1205  for (Value val : invariants) {
1206  Type tp = val.getType();
1207  if (val.getDefiningOp<arith::ConstantOp>())
1208  constants.push_back(val);
1209  else if (isa<FloatType>(tp) || tp.isIntOrIndex())
1210  scalars.push_back(val);
1211  else if (isa<MemRefType>(tp))
1212  buffers.push_back(val);
1213  else
1214  return failure(); // don't know how to share
1215  }
1216  // Pass outlined non-constant values.
1217  // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1218  // keep the feature at all (either through a heuristic or compiler
1219  // option for gpu codegen).
1220  Location loc = forallOp->getLoc();
1221  SmallVector<Value> args;
1222  SmallVector<Value> tokens;
1223  Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1224  /*useHostRegistrationForOut=*/false);
1225  // Set up GPU module and construct GPU function.
1226  auto saveIp = rewriter.saveInsertionPoint();
1227  ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1228  auto gpuModule = genGPUModule(rewriter, topModule);
1229  auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1230  genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1231  // Generate code that launches the kernel asynchronously, blocking on all
1232  // opens tokens and yielding a new token for the output.
1233  // TODO: Passing in tokens to launch up does not seem to be properly lowered
1234  // by cubin yet, hence the current blocking wait.
1235  rewriter.restoreInsertionPoint(saveIp);
1236  genBlockingWait(rewriter, loc, tokens);
1237  tokens.clear();
1238  Value kernelToken =
1239  genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1240  // Finalize the outlined arguments.
1241  genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1242  tokens);
1243  genBlockingWait(rewriter, loc, tokens);
1244  rewriter.eraseOp(forallOp);
1245  return success();
1246  }
1247 
1248 private:
1249  unsigned numThreads;
1250 };
1251 
1252 //===----------------------------------------------------------------------===//
1253 // Rewriting rules for library recognition and code generation.
1254 //===----------------------------------------------------------------------===//
1255 
1256 /// Proof-of-concept rewriter. This rule recognizes certain math kernels
1257 /// and replaces these with corresponding calls into a sparse library.
1258 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1260 
1261  LinalgOpRewriter(MLIRContext *context, bool rt)
1262  : OpRewritePattern(context), enableRT(rt) {}
1263 
1264  LogicalResult matchAndRewrite(linalg::GenericOp op,
1265  PatternRewriter &rewriter) const override {
1266  if (op.getNumDpsInits() != 1)
1267  return failure(); // reject multi-output
1268 
1269  const unsigned numLoops = op.getNumLoops();
1270  const unsigned numTensors = op->getNumOperands();
1271  const auto iteratorTypes = op.getIteratorTypesArray();
1272  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1273 
1274  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1275  auto infer = [&](MapList m) {
1276  return AffineMap::inferFromExprList(m, op.getContext());
1277  };
1278  AffineExpr i, j, k;
1279  bindDims(getContext(), i, j, k);
1280 
1281  // TODO: more robust patterns, transposed versions, more kernels,
1282  // identify alpha and beta and pass them to the CUDA calls.
1283 
1284  // Recognize a SpMV kernel.
1285  if (numLoops == 2 && numTensors == 3 &&
1286  linalg::isParallelIterator(iteratorTypes[0]) &&
1287  linalg::isReductionIterator(iteratorTypes[1]) &&
1288  maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1289  return rewriteSpMV(rewriter, op, enableRT);
1290  }
1291 
1292  // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1293  if (numLoops == 3 && numTensors == 3 &&
1294  linalg::isParallelIterator(iteratorTypes[0]) &&
1295  linalg::isParallelIterator(iteratorTypes[1]) &&
1296  linalg::isReductionIterator(iteratorTypes[2]) &&
1297  maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1298  if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1299  return rewriteSpGEMM(rewriter, op, enableRT);
1300  if (isConversionInto24(op.getOperand(0)))
1301  return rewrite2To4SpMM(rewriter, op);
1302  return rewriteSpMM(rewriter, op, enableRT);
1303  }
1304 
1305  // Recognize a SDDMM kernel.
1306  if (numLoops == 3 && numTensors == 3 &&
1307  linalg::isParallelIterator(iteratorTypes[0]) &&
1308  linalg::isParallelIterator(iteratorTypes[1]) &&
1309  linalg::isReductionIterator(iteratorTypes[2]) &&
1310  maps == infer({{i, k}, {k, j}, {i, j}}) &&
1311  matchSumReductionOfMulUnary(op)) {
1312  return rewriteSDDMM(rewriter, op, enableRT);
1313  }
1314 
1315  return failure();
1316  }
1317 
1318 private:
1319  bool enableRT;
1320 };
1321 
1322 } // namespace
1323 
1324 //===----------------------------------------------------------------------===//
1325 // Public method for populating GPU rewriting rules.
1326 //
1327 // Currently two set of rewriting rules are made available. The first set
1328 // implements direct code generation, currently by means of convering the
1329 // outermost paralell loop into GPU threads. The second set implements
1330 // libary recognition of a set of sparse operations. Eventually, the right
1331 // combination of these two approaches has to be found.
1332 //===----------------------------------------------------------------------===//
1333 
1335  unsigned numThreads) {
1336  patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
1337 }
1338 
1340  bool enableRT) {
1341  patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
1342 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
IndexType getIndexType()
Definition: Builders.cpp:50
IntegerType getI8Type()
Definition: Builders.cpp:58
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:383
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:388
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:575
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:243
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
bool isPermutation() const
Returns true if the dimToLvl mapping is a permutation.
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:235
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:95
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.