MLIR  21.0.0git
IndexIntrinsicsOpLowering.h
Go to the documentation of this file.
1 //===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
10 
15 
16 namespace mlir {
17 namespace gpu {
18 namespace index_lowering {
19 enum class IndexKind : uint32_t { Other = 0, Block = 1, Grid = 2 };
20 enum class IntrType : uint32_t {
21  None = 0,
22  Id = 1,
23  Dim = 2,
24 };
25 
26 // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
27 // that Op operates on. Op is assumed to return an `index` value and
28 // XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
29 // `indexBitwidth`, sign-extend or truncate the resulting value to match the
30 // bitwidth expected by the consumers of the value.
31 template <typename Op, typename XOp, typename YOp, typename ZOp>
32 struct OpLowering : public ConvertOpToLLVMPattern<Op> {
33 private:
34  unsigned indexBitwidth;
35  IndexKind indexKind;
36  IntrType intrType;
37 
38 public:
40  PatternBenefit benefit = 1)
42  indexBitwidth(typeConverter.getIndexTypeBitwidth()),
43  indexKind(IndexKind::Other), intrType(IntrType::None) {}
44 
46  IndexKind indexKind, IntrType intrType,
47  PatternBenefit benefit = 1)
49  indexBitwidth(typeConverter.getIndexTypeBitwidth()),
50  indexKind(indexKind), intrType(intrType) {}
51 
52  // Convert the kernel arguments to an LLVM type, preserve the rest.
53  LogicalResult
54  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
55  ConversionPatternRewriter &rewriter) const override {
56  auto loc = op->getLoc();
57  MLIRContext *context = rewriter.getContext();
58  Operation *newOp;
59  switch (op.getDimension()) {
60  case gpu::Dimension::x:
61  newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32));
62  break;
63  case gpu::Dimension::y:
64  newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32));
65  break;
66  case gpu::Dimension::z:
67  newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32));
68  break;
69  }
70 
71  // Order of priority for bounds:
72  // 1. The upper_bound attribute
73  // 2. Inherent attributes on a surrounding gpu.func
74  // 3. Discardable attributes on a surrounding function of any kind
75  // The below code handles these in reverse order so that more important
76  // sources overwrite less important ones.
77  DenseI32ArrayAttr funcBounds = nullptr;
78  if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
79  switch (indexKind) {
80  case IndexKind::Block: {
81  auto blockHelper =
82  gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
83  if (blockHelper.isAttrPresent(funcOp))
84  funcBounds = blockHelper.getAttr(funcOp);
85  break;
86  }
87  case IndexKind::Grid: {
88  auto gridHelper =
89  gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
90  if (gridHelper.isAttrPresent(funcOp))
91  funcBounds = gridHelper.getAttr(funcOp);
92  break;
93  }
94  case IndexKind::Other:
95  break;
96  }
97  }
98  if (auto gpuFunc = op->template getParentOfType<gpu::GPUFuncOp>()) {
99  switch (indexKind) {
100  case IndexKind::Block:
101  funcBounds = gpuFunc.getKnownBlockSizeAttr();
102  break;
103  case IndexKind::Grid:
104  funcBounds = gpuFunc.getKnownGridSizeAttr();
105  break;
106  case IndexKind::Other:
107  break;
108  }
109  }
110  std::optional<int32_t> upperBound;
111  if (funcBounds)
112  upperBound =
113  funcBounds.asArrayRef()[static_cast<uint32_t>(op.getDimension())];
114  if (auto opBound = op.getUpperBound())
115  upperBound = opBound->getZExtValue();
116 
117  if (upperBound && intrType != IntrType::None) {
118  int32_t min = (intrType == IntrType::Dim ? 1 : 0);
119  int32_t max = *upperBound + (intrType == IntrType::Id ? 0 : 1);
120  newOp->setAttr("range", LLVM::ConstantRangeAttr::get(
121  rewriter.getContext(), 32, min, max));
122  }
123  if (indexBitwidth > 32) {
124  newOp = rewriter.create<LLVM::SExtOp>(
125  loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
126  } else if (indexBitwidth < 32) {
127  newOp = rewriter.create<LLVM::TruncOp>(
128  loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
129  }
130 
131  rewriter.replaceOp(op, newOp->getResults());
132  return success();
133  }
134 };
135 } // namespace index_lowering
136 } // namespace gpu
137 } // namespace mlir
138 
139 #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Definition: Block.h:33
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
MLIRContext * getContext()
Return the context this operation belongs to.
Definition: OpDefinition.h:114
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpLowering(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpLowering(const LLVMTypeConverter &typeConverter, IndexKind indexKind, IntrType intrType, PatternBenefit benefit=1)