MLIR  18.0.0git
GPUTransformOps.cpp
Go to the documentation of this file.
1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/MLIRContext.h"
34 #include "mlir/IR/OpDefinition.h"
35 #include "mlir/IR/Visitors.h"
36 #include "mlir/Support/LLVM.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/ErrorHandling.h"
43 #include <type_traits>
44 
45 using namespace mlir;
46 using namespace mlir::gpu;
47 using namespace mlir::transform;
48 using namespace mlir::transform::gpu;
49 
50 #define DEBUG_TYPE "gpu-transforms"
51 #define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
52 
53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
55 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
56 
57 //===----------------------------------------------------------------------===//
58 // Apply...ConversionPatternsOp
59 //===----------------------------------------------------------------------===//
60 
61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
62  TypeConverter &typeConverter, RewritePatternSet &patterns) {
63  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
64  // NVVM uses alloca in the default address space to represent private
65  // memory allocations, so drop private annotations. NVVM uses address
66  // space 3 for shared memory. NVVM uses the default address space to
67  // represent global memory.
68  // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
69  // TODO: We should have a single to_nvvm_type_converter.
71  llvmTypeConverter, [](AddressSpace space) -> unsigned {
72  switch (space) {
73  case AddressSpace::Global:
74  return static_cast<unsigned>(
76  case AddressSpace::Workgroup:
77  return static_cast<unsigned>(
79  case AddressSpace::Private:
80  return 0;
81  }
82  llvm_unreachable("unknown address space enum value");
83  return 0;
84  });
85  // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
86  // TODO: We should have a single to_nvvm_type_converter.
87  llvmTypeConverter.addConversion(
88  [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
89  populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns);
90 }
91 
93 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94  transform::TypeConverterBuilderOpInterface builder) {
95  if (builder.getTypeConverterType() != "LLVMTypeConverter")
96  return emitOpError("expected LLVMTypeConverter");
97  return success();
98 }
99 
100 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
101  TypeConverter &typeConverter, RewritePatternSet &patterns) {
102  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
103  populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns);
104 }
105 
107 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108  transform::TypeConverterBuilderOpInterface builder) {
109  if (builder.getTypeConverterType() != "LLVMTypeConverter")
110  return emitOpError("expected LLVMTypeConverter");
111  return success();
112 }
113 
114 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
115  populatePatterns(TypeConverter &typeConverter,
116  RewritePatternSet &patterns) {
117  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
118  populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns);
119 }
120 
121 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122  verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123  if (builder.getTypeConverterType() != "LLVMTypeConverter")
124  return emitOpError("expected LLVMTypeConverter");
125  return success();
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // Apply...PatternsOp
130 //===----------------------------------------------------------------------===//s
131 
132 void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
133  populateGpuRewritePatterns(patterns);
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // ApplyUnrollVectorsSubgroupMmaOp
138 //===----------------------------------------------------------------------===//
139 
140 /// Pick an unrolling order that will allow tensorcore operation to reuse LHS
141 /// register.
142 static std::optional<SmallVector<int64_t>>
143 gpuMmaUnrollOrder(vector::ContractionOp contract) {
144  SmallVector<int64_t> order;
145  // First make reduction the outer dimensions.
146  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
147  if (vector::isReductionIterator(iter)) {
148  order.push_back(index);
149  }
150  }
151 
152  llvm::SmallDenseSet<int64_t> dims;
153  for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
154  dims.insert(expr.cast<AffineDimExpr>().getPosition());
155  }
156  // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
157  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
158  if (vector::isParallelIterator(iter) && dims.count(index)) {
159  order.push_back(index);
160  }
161  }
162  // Then the remaining parallel loops.
163  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
164  if (vector::isParallelIterator(iter) && !dims.count(index)) {
165  order.push_back(index);
166  }
167  }
168  return order;
169 }
170 
171 /// Returns the target vector size for the target operation based on the native
172 /// vector size specified with `m`, `n`, and `k`.
173 static std::optional<SmallVector<int64_t>>
174 getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
175  if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
176  int64_t contractRank = contract.getIteratorTypes().size();
177  if (contractRank < 3)
178  return std::nullopt;
179  SmallVector<int64_t> nativeSize(contractRank - 3, 1);
180  nativeSize.append({m, n, k});
181  return nativeSize;
182  }
183  if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
184  int64_t writeRank = writeOp.getVectorType().getRank();
185  if (writeRank < 2)
186  return std::nullopt;
187  SmallVector<int64_t> nativeSize(writeRank - 2, 1);
188  nativeSize.append({m, n});
189  return nativeSize;
190  }
191  if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
192  // Transfer read ops may need different shapes based on how they are being
193  // used. For simplicity just match the shape used by the extract strided op.
194  VectorType sliceType;
195  for (Operation *users : op->getUsers()) {
196  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
197  if (!extract)
198  return std::nullopt;
199  auto vecType = extract.getResult().getType().cast<VectorType>();
200  if (sliceType && sliceType != vecType)
201  return std::nullopt;
202  sliceType = vecType;
203  }
204  return llvm::to_vector(sliceType.getShape());
205  }
206  if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
207  if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
208  // TODO: The condition for unrolling elementwise should be restricted
209  // only to operations that need unrolling (connected to the contract).
210  if (vecType.getRank() < 2)
211  return std::nullopt;
212 
213  // First check whether there is a slice to infer the shape from. This is
214  // required for cases where the accumulator type differs from the input
215  // types, in which case we will see an `arith.ext_` between the contract
216  // and transfer_read which needs to be unrolled.
217  VectorType sliceType;
218  for (Operation *users : op->getUsers()) {
219  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
220  if (!extract)
221  return std::nullopt;
222  auto vecType = extract.getResult().getType().cast<VectorType>();
223  if (sliceType && sliceType != vecType)
224  return std::nullopt;
225  sliceType = vecType;
226  }
227  if (sliceType)
228  return llvm::to_vector(sliceType.getShape());
229 
230  // Else unroll for trailing elementwise.
231  SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
232  // Map elementwise ops to the output shape.
233  nativeSize.append({m, n});
234  return nativeSize;
235  }
236  }
237  return std::nullopt;
238 }
239 
240 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
241  RewritePatternSet &patterns) {
242  auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
243  auto contract = dyn_cast<vector::ContractionOp>(op);
244  if (!contract)
245  return std::nullopt;
246  return gpuMmaUnrollOrder(contract);
247  };
248 
249  int64_t m = getM();
250  int64_t n = getN();
251  int64_t k = getK();
252  auto nativeShapeFn =
253  [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
254  return getSubgroupMmaNativeVectorSize(op, m, n, k);
255  };
257  patterns, vector::UnrollVectorOptions()
258  .setNativeShapeFn(nativeShapeFn)
259  .setUnrollTraversalOrderFn(unrollOrder));
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // EliminateBarriersOp
264 //===----------------------------------------------------------------------===//
265 
266 // The functions below provide interface-like verification, but are too specific
267 // to barrier elimination to become interfaces.
268 
269 /// Implement the MemoryEffectsOpInterface in the suitable way.
271  // memref::AssumeAlignment is conceptually pure, but marking it as such would
272  // make DCE immediately remove it.
273  return isa<memref::AssumeAlignmentOp>(op);
274 }
275 
276 /// Returns `true` if the op is defines the parallel region that is subject to
277 /// barrier synchronization.
279  if (op->hasAttr("__parallel_region_boundary_for_test"))
280  return true;
281 
282  return isa<GPUFuncOp, LaunchOp>(op);
283 }
284 
285 /// Returns `true` if the op behaves like a sequential loop, e.g., the control
286 /// flow "wraps around" from the end of the body region back to its start.
287 static bool isSequentialLoopLike(Operation *op) { return isa<scf::ForOp>(op); }
288 
289 /// Returns `true` if the regions of the op are guaranteed to be executed at
290 /// most once. Thus, if an operation in one of the nested regions of `op` is
291 /// executed than so are all the other operations in this region.
293  return isa<scf::IfOp, memref::AllocaScopeOp>(op);
294 }
295 
296 /// Returns `true` if the operation is known to produce a pointer-like object
297 /// distinct from any other object produced by a similar operation. For example,
298 /// an allocation produces such an object.
299 static bool producesDistinctBase(Operation *op) {
300  return isa_and_nonnull<memref::AllocOp, memref::AllocaOp>(op);
301 }
302 
303 /// Populates `effects` with all memory effects without associating them to a
304 /// specific value.
307  effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
308  effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
309  effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Allocate>());
310  effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
311 }
312 
313 /// Collect the memory effects of the given op in 'effects'. Returns 'true' if
314 /// it could extract the effect information from the op, otherwise returns
315 /// 'false' and conservatively populates the list with all possible effects
316 /// associated with no particular value or symbol.
317 static bool
320  bool ignoreBarriers = true) {
321  // Skip over barriers to avoid infinite recursion (those barriers would ask
322  // this barrier again).
323  if (ignoreBarriers && isa<BarrierOp>(op))
324  return true;
325 
326  // Skip over ops that we know have no effects.
328  return true;
329 
330  // Collect effect instances the operation. Note that the implementation of
331  // getEffects erases all effect instances that have the type other than the
332  // template parameter so we collect them first in a local buffer and then
333  // copy.
334  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
336  iface.getEffects(localEffects);
337  llvm::append_range(effects, localEffects);
338  return true;
339  }
341  for (auto &region : op->getRegions()) {
342  for (auto &block : region) {
343  for (auto &innerOp : block)
344  if (!collectEffects(&innerOp, effects, ignoreBarriers))
345  return false;
346  }
347  }
348  return true;
349  }
350 
351  // We need to be conservative here in case the op doesn't have the interface
352  // and assume it can have any possible effect.
353  addAllValuelessEffects(effects);
354  return false;
355 }
356 
357 /// Collects memory effects from operations that may be executed before `op` in
358 /// a trivial structured control flow, e.g., without branches. Stops at the
359 /// parallel region boundary or at the barrier operation if `stopAtBarrier` is
360 /// set. Returns `true` if the memory effects added to `effects` are exact,
361 /// `false` if they are a conservative over-approximation. The latter means that
362 /// `effects` contain instances not associated with a specific value.
363 static bool
366  bool stopAtBarrier) {
367  if (!op->getBlock())
368  return true;
369 
370  // If there is a non-structured control flow, bail.
371  Region *region = op->getBlock()->getParent();
372  if (region && !llvm::hasSingleElement(region->getBlocks())) {
373  addAllValuelessEffects(effects);
374  return false;
375  }
376 
377  // Collect all effects before the op.
378  if (op != &op->getBlock()->front()) {
379  for (Operation *it = op->getPrevNode(); it != nullptr;
380  it = it->getPrevNode()) {
381  if (isa<BarrierOp>(it)) {
382  if (stopAtBarrier)
383  return true;
384  else
385  continue;
386  }
387  if (!collectEffects(it, effects))
388  return false;
389  }
390  }
391 
392  // Stop if reached the parallel region boundary.
394  return true;
395 
396  // Otherwise, keep collecting above the parent operation.
397  if (!getEffectsBefore(op->getParentOp(), effects, stopAtBarrier))
398  return false;
399 
400  // If the op is loop-like, collect effects from the trailing operations until
401  // we hit a barrier because they can executed before the current operation by
402  // the previous iteration of this loop. For example, in the following loop
403  //
404  // for i = ... {
405  // op1
406  // ...
407  // barrier
408  // op2
409  // }
410  //
411  // the operation `op2` at iteration `i` is known to be executed before the
412  // operation `op1` at iteration `i+1` and the side effects must be ordered
413  // appropriately.
414  if (isSequentialLoopLike(op->getParentOp())) {
415  // Assuming loop terminators have no side effects.
416  return getEffectsBefore(op->getBlock()->getTerminator(), effects,
417  /*stopAtBarrier=*/true);
418  }
419 
420  // If the parent operation is not guaranteed to execute its (single-block)
421  // region once, walk the block.
422  bool conservative = false;
424  op->getParentOp()->walk([&](Operation *in) {
425  if (conservative)
426  return WalkResult::interrupt();
427  if (!collectEffects(in, effects)) {
428  conservative = true;
429  return WalkResult::interrupt();
430  }
431  return WalkResult::advance();
432  });
433 
434  return !conservative;
435 }
436 
437 /// Collects memory effects from operations that may be executed after `op` in
438 /// a trivial structured control flow, e.g., without branches. Stops at the
439 /// parallel region boundary or at the barrier operation if `stopAtBarrier` is
440 /// set. Returns `true` if the memory effects added to `effects` are exact,
441 /// `false` if they are a conservative over-approximation. The latter means that
442 /// `effects` contain instances not associated with a specific value.
443 static bool
446  bool stopAtBarrier) {
447  if (!op->getBlock())
448  return true;
449 
450  // If there is a non-structured control flow, bail.
451  Region *region = op->getBlock()->getParent();
452  if (region && !llvm::hasSingleElement(region->getBlocks())) {
453  addAllValuelessEffects(effects);
454  return false;
455  }
456 
457  // Collect all effects after the op.
458  if (op != &op->getBlock()->back())
459  for (Operation *it = op->getNextNode(); it != nullptr;
460  it = it->getNextNode()) {
461  if (isa<BarrierOp>(it)) {
462  if (stopAtBarrier)
463  return true;
464  continue;
465  }
466  if (!collectEffects(it, effects))
467  return false;
468  }
469 
470  // Stop if reached the parallel region boundary.
472  return true;
473 
474  // Otherwise, keep collecting below the parent operation.
475  if (!getEffectsAfter(op->getParentOp(), effects, stopAtBarrier))
476  return false;
477 
478  // If the op is loop-like, collect effects from the leading operations until
479  // we hit a barrier because they can executed after the current operation by
480  // the next iteration of this loop. For example, in the following loop
481  //
482  // for i = ... {
483  // op1
484  // ...
485  // barrier
486  // op2
487  // }
488  //
489  // the operation `op1` at iteration `i` is known to be executed after the
490  // operation `op2` at iteration `i-1` and the side effects must be ordered
491  // appropriately.
492  if (isSequentialLoopLike(op->getParentOp())) {
493  if (isa<BarrierOp>(op->getBlock()->front()))
494  return true;
495 
496  bool exact = collectEffects(&op->getBlock()->front(), effects);
497  return getEffectsAfter(&op->getBlock()->front(), effects,
498  /*stopAtBarrier=*/true) &&
499  exact;
500  }
501 
502  // If the parent operation is not guaranteed to execute its (single-block)
503  // region once, walk the block.
504  bool conservative = false;
506  op->getParentOp()->walk([&](Operation *in) {
507  if (conservative)
508  return WalkResult::interrupt();
509  if (!collectEffects(in, effects)) {
510  conservative = true;
511  return WalkResult::interrupt();
512  }
513  return WalkResult::advance();
514  });
515 
516  return !conservative;
517 }
518 
519 /// Looks through known "view-like" ops to find the base memref.
520 static Value getBase(Value v) {
521  while (true) {
522  Operation *definingOp = v.getDefiningOp();
523  if (!definingOp)
524  break;
525 
526  bool shouldContinue =
528  .Case<memref::CastOp, memref::SubViewOp, memref::ViewOp>(
529  [&](auto op) {
530  v = op.getSource();
531  return true;
532  })
533  .Case<memref::TransposeOp>([&](auto op) {
534  v = op.getIn();
535  return true;
536  })
537  .Case<memref::CollapseShapeOp, memref::ExpandShapeOp>([&](auto op) {
538  v = op.getSrc();
539  return true;
540  })
541  .Default([](Operation *) { return false; });
542  if (!shouldContinue)
543  break;
544  }
545  return v;
546 }
547 
548 /// Returns `true` if the value is defined as a function argument.
549 static bool isFunctionArgument(Value v) {
550  auto arg = dyn_cast<BlockArgument>(v);
551  return arg && isa<FunctionOpInterface>(arg.getOwner()->getParentOp());
552 }
553 
554 /// Returns the operand that the operation "propagates" through it for capture
555 /// purposes. That is, if the value produced by this operation is captured, then
556 /// so is the returned value.
559  .Case(
560  [](ViewLikeOpInterface viewLike) { return viewLike.getViewSource(); })
561  .Case([](CastOpInterface castLike) { return castLike->getOperand(0); })
562  .Case([](memref::TransposeOp transpose) { return transpose.getIn(); })
563  .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(
564  [](auto op) { return op.getSrc(); })
565  .Default([](Operation *) { return Value(); });
566 }
567 
568 /// Returns `true` if the given operation is known to capture the given value,
569 /// `false` if it is known not to capture the given value, `nullopt` if neither
570 /// is known.
571 static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) {
573  // Store-like operations don't capture the destination, but do capture
574  // the value.
575  .Case<memref::StoreOp, vector::TransferWriteOp>(
576  [&](auto op) { return op.getValue() == v; })
577  .Case<vector::StoreOp, vector::MaskedStoreOp>(
578  [&](auto op) { return op.getValueToStore() == v; })
579  // These operations are known not to capture.
580  .Case([](memref::DeallocOp) { return false; })
581  // By default, we don't know anything.
582  .Default([](Operation *) { return std::nullopt; });
583 }
584 
585 /// Returns `true` if the value may be captured by any of its users, i.e., if
586 /// the user may be storing this value into memory. This makes aliasing analysis
587 /// more conservative as it cannot assume the pointer-like value is only passed
588 /// around through SSA use-def.
589 static bool maybeCaptured(Value v) {
590  SmallVector<Value> todo = {v};
591  while (!todo.empty()) {
592  Value v = todo.pop_back_val();
593  for (Operation *user : v.getUsers()) {
594  // A user that is known to only read cannot capture.
595  auto iface = dyn_cast<MemoryEffectOpInterface>(user);
596  if (iface) {
598  iface.getEffects(effects);
599  if (llvm::all_of(effects,
600  [](const MemoryEffects::EffectInstance &effect) {
601  return isa<MemoryEffects::Read>(effect.getEffect());
602  })) {
603  continue;
604  }
605  }
606 
607  // When an operation is known to create an alias, consider if the
608  // source is captured as well.
609  if (Value v = propagatesCapture(user)) {
610  todo.push_back(v);
611  continue;
612  }
613 
614  std::optional<bool> knownCaptureStatus = getKnownCapturingStatus(user, v);
615  if (!knownCaptureStatus || *knownCaptureStatus)
616  return true;
617  }
618  }
619 
620  return false;
621 }
622 
623 /// Returns true if two values may be referencing aliasing memory. This is a
624 /// rather naive and conservative analysis. Values defined by different
625 /// allocation-like operations as well as values derived from those by casts and
626 /// views cannot alias each other. Similarly, values defined by allocations
627 /// inside a function cannot alias function arguments. Global values cannot
628 /// alias each other or local allocations. Values that are captured, i.e.
629 /// themselves potentially stored in memory, are considered as aliasing with
630 /// everything. This seems sufficient to achieve barrier removal in structured
631 /// control flow, more complex cases would require a proper dataflow analysis.
632 static bool mayAlias(Value first, Value second) {
633  DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
634  DBGS_ALIAS() << "checking aliasing between ";
635  DBGS_ALIAS() << first << "\n";
636  DBGS_ALIAS() << " and ";
637  DBGS_ALIAS() << second << "\n";
638  });
639 
640  first = getBase(first);
641  second = getBase(second);
642 
643  DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
644  DBGS_ALIAS() << "base ";
645  DBGS_ALIAS() << first << "\n";
646  DBGS_ALIAS() << " and ";
647  DBGS_ALIAS() << second << "\n";
648  });
649 
650  // Values derived from the same base memref do alias (unless we do a more
651  // advanced analysis to prove non-overlapping accesses).
652  if (first == second) {
653  DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> do alias!\n");
654  return true;
655  }
656 
657  // Different globals cannot alias.
658  if (auto globFirst = first.getDefiningOp<memref::GetGlobalOp>()) {
659  if (auto globSecond = second.getDefiningOp<memref::GetGlobalOp>()) {
660  return globFirst.getNameAttr() == globSecond.getNameAttr();
661  }
662  }
663 
664  // Two function arguments marked as noalias do not alias.
665  auto isNoaliasFuncArgument = [](Value value) {
666  auto bbArg = dyn_cast<BlockArgument>(value);
667  if (!bbArg)
668  return false;
669  auto iface = dyn_cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
670  if (!iface)
671  return false;
672  // TODO: we need a way to not depend on the LLVM dialect here.
673  return iface.getArgAttr(bbArg.getArgNumber(), "llvm.noalias") != nullptr;
674  };
675  if (isNoaliasFuncArgument(first) && isNoaliasFuncArgument(second))
676  return false;
677 
678  bool isDistinct[] = {producesDistinctBase(first.getDefiningOp()),
680  bool isGlobal[] = {first.getDefiningOp<memref::GetGlobalOp>() != nullptr,
681  second.getDefiningOp<memref::GetGlobalOp>() != nullptr};
682 
683  // Non-equivalent distinct bases and globals cannot alias. At this point, we
684  // have already filtered out based on values being equal and global name being
685  // equal.
686  if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1]))
687  return false;
688 
689  bool isArg[] = {isFunctionArgument(first), isFunctionArgument(second)};
690 
691  // Distinct bases (allocations) cannot have been passed as an argument.
692  if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0]))
693  return false;
694 
695  // Non-captured base distinct values cannot conflict with another base value.
696  if (isDistinct[0] && !maybeCaptured(first))
697  return false;
698  if (isDistinct[1] && !maybeCaptured(second))
699  return false;
700 
701  // Otherwise, conservatively assume aliasing.
702  DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n");
703  return true;
704 }
705 
706 /// Returns `true` if the effect may be affecting memory aliasing the value. If
707 /// the effect is not associated with any value, it is assumed to affect all
708 /// memory and therefore aliases with everything.
710  if (Value v = a.getValue()) {
711  return mayAlias(v, v2);
712  }
713  return true;
714 }
715 
716 /// Returns `true` if the two effects may be affecting aliasing memory. If
717 /// an effect is not associated with any value, it is assumed to affect all
718 /// memory and therefore aliases with everything. Effects on different resources
719 /// cannot alias.
722  if (a.getResource()->getResourceID() != b.getResource()->getResourceID())
723  return false;
724  if (Value v2 = b.getValue()) {
725  return mayAlias(a, v2);
726  } else if (Value v = a.getValue()) {
727  return mayAlias(b, v);
728  }
729  return true;
730 }
731 
732 /// Returns `true` if any of the "before" effect instances has a conflict with
733 /// any "after" instance for the purpose of barrier elimination. The effects are
734 /// supposed to be limited to a barrier synchronization scope. A conflict exists
735 /// if effects instances affect aliasing memory locations and at least on of
736 /// then as a write. As an exception, if the non-write effect is an allocation
737 /// effect, there is no conflict since we are only expected to see the
738 /// allocation happening in the same thread and it cannot be accessed from
739 /// another thread without capture (which we do handle in alias analysis).
740 static bool
743  for (const MemoryEffects::EffectInstance &before : beforeEffects) {
744  for (const MemoryEffects::EffectInstance &after : afterEffects) {
745  // If cannot alias, definitely no conflict.
746  if (!mayAlias(before, after))
747  continue;
748 
749  // Read/read is not a conflict.
750  if (isa<MemoryEffects::Read>(before.getEffect()) &&
751  isa<MemoryEffects::Read>(after.getEffect())) {
752  continue;
753  }
754 
755  // Allocate/* is not a conflict since the allocation happens within the
756  // thread context.
757  // TODO: This is not the case for */Free unless the allocation happened in
758  // the thread context, which we could also check for.
759  if (isa<MemoryEffects::Allocate>(before.getEffect()) ||
760  isa<MemoryEffects::Allocate>(after.getEffect())) {
761  continue;
762  }
763 
764  // In the particular case that the before effect is a free, we only have 2
765  // possibilities:
766  // 1. either the program is well-formed and there must be an interleaved
767  // alloc that must limit the scope of effect lookback and we can
768  // safely ignore the free -> read / free -> write and free -> free
769  // conflicts.
770  // 2. either the program is ill-formed and we are in undefined behavior
771  // territory.
772  if (isa<MemoryEffects::Free>(before.getEffect()))
773  continue;
774 
775  // Other kinds of effects create a conflict, e.g. read-after-write.
776  LLVM_DEBUG(
777  DBGS() << "found a conflict between (before): " << before.getValue()
778  << " read:" << isa<MemoryEffects::Read>(before.getEffect())
779  << " write:" << isa<MemoryEffects::Write>(before.getEffect())
780  << " alloc:"
781  << isa<MemoryEffects::Allocate>(before.getEffect()) << " free:"
782  << isa<MemoryEffects::Free>(before.getEffect()) << "\n");
783  LLVM_DEBUG(
784  DBGS() << "and (after): " << after.getValue()
785  << " read:" << isa<MemoryEffects::Read>(after.getEffect())
786  << " write:" << isa<MemoryEffects::Write>(after.getEffect())
787  << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
788  << " free:" << isa<MemoryEffects::Free>(after.getEffect())
789  << "\n");
790  return true;
791  }
792  }
793 
794  return false;
795 }
796 
797 namespace {
798 /// Barrier elimination pattern. If a barrier does not enforce any conflicting
799 /// pair of memory effects, including a pair that is enforced by another
800 /// barrier, it is unnecessary and can be removed. Adapted from
801 /// "High-Performance GPU-to-CPU Transpilation and Optimization via High-Level
802 /// Parallel Constructs" by Moses, Ivanov, Domke, Endo, Doerfert, and Zinenko in
803 /// PPoPP 2023 and implementation in Polygeist.
804 class BarrierElimination final : public OpRewritePattern<BarrierOp> {
805 public:
807 
808  LogicalResult matchAndRewrite(BarrierOp barrier,
809  PatternRewriter &rewriter) const override {
810  LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " "
811  << barrier.getLoc() << "\n");
812 
814  getEffectsBefore(barrier, beforeEffects, /*stopAtBarrier=*/true);
815 
817  getEffectsAfter(barrier, afterEffects, /*stopAtBarrier=*/true);
818 
819  if (!haveConflictingEffects(beforeEffects, afterEffects)) {
820  LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing "
821  << barrier << "\n");
822  rewriter.eraseOp(barrier);
823  return success();
824  }
825 
826  LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " "
827  << barrier.getLoc() << "\n");
828  return failure();
829  }
830 };
831 } // namespace
832 
833 void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
834  patterns.insert<BarrierElimination>(getContext());
835 }
836 
837 //===----------------------------------------------------------------------===//
838 // Block and thread mapping utilities.
839 //===----------------------------------------------------------------------===//
840 
841 namespace {
842 /// Local types used for mapping verification.
843 struct MappingKind {};
844 struct BlockMappingKind : MappingKind {};
845 struct ThreadMappingKind : MappingKind {};
846 } // namespace
847 
849 definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
850  Operation *target, const Twine &message) {
851  if (transformOp.has_value())
852  return transformOp->emitDefiniteFailure() << message;
853  return emitDefiniteFailure(target, message);
854 }
855 
856 /// Check if given mapping attributes are one of the desired attributes
857 template <typename MappingKindType>
859 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
860  scf::ForallOp forallOp) {
861  if (!forallOp.getMapping().has_value()) {
862  return definiteFailureHelper(transformOp, forallOp,
863  "scf.forall op requires a mapping attribute");
864  }
865 
866  bool hasBlockMapping =
867  llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
868  return isa<GPUBlockMappingAttr>(attr);
869  });
870  bool hasWarpgroupMapping =
871  llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
872  return isa<GPUWarpgroupMappingAttr>(attr);
873  });
874  bool hasWarpMapping =
875  llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
876  return isa<GPUWarpMappingAttr>(attr);
877  });
878  bool hasThreadMapping =
879  llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
880  return isa<GPUThreadMappingAttr>(attr);
881  });
882  int64_t countMappingTypes = 0;
883  countMappingTypes += hasBlockMapping ? 1 : 0;
884  countMappingTypes += hasWarpgroupMapping ? 1 : 0;
885  countMappingTypes += hasWarpMapping ? 1 : 0;
886  countMappingTypes += hasThreadMapping ? 1 : 0;
887  if (countMappingTypes > 1) {
888  return definiteFailureHelper(
889  transformOp, forallOp,
890  "cannot mix different mapping types, use nesting");
891  }
892  if (std::is_same<MappingKindType, BlockMappingKind>::value &&
893  !hasBlockMapping) {
894  return definiteFailureHelper(
895  transformOp, forallOp,
896  "scf.forall op requires a mapping attribute of kind 'block'");
897  }
898  if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
899  !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
900  return definiteFailureHelper(transformOp, forallOp,
901  "scf.forall op requires a mapping attribute "
902  "of kind 'thread' or 'warp'");
903  }
904 
905  DenseSet<Attribute> seen;
906  for (Attribute map : forallOp.getMapping()->getValue()) {
907  if (seen.contains(map)) {
908  return definiteFailureHelper(
909  transformOp, forallOp,
910  "duplicate attribute, cannot map different loops "
911  "to the same mapping id");
912  }
913  seen.insert(map);
914  }
915 
916  auto isLinear = [](Attribute a) {
917  return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
918  };
919  if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
920  !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
921  return definiteFailureHelper(
922  transformOp, forallOp,
923  "cannot mix linear and non-linear mapping modes");
924  }
925 
927 }
928 
929 template <typename MappingKindType>
931 verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
932  scf::ForallOp forallOp) {
933  // Check the types of the mapping attributes match.
935  checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
936  if (!typeRes.succeeded())
937  return typeRes;
938 
939  // Perform other non-types verifications.
940  if (!forallOp.isNormalized())
941  return definiteFailureHelper(transformOp, forallOp,
942  "unsupported non-normalized loops");
943  if (forallOp.getNumResults() > 0)
944  return definiteFailureHelper(transformOp, forallOp,
945  "only bufferized scf.forall can be mapped");
946  bool useLinearMapping = cast<DeviceMappingAttrInterface>(
947  forallOp.getMapping()->getValue().front())
948  .isLinearMapping();
949  // TODO: This would be more natural with support for Optional<EnumParameter>
950  // in GPUDeviceMappingAttr.
951  int64_t maxNumMappingsSupported =
952  useLinearMapping ? (getMaxEnumValForMappingId() -
953  static_cast<uint64_t>(MappingId::DimZ))
954  : 3;
955  if (forallOp.getRank() > maxNumMappingsSupported) {
956  return definiteFailureHelper(transformOp, forallOp,
957  "scf.forall with rank > ")
958  << maxNumMappingsSupported
959  << " does not lower for the specified mapping attribute type";
960  }
961  auto numParallelIterations =
962  getConstantIntValues(forallOp.getMixedUpperBound());
963  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
964  return definiteFailureHelper(
965  transformOp, forallOp,
966  "requires statically sized, normalized forall op");
967  }
969 }
970 
971 /// Struct to return the result of the rewrite of a forall operation.
975 };
976 
977 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
978 template <typename OpTy, typename OperationOrBlock>
979 static void
981  OperationOrBlock *parent, Value replacement,
982  ArrayRef<int64_t> availableMappingSizes) {
983  parent->walk([&](OpTy idOp) {
984  if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
985  rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
986  });
987 }
988 
990  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
991  scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
992  ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
993  LDBG("--start rewriteOneForallCommonImpl");
994 
995  // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
996  auto numParallelIterations =
997  getConstantIntValues(forallOp.getMixedUpperBound());
998  assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
999  "requires statically sized, normalized forall op");
1000  SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
1001  SetVector<Attribute> forallMappingAttrs;
1002  forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
1003  forallOp.getMapping()->getValue().end());
1004  auto comparator = [](Attribute a, Attribute b) -> bool {
1005  return cast<DeviceMappingAttrInterface>(a).getMappingId() <
1006  cast<DeviceMappingAttrInterface>(b).getMappingId();
1007  };
1008 
1009  // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
1010  // mapping all dimensions. In the 3-D mapping case we need to map all
1011  // dimensions.
1012  DeviceMappingAttrInterface maxMapping =
1013  cast<DeviceMappingAttrInterface>(*std::max_element(
1014  forallMappingAttrs.begin(), forallMappingAttrs.end(), comparator));
1015  DeviceMappingAttrInterface maxLinearMapping;
1016  if (maxMapping.isLinearMapping())
1017  maxLinearMapping = maxMapping;
1018  for (auto attr : gpuIdBuilder.mappingAttributes) {
1019  // If attr overflows, just skip.
1020  if (maxLinearMapping && comparator(maxLinearMapping, attr))
1021  continue;
1022  // Try to insert. If element was already present, just continue.
1023  if (!forallMappingAttrs.insert(attr))
1024  continue;
1025  // Otherwise, we have a new insertion without a size -> use size 1.
1026  tmpMappingSizes.push_back(1);
1027  }
1028  LLVM_DEBUG(
1029  llvm::interleaveComma(
1030  tmpMappingSizes,
1031  DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
1032  llvm::dbgs() << "\n");
1033 
1034  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
1035  SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
1036  forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
1037  LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
1038  DBGS() << "----forallMappingSizes: ");
1039  llvm::dbgs() << "\n"; llvm::interleaveComma(
1040  forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
1041  llvm::dbgs() << "\n");
1042 
1043  // Step 3. Generate the mappingIdOps using the provided generator.
1044  Location loc = forallOp.getLoc();
1045  OpBuilder::InsertionGuard guard(rewriter);
1046  rewriter.setInsertionPoint(forallOp);
1047  SmallVector<int64_t> originalBasis(availableMappingSizes);
1048  bool originalBasisWasProvided = !originalBasis.empty();
1049  if (!originalBasisWasProvided) {
1050  originalBasis = forallMappingSizes;
1051  while (originalBasis.size() < 3)
1052  originalBasis.push_back(1);
1053  }
1054 
1055  IdBuilderResult builderResult =
1056  gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
1057 
1058  // Step 4. Map the induction variables to the mappingIdOps, this may involve
1059  // a permutation.
1060  SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
1061  IRMapping bvm;
1062  for (auto [iv, dim] : llvm::zip_equal(
1063  forallOp.getInductionVars(),
1064  forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
1065  auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
1066  Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
1067  bvm.map(iv, peIdOp);
1068  }
1069 
1070  // Step 5. If the originalBasis is already known, create conditionals to
1071  // predicate the region. Otherwise, the current forall determines the
1072  // originalBasis and no predication occurs.
1073  Value predicate;
1074  if (originalBasisWasProvided) {
1075  SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
1076  SmallVector<int64_t> availableMappingSizes =
1077  builderResult.availableMappingSizes;
1078  SmallVector<Value> activeIdOps = builderResult.activeIdOps;
1079  // clang-format off
1080  LLVM_DEBUG(
1081  llvm::interleaveComma(
1082  activeMappingSizes, DBGS() << "----activeMappingSizes: ");
1083  llvm::dbgs() << "\n";
1084  llvm::interleaveComma(
1085  availableMappingSizes, DBGS() << "----availableMappingSizes: ");
1086  llvm::dbgs() << "\n";
1087  llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
1088  llvm::dbgs() << "\n");
1089  // clang-format on
1090  for (auto [activeId, activeMappingSize, availableMappingSize] :
1091  llvm::zip_equal(activeIdOps, activeMappingSizes,
1092  availableMappingSizes)) {
1093  if (activeMappingSize > availableMappingSize) {
1094  return definiteFailureHelper(
1095  transformOp, forallOp,
1096  "Trying to map to fewer GPU threads than loop iterations but "
1097  "overprovisioning is not yet supported. "
1098  "Try additional tiling of the before mapping or map to more "
1099  "threads.");
1100  }
1101  if (activeMappingSize == availableMappingSize)
1102  continue;
1103  Value idx =
1104  rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
1105  Value tmpPredicate = rewriter.create<arith::CmpIOp>(
1106  loc, arith::CmpIPredicate::ult, activeId, idx);
1107  LDBG("----predicate: " << tmpPredicate);
1108  predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
1109  tmpPredicate)
1110  : tmpPredicate;
1111  }
1112  }
1113 
1114  // Step 6. Move the body of forallOp.
1115  // Erase the terminator first, it will not be used.
1116  rewriter.eraseOp(forallOp.getTerminator());
1117  Block *targetBlock;
1118  Block::iterator insertionPoint;
1119  if (predicate) {
1120  // Step 6.a. If predicated, move at the beginning.
1121  auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
1122  /*withElseRegion=*/false);
1123  targetBlock = ifOp.thenBlock();
1124  insertionPoint = ifOp.thenBlock()->begin();
1125  } else {
1126  // Step 6.b. Otherwise, move inline just at the rewriter insertion
1127  // point.
1128  targetBlock = forallOp->getBlock();
1129  insertionPoint = rewriter.getInsertionPoint();
1130  }
1131  Block &sourceBlock = forallOp.getRegion().front();
1132  targetBlock->getOperations().splice(insertionPoint,
1133  sourceBlock.getOperations());
1134 
1135  // Step 7. RAUW indices.
1136  for (Value loopIndex : forallOp.getInductionVars()) {
1137  Value threadIdx = bvm.lookup(loopIndex);
1138  rewriter.replaceAllUsesWith(loopIndex, threadIdx);
1139  }
1140 
1141  // Step 8. Erase old op.
1142  rewriter.eraseOp(forallOp);
1143 
1144  LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
1145  DBGS() << "----result forallMappingSizes: ");
1146  llvm::dbgs() << "\n"; llvm::interleaveComma(
1147  mappingIdOps, DBGS() << "----result mappingIdOps: ");
1148  llvm::dbgs() << "\n");
1149 
1150  result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
1152 }
1153 
1154 //===----------------------------------------------------------------------===//
1155 // MapForallToBlocks
1156 //===----------------------------------------------------------------------===//
1157 
1159  RewriterBase &rewriter, TransformOpInterface transformOp,
1160  scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
1161  const GpuIdBuilder &gpuIdBuilder) {
1162  LDBG("Start mapForallToBlocksImpl");
1163 
1164  {
1165  // GPU-specific verifications. There is no better place to anchor
1166  // those right now: the ForallOp is target-independent and the transform
1167  // op does not apply to individual ForallOp.
1169  verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
1170  if (!diag.succeeded())
1171  return diag;
1172  }
1173 
1174  Location loc = forallOp.getLoc();
1175  Block *parentBlock = forallOp->getBlock();
1176  Value zero;
1177  {
1178  // Create an early zero index value for replacements and immediately reset
1179  // the insertion point.
1180  OpBuilder::InsertionGuard guard(rewriter);
1181  rewriter.setInsertionPointToStart(parentBlock);
1182  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1183  }
1184 
1185  ForallRewriteResult rewriteResult;
1187  rewriter, transformOp, forallOp,
1188  /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
1189 
1190  // Return if anything goes wrong, use silenceable failure as a match
1191  // failure.
1192  if (!diag.succeeded())
1193  return diag;
1194 
1195  // If gridDims was not provided already, set it from the return.
1196  if (gridDims.empty()) {
1197  gridDims = rewriteResult.mappingSizes;
1198  while (gridDims.size() < 3)
1199  gridDims.push_back(1);
1200  }
1201  assert(gridDims.size() == 3 && "Need 3-D gridDims");
1202 
1203  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
1204  // Here, the result of mapping determines the available mapping sizes.
1205  replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
1206  rewriteResult.mappingSizes);
1207 
1209 }
1210 
1213  scf::ForallOp &topLevelForallOp,
1214  TransformOpInterface transformOp) {
1215  auto walkResult = target->walk([&](scf::ForallOp forallOp) {
1216  if (forallOp->getParentOfType<scf::ForallOp>())
1217  return WalkResult::advance();
1218  if (topLevelForallOp)
1219  // TODO: Handle multiple forall if they are independent.
1220  return WalkResult::interrupt();
1221  topLevelForallOp = forallOp;
1222  return WalkResult::advance();
1223  });
1224 
1225  if (walkResult.wasInterrupted() || !topLevelForallOp)
1226  return transformOp.emitSilenceableError()
1227  << "could not find a unique topLevel scf.forall";
1229 }
1230 
1231 DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
1232  transform::TransformRewriter &rewriter, Operation *target,
1234  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
1235  auto transformOp = cast<TransformOpInterface>(getOperation());
1236 
1237  if (!getGenerateGpuLaunch() && !gpuLaunch) {
1239  emitSilenceableError()
1240  << "Given target is not gpu.launch, set `generate_gpu_launch` "
1241  "attribute";
1242  diag.attachNote(target->getLoc()) << "when applied to this payload op";
1243  return diag;
1244  }
1245 
1246  scf::ForallOp topLevelForallOp;
1248  target, topLevelForallOp, transformOp);
1249  if (!diag.succeeded()) {
1250  diag.attachNote(target->getLoc()) << "when applied to this payload op";
1251  return diag;
1252  }
1253  assert(topLevelForallOp && "expect an scf.forall");
1254 
1255  SmallVector<int64_t> gridDims{getGridDims()};
1256  if (!getGenerateGpuLaunch() && gridDims.size() != 3)
1257  return transformOp.emitDefiniteFailure("transform require size-3 mapping");
1258 
1259  OpBuilder::InsertionGuard guard(rewriter);
1260  rewriter.setInsertionPoint(topLevelForallOp);
1261 
1262  // Generate gpu launch here and move the forall inside
1263  if (getGenerateGpuLaunch()) {
1265  createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
1266  if (!diag.succeeded())
1267  return diag;
1268 
1269  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
1270  Operation *newForallOp = rewriter.clone(*topLevelForallOp);
1271  rewriter.eraseOp(topLevelForallOp);
1272  topLevelForallOp = cast<scf::ForallOp>(newForallOp);
1273  }
1274 
1275  // The BlockIdBuilder adapts to whatever is thrown at it.
1276  bool useLinearMapping = false;
1277  if (topLevelForallOp.getMapping()) {
1278  auto mappingAttr = cast<DeviceMappingAttrInterface>(
1279  topLevelForallOp.getMapping()->getValue().front());
1280  useLinearMapping = mappingAttr.isLinearMapping();
1281  }
1282  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
1283 
1285  rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
1286  if (!diag.succeeded())
1287  return diag;
1288 
1289  // Set the GPU launch configuration for the grid dims late, this is
1290  // subject to IR inspection.
1291  diag = alterGpuLaunch(rewriter, gpuLaunch,
1292  cast<TransformOpInterface>(getOperation()), gridDims[0],
1293  gridDims[1], gridDims[2]);
1294 
1295  results.push_back(gpuLaunch);
1296  return diag;
1297 }
1298 
1300  if (!getGridDims().empty() && getGridDims().size() != 3) {
1301  return emitOpError() << "transform requires empty or size-3 grid_dims";
1302  }
1303  return success();
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // MapNestedForallToThreads
1308 //===----------------------------------------------------------------------===//
1309 
1311  std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
1312  ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
1313  int factor, bool useLinearMapping = false) {
1314  if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
1315  auto diag = definiteFailureHelper(
1316  transformOp, forallOp,
1317  Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
1318  std::to_string(factor));
1319  return diag;
1320  }
1321  if (computeProduct(numParallelIterations) * factor >
1322  computeProduct(blockOrGridSizes)) {
1323  auto diag = definiteFailureHelper(
1324  transformOp, forallOp,
1325  Twine("the number of required parallel resources (blocks or "
1326  "threads) ") +
1327  std::to_string(computeProduct(numParallelIterations) * factor) +
1328  std::string(" overflows the number of available resources ") +
1329  std::to_string(computeProduct(blockOrGridSizes)));
1330  return diag;
1331  }
1333 }
1334 
1336 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
1337  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
1338  int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
1339  auto mappingAttr = cast<DeviceMappingAttrInterface>(
1340  forallOp.getMapping()->getValue().front());
1341  bool useLinearMapping = mappingAttr.isLinearMapping();
1342 
1343  // Sanity checks that may result in runtime verification errors.
1344  auto numParallelIterations =
1345  getConstantIntValues((forallOp.getMixedUpperBound()));
1346  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
1347  return definiteFailureHelper(
1348  transformOp, forallOp,
1349  "requires statically sized, normalized forall op");
1350  }
1351  int64_t factor = 1;
1352  if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
1353  factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
1354  } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
1355  factor = warpSize;
1356  }
1358  checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
1359  blockSizes, factor, useLinearMapping);
1360  if (!diag.succeeded())
1361  return diag;
1362 
1363  // Start mapping.
1364  MLIRContext *ctx = forallOp.getContext();
1365  gpuIdBuilder =
1367  .Case([&](GPUWarpgroupMappingAttr) {
1368  return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
1369  })
1370  .Case([&](GPUWarpMappingAttr) {
1371  return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
1372  })
1373  .Case([&](GPUThreadMappingAttr) {
1374  return GpuThreadIdBuilder(ctx, useLinearMapping);
1375  })
1376  .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
1377  llvm_unreachable("unknown mapping attribute");
1378  });
1380 }
1381 
1383  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
1384  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
1385  bool syncAfterDistribute) {
1386 
1387  {
1388  // GPU-specific verifications. There is no better place to anchor
1389  // those right now: the ForallOp is target-independent and the transform
1390  // op does not apply to individual ForallOp.
1392  verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
1393  if (!diag.succeeded())
1394  return diag;
1395  }
1396 
1397  GpuIdBuilder gpuIdBuilder;
1398  {
1399  // Try to construct the id builder, if it fails, return.
1401  transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
1402  if (!diag.succeeded())
1403  return diag;
1404  }
1405 
1406  Location loc = forallOp.getLoc();
1407  OpBuilder::InsertionGuard g(rewriter);
1408  // Insert after to allow for syncthreads after `forall` is erased.
1409  rewriter.setInsertionPointAfter(forallOp);
1410  ForallRewriteResult rewriteResult;
1412  rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
1413  if (!diag.succeeded())
1414  return diag;
1415  // Add a syncthreads if needed. TODO: warpsync
1416  if (syncAfterDistribute)
1417  rewriter.create<BarrierOp>(loc);
1418 
1420 }
1421 
1423  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
1424  Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
1425  bool syncAfterDistribute) {
1426  LDBG("Start mapNestedForallToThreadsImpl");
1427  if (blockDims.size() != 3) {
1428  return definiteFailureHelper(transformOp, target,
1429  "requires size-3 thread mapping");
1430  }
1431 
1432  // Create an early zero index value for replacements.
1433  Location loc = target->getLoc();
1434  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1436  WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
1438  rewriter, transformOp, forallOp, blockDims, warpSize,
1439  syncAfterDistribute);
1440  if (diag.isDefiniteFailure())
1441  return WalkResult::interrupt();
1442  if (diag.succeeded())
1443  return WalkResult::skip();
1444  return WalkResult::advance();
1445  });
1446  if (walkResult.wasInterrupted())
1447  return diag;
1448 
1449  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
1450  // Here, the result of mapping determines the available mapping sizes.
1451  replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
1452  blockDims);
1453 
1455 }
1456 
1457 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
1458  transform::TransformRewriter &rewriter, Operation *target,
1459  ApplyToEachResultList &results, TransformState &state) {
1460  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
1461  auto transformOp = cast<TransformOpInterface>(getOperation());
1462 
1463  // Basic high-level verifications.
1464  if (!gpuLaunch)
1465  return emitSilenceableError() << "Given target is not a gpu.launch";
1466 
1467  // Mapping to block ids.
1468  SmallVector<int64_t> blockDims{getBlockDims()};
1470  checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
1471  blockDims[0], blockDims[1], blockDims[2]);
1472  if (diag.isSilenceableFailure()) {
1473  diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
1474  return diag;
1475  }
1476 
1477  // Set the GPU launch configuration for the block dims early, this is not
1478  // subject to IR inspection.
1479  diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
1480  std::nullopt, std::nullopt, blockDims[0], blockDims[1],
1481  blockDims[2]);
1482 
1483  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
1484  diag =
1485  mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
1486  getWarpSize(), getSyncAfterDistribute());
1487 
1488  results.push_back(gpuLaunch.getOperation());
1489  return diag;
1490 }
1491 
1492 //===----------------------------------------------------------------------===//
1493 // Transform op registration
1494 //===----------------------------------------------------------------------===//
1495 
1496 namespace {
1497 /// Registers new ops and declares PDL as dependent dialect since the
1498 /// additional ops are using PDL types for operands and results.
1499 class GPUTransformDialectExtension
1501  GPUTransformDialectExtension> {
1502 public:
1503  GPUTransformDialectExtension() {
1504  declareGeneratedDialect<scf::SCFDialect>();
1505  declareGeneratedDialect<arith::ArithDialect>();
1506  declareGeneratedDialect<GPUDialect>();
1507  registerTransformOps<
1508 #define GET_OP_LIST
1509 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
1510  >();
1511  }
1512 };
1513 } // namespace
1514 
1515 #define GET_OP_CLASSES
1516 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
1517 
1519  registry.addExtensions<GPUTransformDialectExtension>();
1520 }
@ kGlobalMemorySpace
Global memory space identifier.
Definition: GPUDialect.cpp:173
@ kSharedMemorySpace
Shared memory space identifier.
Definition: GPUDialect.cpp:176
static bool isSequentialLoopLike(Operation *op)
Returns true if the op behaves like a sequential loop, e.g., the control flow "wraps around" from the...
static DiagnosedSilenceableFailure checkMappingAttributeTypes(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
Check if given mapping attributes are one of the desired attributes.
static bool isFunctionArgument(Value v)
Returns true if the value is defined as a function argument.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static std::optional< SmallVector< int64_t > > getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k)
Returns the target vector size for the target operation based on the native vector size specified wit...
static Value propagatesCapture(Operation *op)
Returns the operand that the operation "propagates" through it for capture purposes.
static bool hasSingleExecutionBody(Operation *op)
Returns true if the regions of the op are guaranteed to be executed at most once.
static bool producesDistinctBase(Operation *op)
Returns true if the operation is known to produce a pointer-like object distinct from any other objec...
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > availableMappingSizes, ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder)
static bool isKnownNoEffectsOpWithoutInterface(Operation *op)
Implement the MemoryEffectsOpInterface in the suitable way.
static bool isParallelRegionBoundary(Operation *op)
Returns true if the op is defines the parallel region that is subject to barrier synchronization.
static DiagnosedSilenceableFailure definiteFailureHelper(std::optional< TransformOpInterface > transformOp, Operation *target, const Twine &message)
static DiagnosedSilenceableFailure checkMappingSpec(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > numParallelIterations, ArrayRef< int64_t > blockOrGridSizes, int factor, bool useLinearMapping=false)
#define DEBUG_TYPE_ALIAS
static bool getEffectsAfter(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)
Collects memory effects from operations that may be executed after op in a trivial structured control...
static std::optional< bool > getKnownCapturingStatus(Operation *op, Value v)
Returns true if the given operation is known to capture the given value, false if it is known not to ...
static void replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, OperationOrBlock *parent, Value replacement, ArrayRef< int64_t > availableMappingSizes)
Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
static bool collectEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool ignoreBarriers=true)
Collect the memory effects of the given op in 'effects'.
#define DBGS_ALIAS()
static std::optional< SmallVector< int64_t > > gpuMmaUnrollOrder(vector::ContractionOp contract)
Pick an unrolling order that will allow tensorcore operation to reuse LHS register.
static bool haveConflictingEffects(ArrayRef< MemoryEffects::EffectInstance > beforeEffects, ArrayRef< MemoryEffects::EffectInstance > afterEffects)
Returns true if any of the "before" effect instances has a conflict with any "after" instance for the...
#define DBGS()
static void addAllValuelessEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with all memory effects without associating them to a specific value.
static bool maybeCaptured(Value v)
Returns true if the value may be captured by any of its users, i.e., if the user may be storing this ...
static DiagnosedSilenceableFailure verifyGpuMapping(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
#define LDBG(X)
static DiagnosedSilenceableFailure getThreadIdBuilder(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockSizes, int64_t warpSize, GpuIdBuilder &gpuIdBuilder)
static bool getEffectsBefore(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)
Collects memory effects from operations that may be executed before op in a trivial structured contro...
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
unsigned getPosition() const
Definition: AffineExpr.cpp:325
Base type for affine expression.
Definition: AffineExpr.h:68
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
Operation & back()
Definition: Block.h:145
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
OpListType & getOperations()
Definition: Block.h:130
Operation & front()
Definition: Block.h:146
iterator begin()
Definition: Block.h:136
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:538
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
result_type_range getResultTypes()
Definition: Operation.h:423
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
TypeID getResourceID() const
Return the unique identifier for the base resource class.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
user_range getUsers() const
Definition: Value.h:222
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:127
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1344
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerTransformDialectExtension(DialectRegistry &registry)
DiagnosedSilenceableFailure mapOneForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockDims, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
DiagnosedSilenceableFailure findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp, TransformOpInterface transformOp)
Find the unique top level scf::ForallOp within a given target op.
DiagnosedSilenceableFailure alterGpuLaunch(RewriterBase &rewriter, mlir::gpu::LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Alter kernel configuration of the given kernel.
DiagnosedSilenceableFailure createGpuLaunch(RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, mlir::gpu::LaunchOp &launchOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Create an empty-body gpu::LaunchOp using the provided kernel settings and put a terminator within.
DiagnosedSilenceableFailure mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp, scf::ForallOp forallOp, SmallVectorImpl< int64_t > &gridDims, const GpuIdBuilder &gpuIdBuilder)
Map the top level scf.forall op to GPU blocks.
DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, std::optional< int64_t > gridDimX, std::optional< int64_t > gridDimY, std::optional< int64_t > gridDimZ, std::optional< int64_t > blockDimX, std::optional< int64_t > blockDimY, std::optional< int64_t > blockDimZ)
Determine if the size of the kernel configuration is supported by the GPU architecture being used.
Definition: Utils.cpp:232
DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, Operation *target, ArrayRef< int64_t > blockDims, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:130
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:125
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:66
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
Struct to return the result of the rewrite of a forall operation.
SmallVector< Value > mappingIds
SmallVector< int64_t > mappingSizes
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Builder for gpu::BlockIdOps used to map scf.forall to blocks.
Definition: Utils.h:83
Helper struct for configuring the rewrite of mapped scf.forall ops to various gpu id configurations.
Definition: Utils.h:63
SmallVector< DeviceMappingAttrInterface > mappingAttributes
The mapping attributes targeted by this generator.
Definition: Utils.h:72
GpuIdBuilderFnType idBuilder
The constructor that builds the concrete IR for mapping ids.
Definition: Utils.h:75
Builder for warp ids used to map scf.forall to reindexed threads.
Definition: Utils.h:116
Builder for warp ids used to map scf.forall to reindexed warps.
Definition: Utils.h:105
Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
Definition: Utils.h:92
Helper type for functions that generate ids for the mapping of a scf.forall.
Definition: Utils.h:38
SmallVector< int64_t > availableMappingSizes
Definition: Utils.h:43
SmallVector< Value > mappingIdOps
Definition: Utils.h:40
SmallVector< Value > activeIdOps
Definition: Utils.h:49
SmallVector< int64_t > activeMappingSizes
Definition: Utils.h:46
Options that control the vector unrolling.