MLIR 22.0.0git
IndexIntrinsicsOpLowering.h
Go to the documentation of this file.
1//===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
9#define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
10
15#include <limits>
16
17namespace mlir {
18namespace gpu {
19namespace index_lowering {
20enum class IndexKind : uint32_t { Other = 0, Block = 1, Grid = 2 };
21enum class IntrType : uint32_t {
22 None = 0,
23 Id = 1,
24 Dim = 2,
25};
26
27// Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
28// that Op operates on. Op is assumed to return an `index` value and
29// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
30// `indexBitwidth`, sign-extend or truncate the resulting value to match the
31// bitwidth expected by the consumers of the value.
32template <typename Op, typename XOp, typename YOp, typename ZOp>
34private:
35 unsigned indexBitwidth;
36 IndexKind indexKind;
37 IntrType intrType;
38
39public:
40 explicit OpLowering(const LLVMTypeConverter &typeConverter,
41 PatternBenefit benefit = 1)
42 : ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
43 indexBitwidth(typeConverter.getIndexTypeBitwidth()),
44 indexKind(IndexKind::Other), intrType(IntrType::None) {}
45
46 explicit OpLowering(const LLVMTypeConverter &typeConverter,
47 IndexKind indexKind, IntrType intrType,
48 PatternBenefit benefit = 1)
49 : ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
50 indexBitwidth(typeConverter.getIndexTypeBitwidth()),
51 indexKind(indexKind), intrType(intrType) {}
52
53 // Convert the kernel arguments to an LLVM type, preserve the rest.
54 LogicalResult
55 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 auto loc = op->getLoc();
58 MLIRContext *context = rewriter.getContext();
59 Operation *newOp;
60 switch (op.getDimension()) {
61 case gpu::Dimension::x:
62 newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32));
63 break;
64 case gpu::Dimension::y:
65 newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32));
66 break;
67 case gpu::Dimension::z:
68 newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32));
69 break;
70 }
71
72 // Order of priority for bounds:
73 // 1. The upper_bound attribute
74 // 2. Inherent attributes on a surrounding gpu.func
75 // 3. Discardable attributes on a surrounding function of any kind
76 // The below code handles these in reverse order so that more important
77 // sources overwrite less important ones.
78 DenseI32ArrayAttr funcBounds = nullptr;
79 if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
80 switch (indexKind) {
81 case IndexKind::Block: {
82 auto blockHelper =
83 gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
84 if (blockHelper.isAttrPresent(funcOp))
85 funcBounds = blockHelper.getAttr(funcOp);
86 break;
87 }
88 case IndexKind::Grid: {
89 auto gridHelper =
90 gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
91 if (gridHelper.isAttrPresent(funcOp))
92 funcBounds = gridHelper.getAttr(funcOp);
93 break;
94 }
96 break;
97 }
98 }
99 if (auto gpuFunc = op->template getParentOfType<gpu::GPUFuncOp>()) {
100 switch (indexKind) {
101 case IndexKind::Block:
102 funcBounds = gpuFunc.getKnownBlockSizeAttr();
103 break;
104 case IndexKind::Grid:
105 funcBounds = gpuFunc.getKnownGridSizeAttr();
106 break;
107 case IndexKind::Other:
108 break;
109 }
110 }
111 std::optional<int32_t> upperBound;
112 if (funcBounds)
113 upperBound =
114 funcBounds.asArrayRef()[static_cast<uint32_t>(op.getDimension())];
115 if (auto opBound = op.getUpperBound())
116 upperBound = opBound->getZExtValue();
117
118 if (upperBound && intrType != IntrType::None) {
119 int32_t min = (intrType == IntrType::Dim ? 1 : 0);
120 int32_t max = *upperBound == std::numeric_limits<int32_t>::max()
121 ? *upperBound
122 : *upperBound + (intrType == IntrType::Id ? 0 : 1);
123 newOp->setAttr("range", LLVM::ConstantRangeAttr::get(
124 rewriter.getContext(), 32, min, max));
125 }
126 if (indexBitwidth > 32) {
127 newOp = LLVM::SExtOp::create(rewriter, loc,
128 IntegerType::get(context, indexBitwidth),
129 newOp->getResult(0));
130 } else if (indexBitwidth < 32) {
131 newOp = LLVM::TruncOp::create(rewriter, loc,
132 IntegerType::get(context, indexBitwidth),
133 newOp->getResult(0));
134 }
135
136 rewriter.replaceOp(op, newOp->getResults());
137 return success();
138 }
139};
140} // namespace index_lowering
141} // namespace gpu
142} // namespace mlir
143
144#endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
return success()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext()
Return the context this operation belongs to.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpLowering(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpLowering(const LLVMTypeConverter &typeConverter, IndexKind indexKind, IntrType intrType, PatternBenefit benefit=1)