MLIR 23.0.0git
SparseGPUCodegen.cpp
Go to the documentation of this file.
1//===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is a prototype GPU codegenerator for the sparsifier.
10// The objective is to eventually use the right combination of
11// direct code generation and libary calls into vendor-specific
12// highly optimized sparse libraries (e.g. cuSparse for CUDA).
13//
14//===----------------------------------------------------------------------===//
15
16#include "Utils/CodegenUtils.h"
17#include "Utils/LoopEmitter.h"
18
28#include "mlir/IR/IRMapping.h"
29#include "mlir/IR/Matchers.h"
31#include "llvm/Support/Casting.h"
32
33using namespace mlir;
34using namespace mlir::sparse_tensor;
35
36namespace {
37
38// Sparse formats supported by cuSparse.
39enum class CuSparseFormat {
40 kNone,
41 kCOO,
42 kCSR,
43 kCSC,
44 kBSR,
45};
46
47//===----------------------------------------------------------------------===//
48// Helper methods.
49//===----------------------------------------------------------------------===//
50
51/// Marks the given top module as a GPU container module.
52static void markAsGPUContainer(ModuleOp topModule) {
53 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
54 UnitAttr::get(topModule->getContext()));
55}
56
57/// Constructs a new GPU module (for GPU kernels) inside the given top module,
58/// or returns an existing GPU module if one was built previously.
59static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
60 for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
61 return op; // existing
62 markAsGPUContainer(topModule);
63 builder.setInsertionPointToStart(topModule.getBody());
64 return gpu::GPUModuleOp::create(builder, topModule->getLoc(),
65 "sparse_kernels");
66}
67
68/// Constructs a new GPU kernel in the given GPU module.
69static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
71 // Get a unique kernel name. Not very creative,
72 // but we simply try kernel0, kernel1, etc.
73 unsigned kernelNumber = 0;
74 SmallString<16> kernelName;
75 do {
76 kernelName.clear();
77 ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
78 } while (gpuModule.lookupSymbol(kernelName));
79 // Then we insert a new kernel with given arguments into the module.
80 builder.setInsertionPointToStart(gpuModule.getBody());
81 SmallVector<Type> argsTp;
82 for (auto arg : args)
83 argsTp.push_back(arg.getType());
84 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
85 auto gpuFunc =
86 gpu::GPUFuncOp::create(builder, gpuModule->getLoc(), kernelName, type);
87 gpuFunc.setKernel(true);
88 return gpuFunc;
89}
90
91/// Constructs code to launch GPU kernel.
92static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
95 unsigned numThreads) {
96 Location loc = gpuFunc->getLoc();
98 Value one = constantIndex(builder, loc, 1);
99 Value numT = constantIndex(builder, loc, numThreads);
100 gpu::KernelDim3 gridSize = {one, one, one};
101 gpu::KernelDim3 blckSize = {numT, one, one};
102 return gpu::LaunchFuncOp::create(builder, loc, gpuFunc, gridSize, blckSize,
103 /*dynSharedMemSz*/ none, args,
104 builder.getType<gpu::AsyncTokenType>(),
105 tokens)
106 .getAsyncToken();
107}
108
109/// Maps the provided ranked host buffer into the device address space.
110/// Writes from the host are guaranteed to be visible to device kernels
111/// that are launched afterwards. Writes from the device are guaranteed
112/// to be visible on the host after synchronizing with the device kernel
113/// completion. Needs to cast the buffer to a unranked buffer.
114static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
115 Value mem) {
116 MemRefType memTp = cast<MemRefType>(mem.getType());
117 UnrankedMemRefType resTp =
118 UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
119 Value cast = memref::CastOp::create(builder, loc, resTp, mem);
120 gpu::HostRegisterOp::create(builder, loc, cast);
121 return cast;
122}
123
124/// Unmaps the provided buffer, expecting the casted buffer.
125static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
126 Value cast) {
127 gpu::HostUnregisterOp::create(builder, loc, cast);
128}
129
130/// Generates first wait in an asynchronous chain.
131static Value genFirstWait(OpBuilder &builder, Location loc) {
132 Type tokenType = builder.getType<gpu::AsyncTokenType>();
133 return gpu::WaitOp::create(builder, loc, tokenType, ValueRange())
134 .getAsyncToken();
135}
136
137/// Generates last, blocking wait in an asynchronous chain.
138static void genBlockingWait(OpBuilder &builder, Location loc,
139 ValueRange operands) {
140 gpu::WaitOp::create(builder, loc, Type(), operands);
141}
142
143/// Allocates memory on the device.
144/// TODO: A `host_shared` attribute could be used to indicate that
145/// the buffer is visible by both host and device, but lowering
146/// that feature does not seem to be fully supported yet.
147static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
148 Value token) {
149 auto tp = cast<ShapedType>(mem.getType());
150 auto elemTp = tp.getElementType();
151 auto shape = tp.getShape();
152 auto memTp = MemRefType::get(shape, elemTp);
153 SmallVector<Value> dynamicSizes;
154 for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
155 if (shape[r] == ShapedType::kDynamic) {
156 Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
157 dynamicSizes.push_back(dimOp);
158 }
159 }
160 return gpu::AllocOp::create(builder, loc, TypeRange({memTp, token.getType()}),
161 token, dynamicSizes, ValueRange());
162}
163
164// Allocates a typed buffer on the host with given size.
165static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
166 Value size) {
167 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
168 return memref::AllocOp::create(builder, loc, memTp, size).getResult();
169}
170
171// Allocates a typed buffer on the device with given size.
172static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
173 Value size, Value token) {
174 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
175 return gpu::AllocOp::create(builder, loc, TypeRange({memTp, token.getType()}),
176 token, size, ValueRange());
177}
178
179// Allocates a void buffer on the device with given size.
180static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
181 Value token) {
182 return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
183}
184
185/// Deallocates memory from the device.
186static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
187 Value token) {
188 return gpu::DeallocOp::create(builder, loc, token.getType(), token, mem)
189 .getAsyncToken();
190}
191
192/// Copies memory between host and device (direction is implicit).
193static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
194 Value src, Value token) {
195 return gpu::MemcpyOp::create(builder, loc, token.getType(), token, dst, src)
196 .getAsyncToken();
197}
198
199/// Generates an alloc/copy pair.
200static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
201 SmallVectorImpl<Value> &tokens) {
202 Value firstToken = genFirstWait(builder, loc);
203 auto alloc = genAllocMemRef(builder, loc, b, firstToken);
204 Value devMem = alloc.getResult(0);
205 Value depToken = alloc.getAsyncToken(); // copy-after-alloc
206 tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
207 return devMem;
208}
209
210/// Generates a memref from tensor operation.
211static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
212 Value tensor) {
213 auto tensorType = llvm::cast<ShapedType>(tensor.getType());
214 auto memrefType =
215 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
216 return bufferization::ToBufferOp::create(rewriter, loc, memrefType, tensor);
217}
218
219/// Prepares the outlined arguments, passing scalars and buffers in. Here we
220/// assume that the first buffer is the one allocated for output. We create
221/// a set of properly chained asynchronous allocation/copy pairs to increase
222/// overlap before launching the kernel.
223static Value genParametersIn(OpBuilder &builder, Location loc,
224 SmallVectorImpl<Value> &scalars,
225 SmallVectorImpl<Value> &buffers,
228 bool useHostRegistrationForOut) {
229 Value out;
230 // Scalars are passed by value.
231 for (Value s : scalars)
232 args.push_back(s);
233 // Buffers are need to be made visible on device.
234 for (Value b : buffers) {
235 if (useHostRegistrationForOut) {
236 out = genHostRegisterMemref(builder, loc, b);
237 args.push_back(b);
238 useHostRegistrationForOut = false;
239 continue;
240 }
241 args.push_back(genAllocCopy(builder, loc, b, tokens));
242 }
243 return out;
244}
245
246/// Finalizes the outlined arguments. The output buffer is copied depending
247/// on the kernel token and then deallocated. All other buffers are simply
248/// deallocated. Then we wait for all operations to complete.
249///
250/// `copyBack` maps 1:1 to the `buffers` array. It tracks which buffers were
251/// mutated by the kernel and require a device-to-host copy. An empty
252/// `copyBack` array implies no buffers are "copied back".
253static void genParametersOut(OpBuilder &builder, Location loc, Value out,
254 Value kernelToken, SmallVectorImpl<Value> &scalars,
255 SmallVectorImpl<Value> &buffers,
258 ArrayRef<bool> copyBack) {
259 unsigned base = scalars.size();
260
261 // `args` stores scalars followed by buffers. `base` is the index of the first
262 // buffer. `bufIdx` maps the current buffer to its exact 1:1 counterpart in
263 // the `copyBack` mask.
264 for (unsigned i = base, e = args.size(); i < e; i++) {
265 unsigned bufIdx = i - base;
266 Value firstToken;
267
268 // Checks if the current buffer needs a device-to-host copy.
269 if (copyBack[bufIdx]) {
270 if (out && bufIdx == 0) {
271 genHostUnregisterMemref(builder, loc, out);
272 out = Value();
273 continue;
274 }
275 firstToken =
276 genCopyMemRef(builder, loc, buffers[bufIdx], args[i], kernelToken);
277 } else {
278 firstToken = genFirstWait(builder, loc);
279 }
280 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
281 }
282}
283
284/// Constructs code for new GPU kernel.
285static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
286 scf::ParallelOp forallOp,
287 SmallVectorImpl<Value> &constants,
288 SmallVectorImpl<Value> &scalars,
289 SmallVectorImpl<Value> &buffers) {
290 Location loc = gpuFunc->getLoc();
291 Block &block = gpuFunc.getBody().front();
292 rewriter.setInsertionPointToStart(&block);
293
294 // Re-generate the constants, recapture all arguments.
295 unsigned arg = 0;
296 IRMapping irMap;
297 for (Value c : constants)
298 irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
299 for (Value s : scalars)
300 irMap.map(s, block.getArgument(arg++));
301 for (Value b : buffers)
302 irMap.map(b, block.getArgument(arg++));
303
304 // Assume 1-dimensional grid/block configuration (only x dimension),
305 // so that:
306 // row = blockIdx.x * blockDim.x + threadIdx.x
307 // inc = blockDim.x * gridDim.x
308 Value bid = gpu::BlockIdOp::create(rewriter, loc, gpu::Dimension::x);
309 Value bsz = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x);
310 Value tid = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
311 Value gsz = gpu::GridDimOp::create(rewriter, loc, gpu::Dimension::x);
312 Value mul = arith::MulIOp::create(rewriter, loc, bid, bsz);
313 Value row = arith::AddIOp::create(rewriter, loc, mul, tid);
314 Value inc = arith::MulIOp::create(rewriter, loc, bsz, gsz);
315
316 // Construct the iteration over the computational space that
317 // accounts for the fact that the total number of threads and
318 // the amount of work to be done usually do not match precisely.
319 // for (r = row; r < N; r += inc) {
320 // <loop-body>
321 // }
322 Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
323 scf::ForOp forOp = scf::ForOp::create(rewriter, loc, row, upper, inc);
324 // The scf.for builder creates an empty block. scf.for does not allow multiple
325 // blocks in its region, so delete the block before `cloneRegionBefore` adds
326 // an additional block.
327 rewriter.eraseBlock(forOp.getBody());
328 rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
329 forOp.getRegion().begin(), irMap);
330 // Replace the scf.reduce terminator.
331 rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
332 rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
333
334 // Done.
335 rewriter.setInsertionPointAfter(forOp);
336 gpu::ReturnOp::create(rewriter, gpuFunc->getLoc());
337}
338
339//===----------------------------------------------------------------------===//
340// Library helper methods.
341//===----------------------------------------------------------------------===//
342
343/// Helper to detect a + b with arguments taken from given block.
344static bool matchAddOfArgs(Block *block, Value val) {
345 if (auto *def = val.getDefiningOp()) {
346 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
347 Value a = block->getArguments()[0];
348 Value b = block->getArguments()[1];
349 return (def->getOperand(0) == a && def->getOperand(1) == b) ||
350 (def->getOperand(0) == b && def->getOperand(1) == a);
351 }
352 }
353 return false;
354}
355
356/// Helper to detect a * b with arguments taken from given block.
357static bool matchMulOfArgs(Block *block, Value val) {
358 if (auto *def = val.getDefiningOp()) {
359 if (isa<arith::MulFOp, arith::MulIOp>(def)) {
360 Value a = block->getArguments()[0];
361 Value b = block->getArguments()[1];
362 return (def->getOperand(0) == a && def->getOperand(1) == b) ||
363 (def->getOperand(0) == b && def->getOperand(1) == a);
364 }
365 }
366 return false;
367}
368
369/// Helper to detect x = x + a * b
370static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
371 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
372 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
373 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
374 Value x = op.getBlock()->getArguments()[2];
375 return (def->getOperand(0) == x &&
376 matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
377 (def->getOperand(1) == x &&
378 matchMulOfArgs(op.getBlock(), def->getOperand(0)));
379 }
380 }
381 return false;
382}
383
384// Helper to detect c += spy(s) x (a * b)
385static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
386 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
387 // The linalg yields a custom reduce result.
388 Value s_out = op.getBlock()->getArguments()[2];
389 if (auto redOp =
390 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
391 // The reduce consumes the output.
392 Value other;
393 if (s_out == redOp->getOperand(0))
394 other = redOp->getOperand(1);
395 else if (s_out == redOp->getOperand(1))
396 other = redOp->getOperand(0);
397 else
398 return false;
399 // The reduce op also consumes an unary which also consumes the output
400 // and does not define an absent value.
401 if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
402 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
403 return false;
404 // And the bodies are as expected.
405 auto yieldUn = cast<sparse_tensor::YieldOp>(
406 unOp.getRegion(0).front().getTerminator());
407 auto yieldRed = cast<sparse_tensor::YieldOp>(
408 redOp.getRegion().front().getTerminator());
409 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
410 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
411 }
412 }
413 return false;
414}
415
416/// Test for dense tensor.
417static bool isDenseTensor(Value v) {
418 auto sTp = getSparseTensorType(v);
419 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
420}
421
422/// Test for suitable positions/coordinates width.
423static bool isAdmissibleMetaData(SparseTensorType &aTp) {
424 return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
425 (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
426}
427
428/// Test for sorted COO matrix with suitable metadata.
429static bool isAdmissibleCOO(SparseTensorType &aTp) {
430 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
431 aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
432 aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
433 isAdmissibleMetaData(aTp);
434}
435
436/// Test for CSR matrix with suitable metadata.
437static bool isAdmissibleCSR(SparseTensorType &aTp) {
438 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
439 aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
440 aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
441}
442
443/// Test for CSC matrix with suitable metadata.
444static bool isAdmissibleCSC(SparseTensorType &aTp) {
445 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
446 aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
447 aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
448}
449
450/// Test for BSR matrix with suitable metadata.
451static bool isAdmissibleBSR(SparseTensorType &aTp) {
452 if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
453 aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
454 aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
455 // CuSparse only supports "square" blocks currently.
457 assert(dims.size() == 2);
458 return dims[0] == dims[1] && dims[0] > 1;
459 }
460 return false;
461}
462
463/// Test for 2:4 matrix with suitable metadata.
464static bool isAdmissible24(SparseTensorType &aTp) {
465 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
466 aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
467}
468
469/// Test for conversion into 2:4 matrix.
470static bool isConversionInto24(Value v) {
471 if (auto cnv = v.getDefiningOp<ConvertOp>()) {
472 Value a = cnv.getResult();
473 Value d = cnv.getSource();
475 return isDenseTensor(d) && isAdmissible24(aTp);
476 }
477 return false;
478}
479
480/// Returns a suitable sparse format for the operation and given operand
481/// types with cuSparse, or kNone if none is available.
482static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
484 SparseTensorType cTp, bool enableRT,
485 bool isMatVec) {
486 // The other operands have a dense type.
487 if (bTp.hasEncoding() || cTp.hasEncoding())
488 return CuSparseFormat::kNone;
489 // Now check for suitable operand type for the main operand.
490 if (isAdmissibleCOO(aTp))
491#ifdef CUSPARSE_COO_AOS
492 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
493#else
494 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
495#endif
496 if (isAdmissibleCSR(aTp))
497 return CuSparseFormat::kCSR;
498 if (isAdmissibleCSC(aTp))
499 return CuSparseFormat::kCSC;
500 if (isAdmissibleBSR(aTp))
501 return CuSparseFormat::kBSR;
502 return CuSparseFormat::kNone;
503}
504
505/// Generates the first positions/coordinates of a sparse matrix.
506static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
507 CuSparseFormat format, bool enableRT) {
508 if (format == CuSparseFormat::kCOO) {
509 // Library uses SoA COO, direct IR uses AoS COO.
510 if (enableRT)
511 return ToCoordinatesOp::create(builder, loc, a, 0);
512 return ToCoordinatesBufferOp::create(builder, loc, a);
513 }
514 // Formats CSR/CSC and BSR use positions at 1.
515 return ToPositionsOp::create(builder, loc, a, 1);
516}
517
518/// Generates the second coordinates of a sparse matrix.
519static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
520 CuSparseFormat format, bool enableRT) {
521 bool isCOO = format == CuSparseFormat::kCOO;
522 if (isCOO && !enableRT)
523 return Value(); // nothing needed
524 // Formats CSR/CSC and BSR use coordinates at 1.
525 return ToCoordinatesOp::create(builder, loc, a, 1);
526}
527
528/// Generates the sparse matrix handle.
529static Operation *genSpMat(OpBuilder &builder, Location loc,
530 SparseTensorType &aTp, Type handleTp, Type tokenTp,
531 Value token, Value sz1, Value sz2, Value nseA,
532 Value rowA, Value colA, Value valA,
533 CuSparseFormat format, bool enableRT) {
534 if (format == CuSparseFormat::kCOO) {
535 // Library uses SoA COO, direct IR uses AoS COO.
536 if (enableRT) {
537 assert(colA);
538 return gpu::CreateCooOp::create(builder, loc, handleTp, tokenTp, token,
539 sz1, sz2, nseA, rowA, colA, valA);
540 }
541#ifdef CUSPARSE_COO_AOS
542 assert(!colA);
543 return gpu::CreateCooAoSOp::create(builder, loc, handleTp, tokenTp, token,
544 sz1, sz2, nseA, rowA, valA);
545#else
546 llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
547#endif
548 }
549 assert(colA);
550 if (format == CuSparseFormat::kCSR)
551 return gpu::CreateCsrOp::create(builder, loc, handleTp, tokenTp, token, sz1,
552 sz2, nseA, rowA, colA, valA);
553 if (format == CuSparseFormat::kCSC)
554 return gpu::CreateCscOp::create(builder, loc, handleTp, tokenTp, token, sz1,
555 sz2, nseA, rowA, colA, valA);
556 // BSR requires a bit more work since we need to pass in the block size
557 // and all others sizes in terms of blocks (#block-rows, #block-cols,
558 // #nonzero-blocks).
559 assert(format == CuSparseFormat::kBSR);
561 assert(dims.size() == 2 && dims[0] == dims[1]);
562 uint64_t b = dims[0];
563 Value bSz = constantIndex(builder, loc, b);
564 Value bRows = arith::DivUIOp::create(builder, loc, sz1, bSz);
565 Value bCols = arith::DivUIOp::create(builder, loc, sz2, bSz);
566 Value bNum = arith::DivUIOp::create(builder, loc, nseA,
567 constantIndex(builder, loc, b * b));
568 return gpu::CreateBsrOp::create(builder, loc, handleTp, tokenTp, token, bRows,
569 bCols, bNum, bSz, bSz, rowA, colA, valA);
570}
571
572/// Match and rewrite SpMV kernel.
573static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
574 linalg::GenericOp op, bool enableRT) {
575 Location loc = op.getLoc();
576 Value a = op.getOperand(0);
577 Value x = op.getOperand(1);
578 Value y = op.getOperand(2); // we have y = Ax
579 SmallVector<Value> tokens;
580
581 // Only admissible sparse matrix format and dense vectors (no BSR).
585 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
586 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
587 return failure();
588
589 // Start sparse kernel and copy data from host to device.
590 // a : memR/memC/memV -> rowA,colA,valA
591 // x : memX -> vecX
592 // y : memY -> vecY
593 Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
594 Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
595 Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
596 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
597 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
598 Value memV = ToValuesOp::create(rewriter, loc, a);
599 Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
600 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
601 Value valA = genAllocCopy(rewriter, loc, memV, tokens);
602 Value memX = genTensorToMemref(rewriter, loc, x);
603 Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
604 Value memY = genTensorToMemref(rewriter, loc, y);
605 Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
606 genBlockingWait(rewriter, loc, tokens);
607 tokens.clear();
608
609 // Create sparse environment and sparse matrix/dense vector handles.
610 Type indexTp = rewriter.getIndexType();
611 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
612 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
613 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
614 Value token = genFirstWait(rewriter, loc);
615 Operation *spGenA =
616 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
617 nseA, rowA, colA, valA, format, enableRT);
618 Value spMatA = spGenA->getResult(0);
619 token = spGenA->getResult(1);
620 auto dvecX = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp,
621 tokenTp, token, vecX, szX);
622 Value dnX = dvecX.getResult(0);
623 token = dvecX.getAsyncToken();
624 auto dvecY = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp,
625 tokenTp, token, vecY, szY);
626 Value dnY = dvecY.getResult(0);
627 token = dvecY.getAsyncToken();
628 auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
629
630 // Precompute buffersize for SpMV.
631 auto bufferComp = gpu::SpMVBufferSizeOp::create(
632 rewriter, loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
633 /*computeType=*/dnYType);
634 Value bufferSz = bufferComp.getResult(0);
635 token = bufferComp.getAsyncToken();
636 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
637 Value buffer = buf.getResult(0);
638 token = buf.getAsyncToken();
639
640 // Perform the SpMV.
641 auto spmvComp =
642 gpu::SpMVOp::create(rewriter, loc, tokenTp, token, spMatA, dnX, dnY,
643 /*computeType=*/dnYType, buffer);
644 token = spmvComp.getAsyncToken();
645
646 // Copy data back to host and free all the resoures.
647 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
648 .getAsyncToken();
649 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnX)
650 .getAsyncToken();
651 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnY)
652 .getAsyncToken();
653 token = genDeallocMemRef(rewriter, loc, rowA, token);
654 if (colA)
655 token = genDeallocMemRef(rewriter, loc, colA, token);
656 token = genDeallocMemRef(rewriter, loc, valA, token);
657 token = genDeallocMemRef(rewriter, loc, buffer, token);
658 token = genDeallocMemRef(rewriter, loc, vecX, token);
659 token = genCopyMemRef(rewriter, loc, memY, vecY, token);
660 token = genDeallocMemRef(rewriter, loc, vecY, token);
661 tokens.push_back(token);
662 genBlockingWait(rewriter, loc, tokens);
663 tokens.clear();
664
665 // Done.
666 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, y.getType(), memY);
667 return success();
668}
669
670/// Match and rewrite SpMM kernel.
671static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
672 linalg::GenericOp op, bool enableRT) {
673 Location loc = op.getLoc();
674 Value a = op.getOperand(0);
675 Value b = op.getOperand(1);
676 Value c = op.getOperand(2); // we have C = AB
677 SmallVector<Value> tokens;
678
679 // Only admissible sparse matrix format and dense matrices (no BSR).
683 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
684 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
685 return failure();
686
687 // Start sparse kernel and copy data from host to device.
688 // a : memR/memC/memV -> rowA,colA,valA
689 // b : bufB -> matB
690 // c : bufC -> matC
691 Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
692 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
693 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
694 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
695 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
696 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
697 Value memV = ToValuesOp::create(rewriter, loc, a);
698 Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
699 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
700 Value valA = genAllocCopy(rewriter, loc, memV, tokens);
701 Value bufB = genTensorToMemref(rewriter, loc, b);
702 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
703 Value bufC = genTensorToMemref(rewriter, loc, c);
704 Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
705 genBlockingWait(rewriter, loc, tokens);
706 tokens.clear();
707
708 // Create sparse environment and sparse matrix/dense matrix handles.
709 Type indexTp = rewriter.getIndexType();
710 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
711 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
712 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
713 Value token = genFirstWait(rewriter, loc);
714 Operation *spGenA =
715 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
716 nseA, rowA, colA, valA, format, enableRT);
717 Value spMatA = spGenA->getResult(0);
718 token = spGenA->getResult(1);
719 auto dmatB =
720 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
721 token, matB, SmallVector<Value>{szk, szn});
722 Value dnB = dmatB.getResult(0);
723 token = dmatB.getAsyncToken();
724 auto dmatC =
725 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
726 token, matC, SmallVector<Value>{szm, szn});
727 Value dnC = dmatC.getResult(0);
728 token = dmatC.getAsyncToken();
729 auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
730
731 // Precompute buffersize for SpMM.
732 auto bufferComp = gpu::SpMMBufferSizeOp::create(
733 rewriter, loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
734 /*computeType=*/dmatCType);
735 Value bufferSz = bufferComp.getResult(0);
736 token = bufferComp.getAsyncToken();
737 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
738 Value buffer = buf.getResult(0);
739 token = buf.getAsyncToken();
740 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
741
742 // Perform the SpMM.
743 auto spmmComp =
744 gpu::SpMMOp::create(rewriter, loc, tokenTp, token, spMatA, dnB, dnC,
745 /*computeType=*/dnCType, buffer);
746 token = spmmComp.getAsyncToken();
747
748 // Copy data back to host and free all the resoures.
749 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
750 .getAsyncToken();
751 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
752 .getAsyncToken();
753 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC)
754 .getAsyncToken();
755 token = genDeallocMemRef(rewriter, loc, rowA, token);
756 if (colA)
757 token = genDeallocMemRef(rewriter, loc, colA, token);
758 token = genDeallocMemRef(rewriter, loc, valA, token);
759 token = genDeallocMemRef(rewriter, loc, buffer, token);
760 token = genDeallocMemRef(rewriter, loc, matB, token);
761 token = genCopyMemRef(rewriter, loc, bufC, matC, token);
762 token = genDeallocMemRef(rewriter, loc, matC, token);
763 tokens.push_back(token);
764 genBlockingWait(rewriter, loc, tokens);
765 tokens.clear();
766
767 // Done.
768 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, c.getType(), bufC);
769 return success();
770}
771
772// Match and rewrite SpGEMM kernel.
773static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
774 linalg::GenericOp op, bool enableRT) {
775 Location loc = op.getLoc();
776 Value a = op.getOperand(0);
777 Value b = op.getOperand(1);
778 Value c = op.getOperand(2); // we have C = AB
779 SmallVector<Value> tokens;
780
781 // Only CSR <- CSR x CSR supported.
782 auto format = CuSparseFormat::kCSR;
786 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
787 return failure();
788
789 // Start sparse kernel and copy data from host to device.
790 // a : amemR/amemC/amemV -> rowA,colA,valA
791 // b : bmemR/bmemC/bmemV -> rowB,colB,valB
792 // c : materializes
793 auto dnCType = cTp.getElementType();
794 Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
795 Value nseB = NumberOfEntriesOp::create(rewriter, loc, b);
796 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
797 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
798 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
799 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
800 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
801 Value amemV = ToValuesOp::create(rewriter, loc, a);
802 Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
803 Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
804 Value bmemV = ToValuesOp::create(rewriter, loc, b);
805 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
806 Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
807 Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
808 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
809 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
810 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
811 genBlockingWait(rewriter, loc, tokens);
812 tokens.clear();
813
814 // Create sparse environment and sparse matrix/dense vector handles.
815 Type indexTp = rewriter.getIndexType();
816 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
817 Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
818 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
819 Value token = genFirstWait(rewriter, loc);
820 Operation *spGenA =
821 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
822 nseA, rowA, colA, valA, format, enableRT);
823 Value spMatA = spGenA->getResult(0);
824 token = spGenA->getResult(1);
825 Operation *spGenB =
826 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
827 nseB, rowB, colB, valB, format, enableRT);
828 Value spMatB = spGenB->getResult(0);
829 token = spGenB->getResult(1);
830
831 // Sparse matrix C materializes (also assumes beta == 0).
832 Value zero = constantIndex(rewriter, loc, 0);
833 Value one = constantIndex(rewriter, loc, 1);
834 Value mplus1 = arith::AddIOp::create(rewriter, loc, szm, one);
835 auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
836 Value rowC = e1.getResult(0);
837 token = e1.getAsyncToken();
838 auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
839 Value colC = e2.getResult(0); // no free needed
840 token = e2.getAsyncToken();
841 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
842 Value valC = e3.getResult(0); // no free needed
843 token = e3.getAsyncToken();
844 Operation *spGenC =
845 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
846 zero, rowC, colC, valC, format, enableRT);
847 Value spMatC = spGenC->getResult(0);
848 token = spGenC->getResult(1);
849
850 // Precompute buffersizes for SpGEMM.
851 Operation *descOp =
852 gpu::SpGEMMCreateDescrOp::create(rewriter, loc, descTp, tokenTp, token);
853 Value desc = descOp->getResult(0);
854 token = descOp->getResult(1);
855 Operation *work1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
856 rewriter, loc, indexTp, tokenTp, token, desc,
857 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
858 spMatA, spMatB, spMatC, dnCType, zero, valC,
859 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
860 Value bufferSz1 = work1->getResult(0);
861 token = work1->getResult(1);
862 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
863 Value buffer1 = buf1.getResult(0);
864 token = buf1.getAsyncToken();
865 Operation *work2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
866 rewriter, loc, indexTp, tokenTp, token, desc,
867 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
868 spMatA, spMatB, spMatC, dnCType, bufferSz1, buffer1,
869 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
870 token = work2->getResult(1);
871
872 // Compute step.
873 Operation *compute1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
874 rewriter, loc, indexTp, tokenTp, token, desc,
875 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
876 spMatA, spMatB, spMatC, dnCType, zero, valC,
877 gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
878 Value bufferSz2 = compute1->getResult(0);
879 token = compute1->getResult(1);
880 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
881 Value buffer2 = buf2.getResult(0);
882 token = buf2.getAsyncToken();
883 Operation *compute2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
884 rewriter, loc, indexTp, tokenTp, token, desc,
885 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
886 spMatA, spMatB, spMatC, dnCType, bufferSz2, buffer2,
887 gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
888 token = compute2->getResult(1);
889
890 // Get sizes.
891 Operation *sizes = gpu::SpMatGetSizeOp::create(
892 rewriter, loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
893 Value nnz = sizes->getResult(2);
894 token = sizes->getResult(3);
895 auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
896 colC = a2.getResult(0);
897 token = a2.getAsyncToken();
898 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
899 valC = a3.getResult(0);
900 token = a3.getAsyncToken();
901
902 // Update C with new pointers and copy final product back into C.
903 Operation *update = gpu::SetCsrPointersOp::create(
904 rewriter, loc, tokenTp, token, spMatC, rowC, colC, valC);
905 token = update->getResult(0);
906 Operation *copy = gpu::SpGEMMCopyOp::create(
907 rewriter, loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
908 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
909 token = copy->getResult(0);
910
911 // Allocate buffers on host.
912 Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
913 Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
914 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
915
916 // Copy data back to host and free all the resoures.
917 token = gpu::SpGEMMDestroyDescrOp::create(rewriter, loc, tokenTp, token, desc)
918 .getAsyncToken();
919 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
920 .getAsyncToken();
921 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatB)
922 .getAsyncToken();
923 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC)
924 .getAsyncToken();
925 token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
926 token = genCopyMemRef(rewriter, loc, colH, colC, token);
927 token = genCopyMemRef(rewriter, loc, valH, valC, token);
928 token = genDeallocMemRef(rewriter, loc, rowA, token);
929 token = genDeallocMemRef(rewriter, loc, colA, token);
930 token = genDeallocMemRef(rewriter, loc, valA, token);
931 token = genDeallocMemRef(rewriter, loc, rowB, token);
932 token = genDeallocMemRef(rewriter, loc, colB, token);
933 token = genDeallocMemRef(rewriter, loc, valB, token);
934 token = genDeallocMemRef(rewriter, loc, rowC, token);
935 token = genDeallocMemRef(rewriter, loc, colC, token);
936 token = genDeallocMemRef(rewriter, loc, valC, token);
937 token = genDeallocMemRef(rewriter, loc, buffer1, token);
938 token = genDeallocMemRef(rewriter, loc, buffer2, token);
939 tokens.push_back(token);
940 genBlockingWait(rewriter, loc, tokens);
941 tokens.clear();
942
943 // Done.
944 Value vt = bufferization::ToTensorOp::create(
945 rewriter, loc, memref::getTensorTypeFromMemRefType(valH.getType()), valH);
946 Value rt = bufferization::ToTensorOp::create(
947 rewriter, loc, memref::getTensorTypeFromMemRefType(rowH.getType()), rowH);
948 Value ct = bufferization::ToTensorOp::create(
949 rewriter, loc, memref::getTensorTypeFromMemRefType(colH.getType()), colH);
950 rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
951 vt);
952 return success();
953}
954
955// Match and rewrite 2:4 SpMM kernel.
956static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
957 linalg::GenericOp op) {
958 Location loc = op.getLoc();
959 Value A = op.getOperand(0);
960 Value B = op.getOperand(1);
961 Value C = op.getOperand(2); // we have C = AB
962 SmallVector<Value> tokens;
963
964 // The cuSparselt API currently only allows pruning and compression
965 // to occur on the device. So we recognize the pattern
966 // A' = convert A ; dense to 2:4
967 // C = A'B ; 2:4 matrix mult
968 // and then perform compression and matrix multiplication on device.
969 auto cnv = A.getDefiningOp<ConvertOp>();
970 assert(cnv);
971 A = cnv.getSource();
972
973 // All input should be dense tensors.
974 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
975 return failure();
976
977 // Start sparse kernel and copy data from host to device.
978 // a : bufA -> matA
979 // b : bufB -> matB
980 // c : bufC -> matC
981 Value bufA = genTensorToMemref(rewriter, loc, A);
982 Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
983 Value bufB = genTensorToMemref(rewriter, loc, B);
984 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
985 Value bufC = genTensorToMemref(rewriter, loc, C);
986 Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
987 genBlockingWait(rewriter, loc, tokens);
988 tokens.clear();
989
990 // Create sparse environment and sparse matrix/dense vector handles.
991 Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
992 Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
993 Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
994 Type indexTp = rewriter.getIndexType();
995 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
996 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
997 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
998 Value token = genFirstWait(rewriter, loc);
999 Operation *spGenA = gpu::Create2To4SpMatOp::create(
1000 rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk,
1001 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
1002 Value spMatA = spGenA->getResult(0);
1003 token = spGenA->getResult(1);
1004 auto dmatB =
1005 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
1006 token, matB, SmallVector<Value>{szk, szn});
1007 Value dnB = dmatB.getResult(0);
1008 token = dmatB.getAsyncToken();
1009 auto dmatC =
1010 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
1011 token, matC, SmallVector<Value>{szm, szn});
1012 Value dnC = dmatC.getResult(0);
1013 token = dmatC.getAsyncToken();
1014 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1015
1016 // Precompute buffersize for SpMM.
1017 SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
1018 TypeRange bufferTypes(bufferTypes_);
1019 auto bufferComp = gpu::SpMMBufferSizeOp::create(
1020 rewriter, loc, bufferTypes, tokenTp, token,
1021 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
1022 spMatA, dnB, dnC,
1023 /*computeType=*/dmatCType);
1024 token = bufferComp.getAsyncToken();
1025
1026 // Allocate buffers on host.
1027 Value bufferSz1 = bufferComp.getResult(0);
1028 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
1029 Value buffer1 = buf1.getResult(0);
1030 token = buf1.getAsyncToken();
1031 Value bufferSz2 = bufferComp.getResult(1);
1032 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
1033 Value buffer2 = buf2.getResult(0);
1034 token = buf2.getAsyncToken();
1035 Value bufferSz3 = bufferComp.getResult(2);
1036 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
1037 Value buffer3 = buf3.getResult(0);
1038 token = buf3.getAsyncToken();
1039
1040 // Perform the SpMM.
1041 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1042 auto spmmComp = gpu::SpMMOp::create(
1043 rewriter, loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
1044 SmallVector<Value>{buffer1, buffer2, buffer3});
1045 token = spmmComp.getAsyncToken();
1046
1047 // Copy data back to host and free all the resources.
1048 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
1049 .getAsyncToken();
1050 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
1051 .getAsyncToken();
1052 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC)
1053 .getAsyncToken();
1054 token = genDeallocMemRef(rewriter, loc, buffer1, token);
1055 token = genDeallocMemRef(rewriter, loc, buffer2, token);
1056 token = genDeallocMemRef(rewriter, loc, buffer3, token);
1057 token = genDeallocMemRef(rewriter, loc, matA, token);
1058 token = genDeallocMemRef(rewriter, loc, matB, token);
1059 token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1060 token = genDeallocMemRef(rewriter, loc, matC, token);
1061 tokens.push_back(token);
1062 genBlockingWait(rewriter, loc, tokens);
1063 tokens.clear();
1064
1065 // Done.
1066 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, C.getType(), bufC);
1067 return success();
1068}
1069
1070/// Match and rewrite SDDMM kernel.
1071static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
1072 linalg::GenericOp op, bool enableRT) {
1073 Location loc = op.getLoc();
1074 Value a = op.getOperand(0);
1075 Value b = op.getOperand(1);
1076 Value c = op.getOperand(2);
1077 SmallVector<Value> tokens;
1078
1079 // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1083 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
1084 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1085 format == CuSparseFormat::kCSC)
1086 return failure();
1087
1088 // The SDDMM does the in-place operation.
1089 // Start sparse kernel and copy data from host to device.
1090 // a : bufA -> matA
1091 // b : bufB -> matB
1092 // c : memR/memC/memV -> rowC,colC,valC
1093 Value nseC = NumberOfEntriesOp::create(rewriter, loc, c);
1094 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
1095 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
1096 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
1097 Value bufA = genTensorToMemref(rewriter, loc, a);
1098 Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
1099 Value bufB = genTensorToMemref(rewriter, loc, b);
1100 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
1101 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1102 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
1103 Value memV = ToValuesOp::create(rewriter, loc, c);
1104 Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1105 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
1106 Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1107 genBlockingWait(rewriter, loc, tokens);
1108 tokens.clear();
1109
1110 // Create sparse environment and sparse matrix/dense matrix handles.
1111 Type indexTp = rewriter.getIndexType();
1112 Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1113 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1114 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1115 Value token = genFirstWait(rewriter, loc);
1116 auto dmatA =
1117 gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp,
1118 token, matA, SmallVector<Value>{szm, szk});
1119 Value dnA = dmatA.getResult(0);
1120 token = dmatA.getAsyncToken();
1121 auto dmatB =
1122 gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp,
1123 token, matB, SmallVector<Value>{szk, szn});
1124 Value dnB = dmatB.getResult(0);
1125 token = dmatB.getAsyncToken();
1126 Operation *spGenC =
1127 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1128 nseC, rowC, colC, valC, format, enableRT);
1129 Value spMatC = spGenC->getResult(0);
1130 token = spGenC->getResult(1);
1131 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1132
1133 // Precompute buffersize for SDDMM.
1134 auto bufferComp = gpu::SDDMMBufferSizeOp::create(
1135 rewriter, loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1136 Value bufferSz = bufferComp.getResult(0);
1137 token = bufferComp.getAsyncToken();
1138 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1139 Value buffer = buf.getResult(0);
1140 token = buf.getAsyncToken();
1141
1142 // Perform the SDDMM.
1143 auto sddmmComp = gpu::SDDMMOp::create(rewriter, loc, tokenTp, token, dnA, dnB,
1144 spMatC, dnCType, buffer);
1145 token = sddmmComp.getAsyncToken();
1146
1147 // Copy data back to host and free all the resoures.
1148 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnA)
1149 .getAsyncToken();
1150 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
1151 .getAsyncToken();
1152 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC)
1153 .getAsyncToken();
1154 token = genDeallocMemRef(rewriter, loc, buffer, token);
1155 token = genDeallocMemRef(rewriter, loc, matA, token);
1156 token = genDeallocMemRef(rewriter, loc, matB, token);
1157 token = genDeallocMemRef(rewriter, loc, rowC, token);
1158 if (colC)
1159 token = genDeallocMemRef(rewriter, loc, colC, token);
1160 token = genCopyMemRef(rewriter, loc, memV, valC, token);
1161 token = genDeallocMemRef(rewriter, loc, valC, token);
1162 tokens.push_back(token);
1163 genBlockingWait(rewriter, loc, tokens);
1164 tokens.clear();
1165
1166 // Done.
1167 rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1168 return success();
1169}
1170
1171//===----------------------------------------------------------------------===//
1172// Rewriting rules for direct code generation.
1173//===----------------------------------------------------------------------===//
1174
1175/// Proof-of-concept rewriter. This rule generates a GPU implementation
1176/// for each outermost forall loop generated by the sparsifier.
1177/// TODO: right now works with parallelization-strategy=dense-outer-loop
1178/// but give this its own flags in the future
1179struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1180 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
1181
1182 ForallRewriter(MLIRContext *context, unsigned nT)
1183 : OpRewritePattern(context), numThreads(nT) {};
1184
1185 LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1186 PatternRewriter &rewriter) const override {
1187 // Reject inadmissible loop form.
1188 // Essentially only accept a loop, generated by the sparsifier,
1189 // of the form
1190 // forall (i = 0; i < N; i++)
1191 // so that cyclic scheduling over the threads is easy.
1192 if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1193 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1194 !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1195 !matchPattern(forallOp.getStep()[0], m_One()))
1196 return failure();
1197 // Collect every value that is computed outside the parallel loop.
1198 SetVector<Value> invariants; // stable iteration!
1199 forallOp->walk([&](Operation *op) {
1200 // Collect all values of admissible ops.
1201 for (OpOperand &o : op->getOpOperands()) {
1202 Value val = o.get();
1203 Block *block;
1204 if (auto arg = dyn_cast<BlockArgument>(val))
1205 block = arg.getOwner();
1206 else
1207 block = val.getDefiningOp()->getBlock();
1208 if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
1209 invariants.insert(val);
1210 }
1211 });
1212 // Outline the outside values as proper parameters. Fail when sharing
1213 // value between host and device is not straightforward.
1214 SmallVector<Value> constants;
1215 SmallVector<Value> scalars;
1216 SmallVector<Value> buffers;
1217 // A boolean mask aligned 1:1 with the `buffers` array, tracking which
1218 // of those buffers were mutated by the loop. If true, the corresponding
1219 // buffer needs to be "copied back" using a device-to-host copy.
1220 SmallVector<bool> copyBack;
1221 for (Value val : invariants) {
1222 Type tp = val.getType();
1223 if (val.getDefiningOp<arith::ConstantOp>())
1224 constants.push_back(val);
1225 else if (isa<FloatType>(tp) || tp.isIntOrIndex())
1226 scalars.push_back(val);
1227 else if (isa<MemRefType>(tp)) {
1228 buffers.push_back(val);
1229
1230 // Determine if the buffer needs to be "copied back" from device
1231 // to host by checking for `memref.store` and the write memory effect.
1232 bool isWrite = false;
1233 for (Operation *user : val.getUsers()) {
1234 if (isa<memref::StoreOp>(user)) {
1235 isWrite = true;
1236 break;
1237 }
1238 if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(user)) {
1239 if (memInterface.getEffectOnValue<MemoryEffects::Write>(val)) {
1240 isWrite = true;
1241 break;
1242 }
1243 }
1244 }
1245 copyBack.push_back(isWrite);
1246 } else
1247 return failure(); // don't know how to share
1248 }
1249 // Pass outlined non-constant values.
1250 // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1251 // keep the feature at all (either through a heuristic or compiler
1252 // option for gpu codegen).
1253 Location loc = forallOp->getLoc();
1254 SmallVector<Value> args;
1255 SmallVector<Value> tokens;
1256 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1257 /*useHostRegistrationForOut=*/false);
1258 // Set up GPU module and construct GPU function.
1259 auto saveIp = rewriter.saveInsertionPoint();
1260 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1261 auto gpuModule = genGPUModule(rewriter, topModule);
1262 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1263 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1264 // Generate code that launches the kernel asynchronously, blocking on all
1265 // opens tokens and yielding a new token for the output.
1266 // TODO: Passing in tokens to launch up does not seem to be properly lowered
1267 // by cubin yet, hence the current blocking wait.
1268 rewriter.restoreInsertionPoint(saveIp);
1269 genBlockingWait(rewriter, loc, tokens);
1270 tokens.clear();
1271 Value kernelToken =
1272 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1273 // Finalize the outlined arguments.
1274 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1275 tokens, copyBack);
1276 genBlockingWait(rewriter, loc, tokens);
1277 rewriter.eraseOp(forallOp);
1278 return success();
1279 }
1280
1281private:
1282 unsigned numThreads;
1283};
1284
1285//===----------------------------------------------------------------------===//
1286// Rewriting rules for library recognition and code generation.
1287//===----------------------------------------------------------------------===//
1288
1289/// Proof-of-concept rewriter. This rule recognizes certain math kernels
1290/// and replaces these with corresponding calls into a sparse library.
1291struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1292 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1293
1294 LinalgOpRewriter(MLIRContext *context, bool rt)
1295 : OpRewritePattern(context), enableRT(rt) {}
1296
1297 LogicalResult matchAndRewrite(linalg::GenericOp op,
1298 PatternRewriter &rewriter) const override {
1299 if (op.getNumDpsInits() != 1)
1300 return failure(); // reject multi-output
1301
1302 const unsigned numLoops = op.getNumLoops();
1303 const unsigned numTensors = op->getNumOperands();
1304 const auto iteratorTypes = op.getIteratorTypesArray();
1305 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1306
1307 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1308 auto infer = [&](MapList m) {
1309 return AffineMap::inferFromExprList(m, op.getContext());
1310 };
1311 AffineExpr i, j, k;
1312 bindDims(getContext(), i, j, k);
1313
1314 // TODO: more robust patterns, transposed versions, more kernels,
1315 // identify alpha and beta and pass them to the CUDA calls.
1316
1317 // Recognize a SpMV kernel.
1318 if (numLoops == 2 && numTensors == 3 &&
1319 linalg::isParallelIterator(iteratorTypes[0]) &&
1320 linalg::isReductionIterator(iteratorTypes[1]) &&
1321 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1322 return rewriteSpMV(rewriter, op, enableRT);
1323 }
1324
1325 // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1326 if (numLoops == 3 && numTensors == 3 &&
1327 linalg::isParallelIterator(iteratorTypes[0]) &&
1328 linalg::isParallelIterator(iteratorTypes[1]) &&
1329 linalg::isReductionIterator(iteratorTypes[2]) &&
1330 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1331 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1332 return rewriteSpGEMM(rewriter, op, enableRT);
1333 if (isConversionInto24(op.getOperand(0)))
1334 return rewrite2To4SpMM(rewriter, op);
1335 return rewriteSpMM(rewriter, op, enableRT);
1336 }
1337
1338 // Recognize a SDDMM kernel.
1339 if (numLoops == 3 && numTensors == 3 &&
1340 linalg::isParallelIterator(iteratorTypes[0]) &&
1341 linalg::isParallelIterator(iteratorTypes[1]) &&
1342 linalg::isReductionIterator(iteratorTypes[2]) &&
1343 maps == infer({{i, k}, {k, j}, {i, j}}) &&
1344 matchSumReductionOfMulUnary(op)) {
1345 return rewriteSDDMM(rewriter, op, enableRT);
1346 }
1347
1348 return failure();
1349 }
1350
1351private:
1352 bool enableRT;
1353};
1354
1355} // namespace
1356
1357//===----------------------------------------------------------------------===//
1358// Public method for populating GPU rewriting rules.
1359//
1360// Currently two set of rewriting rules are made available. The first set
1361// implements direct code generation, currently by means of convering the
1362// outermost paralell loop into GPU threads. The second set implements
1363// libary recognition of a set of sparse operations. Eventually, the right
1364// combination of these two approaches has to be found.
1365//===----------------------------------------------------------------------===//
1366
1368 unsigned numThreads) {
1369 patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
1370}
1371
1373 bool enableRT) {
1374 patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
1375}
return success()
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
#define mul(a, b)
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation & front()
Definition Block.h:163
BlockArgListType getArguments()
Definition Block.h:97
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
IndexType getIndexType()
Definition Builders.cpp:55
IntegerType getI8Type()
Definition Builders.cpp:63
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition Builders.h:387
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition Builders.h:392
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:593
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static constexpr llvm::StringLiteral getLoopEmitterLoopAttrName()
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
bool isPermutation() const
Returns true if the dimToLvl mapping is a permutation.
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:232
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39