MLIR 22.0.0git
IndexIntrinsicsOpLowering.h
Go to the documentation of this file.
1//===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
9#define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
10
15#include <limits>
16
17namespace mlir {
18namespace gpu {
19namespace index_lowering {
20enum class IndexKind : uint32_t { Other = 0, Block = 1, Grid = 2, Cluster = 3 };
21enum class IntrType : uint32_t {
22 None = 0,
23 Id = 1,
24 Dim = 2,
25};
26
27// Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
28// that Op operates on. Op is assumed to return an `index` value and
29// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
30// `indexBitwidth`, sign-extend or truncate the resulting value to match the
31// bitwidth expected by the consumers of the value.
32template <typename Op, typename XOp, typename YOp, typename ZOp>
34private:
35 unsigned indexBitwidth;
36 IndexKind indexKind;
37 IntrType intrType;
38
39public:
40 explicit OpLowering(const LLVMTypeConverter &typeConverter,
41 PatternBenefit benefit = 1)
42 : ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
43 indexBitwidth(typeConverter.getIndexTypeBitwidth()),
44 indexKind(IndexKind::Other), intrType(IntrType::None) {}
45
46 explicit OpLowering(const LLVMTypeConverter &typeConverter,
47 IndexKind indexKind, IntrType intrType,
48 PatternBenefit benefit = 1)
49 : ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
50 indexBitwidth(typeConverter.getIndexTypeBitwidth()),
51 indexKind(indexKind), intrType(intrType) {}
52
53 // Convert the kernel arguments to an LLVM type, preserve the rest.
54 LogicalResult
55 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 auto loc = op->getLoc();
58 MLIRContext *context = rewriter.getContext();
59 Operation *newOp;
60 switch (op.getDimension()) {
61 case gpu::Dimension::x:
62 newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32));
63 break;
64 case gpu::Dimension::y:
65 newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32));
66 break;
67 case gpu::Dimension::z:
68 newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32));
69 break;
70 }
71
72 // Order of priority for bounds:
73 // 1. The upper_bound attribute
74 // 2. Inherent attributes on a surrounding gpu.func
75 // 3. Discardable attributes on a surrounding function of any kind
76 // The below code handles these in reverse order so that more important
77 // sources overwrite less important ones.
78 DenseI32ArrayAttr funcBounds = nullptr;
79 if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
80 switch (indexKind) {
81 case IndexKind::Block: {
82 auto blockHelper =
83 gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
84 if (blockHelper.isAttrPresent(funcOp))
85 funcBounds = blockHelper.getAttr(funcOp);
86 break;
87 }
88 case IndexKind::Grid: {
89 auto gridHelper =
90 gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
91 if (gridHelper.isAttrPresent(funcOp))
92 funcBounds = gridHelper.getAttr(funcOp);
93 break;
94 }
95 case IndexKind::Cluster: {
96 auto clusterHelper =
97 gpu::GPUDialect::KnownClusterSizeAttrHelper(op.getContext());
98 if (clusterHelper.isAttrPresent(funcOp))
99 funcBounds = clusterHelper.getAttr(funcOp);
100 break;
101 }
102 case IndexKind::Other:
103 break;
104 }
105 }
106 if (auto gpuFunc = op->template getParentOfType<gpu::GPUFuncOp>()) {
107 switch (indexKind) {
108 case IndexKind::Block:
109 funcBounds = gpuFunc.getKnownBlockSizeAttr();
110 break;
111 case IndexKind::Grid:
112 funcBounds = gpuFunc.getKnownGridSizeAttr();
113 break;
115 funcBounds = gpuFunc.getKnownClusterSizeAttr();
116 break;
117 case IndexKind::Other:
118 break;
119 }
120 }
121 std::optional<int32_t> upperBound;
122 if (funcBounds)
123 upperBound =
124 funcBounds.asArrayRef()[static_cast<uint32_t>(op.getDimension())];
125 if (auto opBound = op.getUpperBound())
126 upperBound = opBound->getZExtValue();
127
128 if (upperBound && intrType != IntrType::None) {
129 int32_t min = (intrType == IntrType::Dim ? 1 : 0);
130 int32_t max = *upperBound == std::numeric_limits<int32_t>::max()
131 ? *upperBound
132 : *upperBound + (intrType == IntrType::Id ? 0 : 1);
133 newOp->setAttr("range", LLVM::ConstantRangeAttr::get(
134 rewriter.getContext(), 32, min, max));
135 }
136 if (indexBitwidth > 32) {
137 newOp = LLVM::SExtOp::create(rewriter, loc,
138 IntegerType::get(context, indexBitwidth),
139 newOp->getResult(0));
140 } else if (indexBitwidth < 32) {
141 newOp = LLVM::TruncOp::create(rewriter, loc,
142 IntegerType::get(context, indexBitwidth),
143 newOp->getResult(0));
144 }
145
146 rewriter.replaceOp(op, newOp->getResults());
147 return success();
148 }
149};
150} // namespace index_lowering
151} // namespace gpu
152} // namespace mlir
153
154#endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
return success()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext()
Return the context this operation belongs to.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpLowering(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpLowering(const LLVMTypeConverter &typeConverter, IndexKind indexKind, IntrType intrType, PatternBenefit benefit=1)