MLIR  20.0.0git
IndexIntrinsicsOpLowering.h
Go to the documentation of this file.
1 //===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
10 
15 
16 namespace mlir {
17 namespace gpu {
18 namespace index_lowering {
19 enum class IndexKind : uint32_t { Other = 0, Block = 1, Grid = 2 };
20 enum class IntrType : uint32_t {
21  None = 0,
22  Id = 1,
23  Dim = 2,
24 };
25 
26 // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
27 // that Op operates on. Op is assumed to return an `index` value and
28 // XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
29 // `indexBitwidth`, sign-extend or truncate the resulting value to match the
30 // bitwidth expected by the consumers of the value.
31 template <typename Op, typename XOp, typename YOp, typename ZOp>
32 struct OpLowering : public ConvertOpToLLVMPattern<Op> {
33 private:
34  unsigned indexBitwidth;
35  IndexKind indexKind;
36  IntrType intrType;
37 
38 public:
41  indexBitwidth(typeConverter.getIndexTypeBitwidth()),
42  indexKind(IndexKind::Other), intrType(IntrType::None) {}
43 
45  IndexKind indexKind, IntrType intrType)
47  indexBitwidth(typeConverter.getIndexTypeBitwidth()),
48  indexKind(indexKind), intrType(intrType) {}
49 
50  // Convert the kernel arguments to an LLVM type, preserve the rest.
51  LogicalResult
52  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override {
54  auto loc = op->getLoc();
55  MLIRContext *context = rewriter.getContext();
56  Operation *newOp;
57  switch (op.getDimension()) {
58  case gpu::Dimension::x:
59  newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32));
60  break;
61  case gpu::Dimension::y:
62  newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32));
63  break;
64  case gpu::Dimension::z:
65  newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32));
66  break;
67  }
68 
69  // Order of priority for bounds:
70  // 1. The upper_bound attribute
71  // 2. Inherent attributes on a surrounding gpu.func
72  // 3. Discardable attributes on a surrounding function of any kind
73  // The below code handles these in reverse order so that more important
74  // sources overwrite less important ones.
75  DenseI32ArrayAttr funcBounds = nullptr;
76  if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
77  switch (indexKind) {
78  case IndexKind::Block: {
79  auto blockHelper =
80  gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
81  if (blockHelper.isAttrPresent(funcOp))
82  funcBounds = blockHelper.getAttr(funcOp);
83  break;
84  }
85  case IndexKind::Grid: {
86  auto gridHelper =
87  gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
88  if (gridHelper.isAttrPresent(funcOp))
89  funcBounds = gridHelper.getAttr(funcOp);
90  break;
91  }
92  case IndexKind::Other:
93  break;
94  }
95  }
96  if (auto gpuFunc = op->template getParentOfType<gpu::GPUFuncOp>()) {
97  switch (indexKind) {
98  case IndexKind::Block:
99  funcBounds = gpuFunc.getKnownBlockSizeAttr();
100  break;
101  case IndexKind::Grid:
102  funcBounds = gpuFunc.getKnownGridSizeAttr();
103  break;
104  case IndexKind::Other:
105  break;
106  }
107  }
108  std::optional<int32_t> upperBound;
109  if (funcBounds)
110  upperBound =
111  funcBounds.asArrayRef()[static_cast<uint32_t>(op.getDimension())];
112  if (auto opBound = op.getUpperBound())
113  upperBound = opBound->getZExtValue();
114 
115  if (upperBound && intrType != IntrType::None) {
116  int32_t min = (intrType == IntrType::Dim ? 1 : 0);
117  int32_t max = *upperBound + (intrType == IntrType::Id ? 0 : 1);
118  newOp->setAttr("range", LLVM::ConstantRangeAttr::get(
119  rewriter.getContext(), 32, min, max));
120  }
121  if (indexBitwidth > 32) {
122  newOp = rewriter.create<LLVM::SExtOp>(
123  loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
124  } else if (indexBitwidth < 32) {
125  newOp = rewriter.create<LLVM::TruncOp>(
126  loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
127  }
128 
129  rewriter.replaceOp(op, newOp->getResults());
130  return success();
131  }
132 };
133 } // namespace index_lowering
134 } // namespace gpu
135 } // namespace mlir
136 
137 #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Definition: Block.h:33
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
MLIRContext * getContext()
Return the context this operation belongs to.
Definition: OpDefinition.h:111
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
result_range getResults()
Definition: Operation.h:415
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpLowering(const LLVMTypeConverter &typeConverter)
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpLowering(const LLVMTypeConverter &typeConverter, IndexKind indexKind, IntrType intrType)