MLIR  17.0.0git
GPUTransformOps.cpp
Go to the documentation of this file.
1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/IRMapping.h"
21 
22 using namespace mlir;
23 using namespace mlir::gpu;
24 using namespace mlir::transform;
25 
26 /// Check if given mapping attributes are one of the desired attributes
29  const std::optional<ArrayAttr> &foreachMapping,
30  std::optional<TransformOpInterface> transformOp) {
31  if (!foreachMapping.has_value())
32  return transformOp->emitSilenceableError() << "mapping must be present";
33 
35  for (Attribute map : foreachMapping->getValue()) {
36  if (!llvm::is_contained(threadMappingAttributes, map)) {
37  return transformOp->emitDefiniteFailure()
38  << "mapping must be one of " << threadMappingAttributes;
39  }
40  if (llvm::is_contained(seen, map)) {
41  return transformOp->emitDefiniteFailure()
42  << map
43  << " is duplicated, cannot map different "
44  "loops to the same processor";
45  }
46  seen.insert(map);
47  }
48 
50 }
51 
52 /// Determines if the size of the kernel configuration is supported by the GPU
53 /// architecture being used. It presently makes use of CUDA limitations, however
54 /// that aspect may be enhanced for other GPUs.
56  TransformOpInterface transformOp, std::optional<int64_t> gridDimX,
57  std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
58  std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
59  std::optional<int64_t> blockDimZ) {
60 
61  static constexpr int maxTotalBlockdim = 1024;
62  static constexpr int maxBlockdimx = 1024;
63  static constexpr int maxBlockdimy = 1024;
64  static constexpr int maxBlockdimz = 64;
65  static constexpr int maxTotalGriddim = 2147483647;
66  static constexpr int maxGriddimx = 2147483647;
67  static constexpr int maxGriddimy = 65535;
68  static constexpr int maxGriddimz = 65535;
69 
70  if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
71  maxTotalBlockdim ||
72  (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) >
73  maxTotalGriddim ||
74  blockDimX.value_or(1) > maxBlockdimx ||
75  blockDimY.value_or(1) > maxBlockdimy ||
76  blockDimZ.value_or(1) > maxBlockdimz ||
77  gridDimY.value_or(1) > maxGriddimy ||
78  gridDimZ.value_or(1) > maxGriddimz ||
79  gridDimX.value_or(1) > maxGriddimx) {
80  return transformOp.emitSilenceableError()
81  << "Trying to launch a GPU kernel with gridDim = ("
82  << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", "
83  << gridDimZ.value_or(1) << ") blockDim = (" << blockDimX.value_or(1)
84  << ", " << blockDimY.value_or(1) << ", " << blockDimZ.value_or(1)
85  << "). It is larger than the limits.";
86  }
88 }
89 
90 /// Creates an empty-body gpu::LaunchOp using the provided kernel settings and
91 /// put a terminator within.
94  TransformOpInterface transformOp, LaunchOp &launchOp,
95  std::optional<int64_t> gridDimX = std::nullopt,
96  std::optional<int64_t> gridDimY = std::nullopt,
97  std::optional<int64_t> gridDimZ = std::nullopt,
98  std::optional<int64_t> blockDimX = std::nullopt,
99  std::optional<int64_t> blockDimY = std::nullopt,
100  std::optional<int64_t> blockDimZ = std::nullopt) {
102  checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
103  blockDimY, blockDimZ);
104  if (!diag.succeeded())
105  return diag;
106 
107  auto createConst = [&](int dim) {
108  return rewriter.create<arith::ConstantIndexOp>(loc, dim);
109  };
110  OpBuilder::InsertionGuard guard(rewriter);
111  Value one = createConst(1);
112  Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one;
113  Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one;
114  Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one;
115  Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one;
116  Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one;
117  Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one;
118  launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ,
119  blkSizeX, blkSizeY, blkSizeZ);
120  rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
121  rewriter.create<TerminatorOp>(loc);
123 }
124 
125 /// Alter kernel configuration of the given kernel.
127 alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch,
128  TransformOpInterface transformOp,
129  std::optional<int64_t> gridDimX = std::nullopt,
130  std::optional<int64_t> gridDimY = std::nullopt,
131  std::optional<int64_t> gridDimZ = std::nullopt,
132  std::optional<int64_t> blockDimX = std::nullopt,
133  std::optional<int64_t> blockDimY = std::nullopt,
134  std::optional<int64_t> blockDimZ = std::nullopt) {
136  checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
137  blockDimY, blockDimZ);
138  if (!diag.succeeded())
139  return diag;
140 
141  KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
142  OpBuilder::InsertionGuard guard(rewriter);
143  rewriter.setInsertionPointAfterValue(currentBlockdim.x);
144  auto createConstValue = [&](int dim) {
145  return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
146  dim);
147  };
148 
149  if (gridDimX.has_value())
150  gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value()));
151  if (gridDimY.has_value())
152  gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value()));
153  if (gridDimZ.has_value())
154  gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value()));
155  if (blockDimX.has_value())
156  gpuLaunch.getBlockSizeXMutable().assign(
157  createConstValue(blockDimX.value()));
158  if (blockDimY.has_value())
159  gpuLaunch.getBlockSizeYMutable().assign(
160  createConstValue(blockDimY.value()));
161  if (blockDimZ.has_value())
162  gpuLaunch.getBlockSizeZMutable().assign(
163  createConstValue(blockDimZ.value()));
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // MapForeachToBlocks
169 //===----------------------------------------------------------------------===//
170 
172  RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
173  function_ref<void(RewriterBase &, scf::ForeachThreadOp,
175  blockIdGenerator,
176  SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
177  const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
178  // Step 0. Target-specific verifications. There is no good place to anchor
179  // those right now: the ForeachThreadOp is target-independent and the
180  // transform op does not apply to individual ForeachThreadOp.
181  Location loc = foreachThreadOp->getLoc();
182 
183  if (foreachThreadOp.getNumResults() > 0)
184  return transformOp.emitSilenceableError()
185  << "only bufferized scf.foreach_thread lowers to "
186  "gpu.block_id";
187  if (foreachThreadOp.getNumThreads().size() > 3)
188  return transformOp.emitSilenceableError()
189  << "scf.foreach_thread with rank > 3 does not lower to "
190  "gpu.block_id";
191  if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
192  return !v.getDefiningOp<arith::ConstantIndexOp>();
193  })) {
194  return transformOp.emitSilenceableError()
195  << "unsupported dynamic griddim size";
196  }
197  SmallVector<Attribute> blockMapping =
198  llvm::to_vector(foreachThreadOp.getMapping()->getValue());
199 
200  // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
201  SmallVector<Value> numBlocks =
202  llvm::to_vector(foreachThreadOp.getNumThreads());
203  // Ensure we have 3 block sizes, one for each id.
204  Value one;
205  for (auto attr : mappingAttributes) {
206  if (std::find(blockMapping.begin(), blockMapping.end(), attr) ==
207  blockMapping.end()) {
208  blockMapping.push_back(attr);
209  one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
210  numBlocks.push_back(one);
211  }
212  }
213 
214  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
215  auto comparator = [&](DeviceMappingAttrInterface a,
216  DeviceMappingAttrInterface b) -> bool {
217  return a.getMappingId() < b.getMappingId();
218  };
219  SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
220  blockMapping, numBlocks, comparator);
221  for (Value v : gridDimValues)
222  gridDims.push_back(v.getDefiningOp<arith::ConstantIndexOp>().value());
223 
224  // Step 3. Generate the blockIds using the provided generator and map the
225  // induction variables to the newly created ops.
226  SmallVector<Value> blockOps;
227  blockIdGenerator(rewriter, foreachThreadOp, blockOps);
228  IRMapping bvm;
229  for (auto [blockIdx, blockDim] :
230  llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
231  bvm.map(blockIdx,
232  blockOps[static_cast<int64_t>(
233  blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
234  }
235 
236  // Step 4. Move the body of foreachThreadOp.
237  // Erase the terminator first, it will not be used since we are on buffers.
238  rewriter.eraseOp(foreachThreadOp.getTerminator());
239  Block *targetBlock = foreachThreadOp->getBlock();
240  Block::iterator insertionPoint = Block::iterator(foreachThreadOp);
241  Block &sourceBlock = foreachThreadOp.getRegion().front();
242  targetBlock->getOperations().splice(insertionPoint,
243  sourceBlock.getOperations());
244 
245  // Step 5. RAUW thread indices to thread ops.
246  for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
247  Value blockIdx = bvm.lookup(loopIndex);
248  rewriter.replaceAllUsesWith(loopIndex, blockIdx);
249  }
250 
251  // Step 6. Erase old op.
252  rewriter.eraseOp(foreachThreadOp);
253 
255 }
256 
258  Operation *target, scf::ForeachThreadOp &topLevelForeachThreadOp,
259  TransformOpInterface transformOp) {
260  auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
261  if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>())
262  return WalkResult::advance();
263  if (topLevelForeachThreadOp)
264  // TODO: Handle multiple foreach if there is no dependences between them
265  return WalkResult::interrupt();
266  topLevelForeachThreadOp = foreachThreadOp;
267  return WalkResult::advance();
268  });
269 
270  if (walkResult.wasInterrupted())
271  return transformOp.emitSilenceableError()
272  << "could not find a unique topLevel scf.foreach_thread";
274 }
275 
276 /// This is a helper that is only used in
277 /// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects
278 /// block_id.
279 static void generateGpuBlockIds(RewriterBase &rewriter,
280  scf::ForeachThreadOp foreachOp,
281  SmallVectorImpl<Value> &blockOps) {
282  Location loc = foreachOp->getLoc();
283  OpBuilder::InsertionGuard guard(rewriter);
284  rewriter.setInsertionPoint(foreachOp);
285  IndexType indexType = rewriter.getIndexType();
286  blockOps = SmallVector<Value>{
287  rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
288  rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
289  rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
290 }
291 
293 transform::MapForeachToBlocks::applyToOne(Operation *target,
294  ApplyToEachResultList &results,
295  transform::TransformState &state) {
296  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
297  TrivialPatternRewriter rewriter(getContext());
298  auto transformOp = cast<TransformOpInterface>(getOperation());
299 
300  if (!getGenerateGpuLaunch() && !gpuLaunch) {
302  emitSilenceableError()
303  << "Given target is not gpu.launch, set `generate_gpu_launch` "
304  "attribute";
305  diag.attachNote(target->getLoc()) << "when applied to this payload op";
306  return diag;
307  }
308 
309  scf::ForeachThreadOp topLevelForeachThreadOp;
312  target, topLevelForeachThreadOp, transformOp);
313  if (!diag.succeeded()) {
314  diag.attachNote(target->getLoc()) << "when applied to this payload op";
315  return diag;
316  }
317 
318  OpBuilder::InsertionGuard guard(rewriter);
319  rewriter.setInsertionPoint(topLevelForeachThreadOp);
320 
321  // Generate gpu launch here and move the foreach_thread inside
322  if (getGenerateGpuLaunch()) {
324  createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
325  if (!diag.succeeded()) {
326  return diag;
327  }
328  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
329  Operation *newForeachThreadOp = rewriter.clone(*topLevelForeachThreadOp);
330  rewriter.eraseOp(topLevelForeachThreadOp);
331  topLevelForeachThreadOp = cast<scf::ForeachThreadOp>(newForeachThreadOp);
332  }
333 
334  SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
335  SmallVector<DeviceMappingAttrInterface> blockMappingAttributes = {
336  GPUBlockMappingAttr::get(getContext(), Blocks::DimX),
337  GPUBlockMappingAttr::get(getContext(), Blocks::DimY),
338  GPUBlockMappingAttr::get(getContext(), Blocks::DimZ)};
339 
340  diag = checkAttributeType(blockMappingAttributes,
341  topLevelForeachThreadOp.getMapping(), transformOp);
342  if (diag.succeeded())
344  rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim,
345  transformOp, blockMappingAttributes);
346  if (diag.succeeded()) {
347  diag = alterGpuLaunch(rewriter, gpuLaunch,
348  cast<TransformOpInterface>(getOperation()),
349  gridDim[0], gridDim[1], gridDim[2]);
350  }
351 
352  results.push_back(gpuLaunch);
353  return diag;
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // MapNestedForeachToThreads
358 //===----------------------------------------------------------------------===//
359 
360 /// Searches `scf.foreach_thread` ops nested under `target` and maps each such
361 /// op to GPU threads. Mapping is one-to-one and the induction variables of
362 /// `scf.foreach_thread` are rewritten to gpu.thread_id according to the
363 /// thread_dim_mapping attribute. Sibling `scf.foreach_thread` are supported in
364 /// which case, the union of the number of threads is computed and may result
365 /// in predication. Dynamic, `scf.foreach_thread` trip counts are currently
366 /// not supported. Dynamic block dim sizes are currently not supported.
368  RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
369  const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
370  std::optional<TransformOpInterface> transformOp,
371  const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
372  // Step 0. Target-specific verifications. There is no good place to anchor
373  // those right now: the ForeachThreadOp is target-independent and the
374  // transform op does not apply to individual ForeachThreadOp.
375  auto failureHelper =
376  [&](const Twine &message) -> DiagnosedSilenceableFailure {
377  if (transformOp.has_value()) {
378  return transformOp->emitSilenceableError() << message;
379  }
380  return emitDefiniteFailure(foreachThreadOp, message);
381  };
382  Location loc = foreachThreadOp->getLoc();
383  if (foreachThreadOp.getNumResults() > 0)
384  return failureHelper(
385  "only bufferized scf.foreach_thread lowers to gpu.thread_id");
386  if (foreachThreadOp.getNumThreads().size() > 3)
387  return failureHelper(
388  "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
389  if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
390  return !v.getDefiningOp<arith::ConstantIndexOp>();
391  })) {
392  return failureHelper("unsupported dynamic blockdim size");
393  }
394  if (!foreachThreadOp.getMapping().has_value())
395  return failureHelper("mapping must be present");
396  SmallVector<Attribute> threadMapping =
397  llvm::to_vector(foreachThreadOp.getMapping()->getValue());
398 
399  // Step 1. Complete the threadMapping to a full mapping (with 1s) if
400  // necessary.
401  SmallVector<Value> numThreads =
402  llvm::to_vector(foreachThreadOp.getNumThreads());
403  // Ensure we have 3 block sizes, one for each id.
404  Value one;
405  for (auto attr : threadMappingAttributes) {
406  if (std::find(threadMapping.begin(), threadMapping.end(), attr) ==
407  threadMapping.end()) {
408  threadMapping.push_back(attr);
409  one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
410  numThreads.push_back(one);
411  }
412  }
413 
414  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
415  auto comparator = [&](DeviceMappingAttrInterface a,
416  DeviceMappingAttrInterface b) -> bool {
417  return a.getMappingId() < b.getMappingId();
418  };
419  SmallVector<Value> blockDimValues =
420  scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads,
421  comparator);
422  SmallVector<int64_t> blockDims =
423  llvm::to_vector(llvm::map_range(blockDimValues, [](Value v) {
424  return v.getDefiningOp<arith::ConstantIndexOp>().value();
425  }));
426 
427  // Step 3. Create the gpu.thread ops and map the induction variables to the
428  // newly created ops.
429  IndexType indexType = rewriter.getIndexType();
430  SmallVector<Value> threadOps{
431  rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
432  rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
433  rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
434  // Replace ids of dimension size 1 by zero to simplify the IR.
435  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
436  for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) {
437  if (globalBlockDims[i] == 1)
438  threadOps[i] = zero;
439  }
440  IRMapping bvm;
441  for (auto [blockIdx, blockDim] :
442  llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
443  bvm.map(
444  blockIdx,
445  threadOps[blockDim.cast<DeviceMappingAttrInterface>().getMappingId()]);
446  }
447 
448  // Step 4. Maybe create conditionals to predicate the region.
449  Value predicate;
450  for (auto [threadId, blockDim, globalBlockDim] :
451  llvm::zip(threadOps, blockDims, globalBlockDims)) {
452  if (blockDim > globalBlockDim) {
453  return failureHelper(
454  "The requested GPU threads are fewer than the number of loop trip "
455  "counts. Try to tile scf.foreach_thread before mapping or set "
456  "small blockDim.");
457  }
458  if (blockDim == globalBlockDim)
459  continue;
460  Value blockIdx = rewriter.create<arith::ConstantIndexOp>(loc, blockDim);
461  Value tmpPredicate = rewriter.create<arith::CmpIOp>(
462  loc, arith::CmpIPredicate::ult, threadId, blockIdx);
463  predicate =
464  predicate ? rewriter.create<arith::AndIOp>(loc, predicate, tmpPredicate)
465  : tmpPredicate;
466  }
467 
468  // Step 5. Move the body of foreachThreadOp.
469  // Erase the terminator first, it will not be used.
470  rewriter.eraseOp(foreachThreadOp.getTerminator());
471  Block *targetBlock;
472  Block::iterator insertionPoint;
473  if (predicate) {
474  // Step 5.a. If predicated, move at the beginning.
475  auto ifOp =
476  rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
477  targetBlock = ifOp.thenBlock();
478  insertionPoint = ifOp.thenBlock()->begin();
479  } else {
480  // Step 5.b. Otherwise, move inline just before foreachThreadOp.
481  targetBlock = foreachThreadOp->getBlock();
482  insertionPoint = Block::iterator(foreachThreadOp);
483  }
484  Block &sourceBlock = foreachThreadOp.getRegion().front();
485  targetBlock->getOperations().splice(insertionPoint,
486  sourceBlock.getOperations());
487 
488  // Step 6. RAUW thread indices to thread ops.
489  for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
490  Value threadIdx = bvm.lookup(loopIndex);
491  rewriter.replaceAllUsesWith(loopIndex, threadIdx);
492  }
493 
494  // Step 7. syncthreads.
495  // TODO: Need warpsync
496  if (syncAfterDistribute)
497  rewriter.create<BarrierOp>(loc);
498 
499  // Step 8. Erase old op.
500  rewriter.eraseOp(foreachThreadOp);
501 
503 }
504 
506  RewriterBase &rewriter, Operation *target,
507  const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
508  std::optional<TransformOpInterface> transformOp,
509  const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
511  target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
512  diag = checkAttributeType(threadMappingAttributes,
513  foreachThreadOp.getMapping(), transformOp);
514  if (diag.succeeded()) {
515  rewriter.setInsertionPoint(foreachThreadOp);
517  rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
518  threadMappingAttributes);
519  }
520  return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
521  });
522  return diag;
523 }
524 
525 DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
526  Operation *target, ApplyToEachResultList &results, TransformState &state) {
527  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
528  auto transformOp = cast<TransformOpInterface>(getOperation());
529 
530  if (!gpuLaunch) {
531  return emitSilenceableError() << "Given target is not gpu.launch";
532  }
533 
534  SmallVector<int64_t> blockDim = extractFromI64ArrayAttr(getBlockDim());
535  blockDim.resize(/*size=*/3, /*value=*/1);
536 
538  checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
539  blockDim[0], blockDim[1], blockDim[2]);
540  if (diag.isSilenceableFailure()) {
541  diag.attachNote(getLoc()) << getBlockDimAttrName() << " is very large";
542  return diag;
543  }
544 
545  MLIRContext *ctx = getContext();
546  TrivialPatternRewriter rewriter(ctx);
547  rewriter.setInsertionPoint(target);
548 
549  SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
550  GPUThreadMappingAttr::get(ctx, Threads::DimX),
551  GPUThreadMappingAttr::get(ctx, Threads::DimY),
552  GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
553 
555  rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
556  threadMappingAttributes);
557 
558  if (diag.succeeded()) {
559  diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
560  std::nullopt, std::nullopt, blockDim[0], blockDim[1],
561  blockDim[2]);
562  }
563 
564  results.push_back(gpuLaunch.getOperation());
565  return diag;
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // Transform op registration
570 //===----------------------------------------------------------------------===//
571 
572 namespace {
573 /// Registers new ops and declares PDL as dependent dialect since the
574 /// additional ops are using PDL types for operands and results.
575 class GPUTransformDialectExtension
577  GPUTransformDialectExtension> {
578 public:
579  GPUTransformDialectExtension() {
580  declareDependentDialect<pdl::PDLDialect>();
581  declareGeneratedDialect<scf::SCFDialect>();
582  declareGeneratedDialect<arith::ArithDialect>();
583  declareGeneratedDialect<GPUDialect>();
584  registerTransformOps<
585 #define GET_OP_LIST
586 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
587  >();
588  }
589 };
590 } // namespace
591 
592 #define GET_OP_CLASSES
593 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
594 
596  registry.addExtensions<GPUTransformDialectExtension>();
597 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:25
static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, const SmallVectorImpl< int64_t > &globalBlockDims, bool syncAfterDistribute, std::optional< TransformOpInterface > transformOp, const ArrayRef< DeviceMappingAttrInterface > &threadMappingAttributes)
Searches scf.foreach_thread ops nested under target and maps each such op to GPU threads.
static DiagnosedSilenceableFailure checkAttributeType(ArrayRef< DeviceMappingAttrInterface > threadMappingAttributes, const std::optional< ArrayAttr > &foreachMapping, std::optional< TransformOpInterface > transformOp)
Check if given mapping attributes are one of the desired attributes.
static DiagnosedSilenceableFailure alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Alter kernel configuration of the given kernel.
static DiagnosedSilenceableFailure createGpuLaunch(RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, LaunchOp &launchOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Creates an empty-body gpu::LaunchOp using the provided kernel settings and put a terminator within.
static DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, std::optional< int64_t > gridDimX, std::optional< int64_t > gridDimY, std::optional< int64_t > gridDimZ, std::optional< int64_t > blockDimX, std::optional< int64_t > blockDimY, std::optional< int64_t > blockDimZ)
Determines if the size of the kernel configuration is supported by the GPU architecture being used.
static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForeachThreadOp foreachOp, SmallVectorImpl< Value > &blockOps)
This is a helper that is only used in rewriteTopLevelForeachThreadToGpuBlocks.
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
OpListType & getOperations()
Definition: Block.h:126
Operation & front()
Definition: Block.h:142
iterator begin()
Definition: Block.h:132
IndexType getIndexType()
Definition: Builders.cpp:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:301
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:510
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:384
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:351
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:389
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:374
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:198
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:620
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
The state maintained across applications of various ops implementing the TransformOpInterface.
A simple pattern rewriter that can be constructed from a context.
void registerTransformDialectExtension(DialectRegistry &registry)
DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(RewriterBase &rewriter, Operation *target, const SmallVectorImpl< int64_t > &blockDim, bool syncAfterDistribute, std::optional< TransformOpInterface > transformOp, const ArrayRef< DeviceMappingAttrInterface > &threadMappingAttributes)
Searches scf.foreach_thread ops nested under target and maps each such op to GPU threads.
DiagnosedSilenceableFailure mapForeachToBlocksImpl(RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, function_ref< void(RewriterBase &, scf::ForeachThreadOp, SmallVectorImpl< Value > &)> blockIdGenerator, SmallVectorImpl< int64_t > &gridDims, TransformOpInterface transformOp, const ArrayRef< DeviceMappingAttrInterface > &mappingAttributes)
Maps the top level scf.foreach_thread op to GPU Thread Blocks.
DiagnosedSilenceableFailure findTopLevelForeachThreadOp(Operation *target, scf::ForeachThreadOp &topLevelForeachThreadOp, TransformOpInterface transformOp)
Finds the top level scf::ForeachThreadOp of given target.
Include the generated interface declarations.
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:35