MLIR 22.0.0git
IndexToSPIRV.cpp
Go to the documentation of this file.
1//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16
17using namespace mlir;
18using namespace index;
19
20namespace {
21
22//===----------------------------------------------------------------------===//
23// Trivial Conversions
24//===----------------------------------------------------------------------===//
25
37
38using ConvertIndexShl =
40using ConvertIndexShrS =
42using ConvertIndexShrU =
44
45/// It is the case that when we convert bitwise operations to SPIR-V operations
46/// we must take into account the special pattern in SPIR-V that if the
47/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
48/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
49/// index.add is never a boolean operation so we can directly convert it to the
50/// Bitwise[And|Or]Op.
54
55//===----------------------------------------------------------------------===//
56// ConvertConstantBool
57//===----------------------------------------------------------------------===//
58
59// Converts index.bool.constant operation to spirv.Constant.
60struct ConvertIndexConstantBoolOpPattern final
61 : OpConversionPattern<BoolConstantOp> {
62 using Base::Base;
63
64 LogicalResult
65 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const override {
67 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
68 op.getValueAttr());
69 return success();
70 }
71};
72
73//===----------------------------------------------------------------------===//
74// ConvertConstant
75//===----------------------------------------------------------------------===//
76
77// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
78// when required.
79struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
80 using Base::Base;
81
82 LogicalResult
83 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter) const override {
85 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
86 Type indexType = typeConverter->getIndexType();
87
88 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
89 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
90 op, indexType, IntegerAttr::get(indexType, value));
91 return success();
92 }
93};
94
95//===----------------------------------------------------------------------===//
96// ConvertIndexCeilDivS
97//===----------------------------------------------------------------------===//
98
99/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
100/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
101/// conversion in IndexToLLVM.
102struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
103 using Base::Base;
104
105 LogicalResult
106 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
107 ConversionPatternRewriter &rewriter) const override {
108 Location loc = op.getLoc();
109 Value n = adaptor.getLhs();
110 Type n_type = n.getType();
111 Value m = adaptor.getRhs();
112
113 // Define the constants
114 Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
115 IntegerAttr::get(n_type, 0));
116 Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
117 IntegerAttr::get(n_type, 1));
118 Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
119 IntegerAttr::get(n_type, -1));
120
121 // Compute `x`.
122 Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
123 Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
124
125 // Compute the positive result.
126 Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
127 Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
128 Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);
129
130 // Compute the negative result.
131 Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
132 Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
133 Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);
134
135 // Pick the positive result if `n` and `m` have the same sign and `n` is
136 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
137 Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
138 Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
139 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
140 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
141 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
142 return success();
143 }
144};
145
146//===----------------------------------------------------------------------===//
147// ConvertIndexCeilDivU
148//===----------------------------------------------------------------------===//
149
150/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
151/// from the equivalent conversion in IndexToLLVM.
152struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
153 using Base::Base;
154
155 LogicalResult
156 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 Location loc = op.getLoc();
159 Value n = adaptor.getLhs();
160 Type n_type = n.getType();
161 Value m = adaptor.getRhs();
162
163 // Define the constants
164 Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
165 IntegerAttr::get(n_type, 0));
166 Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
167 IntegerAttr::get(n_type, 1));
168
169 // Compute the non-zero result.
170 Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
171 Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
172 Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);
173
174 // Pick the result
175 Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
176 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
177 return success();
178 }
179};
180
181//===----------------------------------------------------------------------===//
182// ConvertIndexFloorDivS
183//===----------------------------------------------------------------------===//
184
185/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
186/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
187/// in IndexToLLVM.
188struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
189 using Base::Base;
190
191 LogicalResult
192 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter) const override {
194 Location loc = op.getLoc();
195 Value n = adaptor.getLhs();
196 Type n_type = n.getType();
197 Value m = adaptor.getRhs();
198
199 // Define the constants
200 Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
201 IntegerAttr::get(n_type, 0));
202 Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
203 IntegerAttr::get(n_type, 1));
204 Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
205 IntegerAttr::get(n_type, -1));
206
207 // Compute `x`.
208 Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
209 Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
210
211 // Compute the negative result
212 Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
213 Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
214 Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);
215
216 // Compute the positive result.
217 Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);
218
219 // Pick the negative result if `n` and `m` have different signs and `n` is
220 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
221 Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
222 Value diffSign =
223 spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
224 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
225
226 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
227 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
228 return success();
229 }
230};
231
232//===----------------------------------------------------------------------===//
233// ConvertIndexCast
234//===----------------------------------------------------------------------===//
235
236/// Convert a cast op. If the materialized index type is the same as the other
237/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239/// zero extend when the result bitwidth is larger.
240template <typename CastOp, typename ConvertOp>
241struct ConvertIndexCast final : OpConversionPattern<CastOp> {
242 using OpConversionPattern<CastOp>::OpConversionPattern;
243
244 LogicalResult
245 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
248 Type indexType = typeConverter->getIndexType();
249
250 Type srcType = adaptor.getInput().getType();
251 Type dstType = op.getType();
252 if (isa<IndexType>(srcType)) {
253 srcType = indexType;
254 }
255 if (isa<IndexType>(dstType)) {
256 dstType = indexType;
257 }
258
259 if (srcType == dstType) {
260 rewriter.replaceOp(op, adaptor.getInput());
261 } else {
262 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263 adaptor.getOperands());
264 }
265 return success();
266 }
267};
268
269using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
271
272//===----------------------------------------------------------------------===//
273// ConvertIndexCmp
274//===----------------------------------------------------------------------===//
275
276// Helper template to replace the operation
277template <typename ICmpOp>
278static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter) {
280 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
281 return success();
282}
283
284struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
285 using Base::Base;
286
287 LogicalResult
288 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 // We must convert the predicates to the corresponding int comparions.
291 switch (op.getPred()) {
292 case IndexCmpPredicate::EQ:
293 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294 case IndexCmpPredicate::NE:
295 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296 case IndexCmpPredicate::SGE:
297 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::SGT:
299 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SLE:
301 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SLT:
303 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::UGE:
305 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::UGT:
307 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::ULE:
309 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::ULT:
311 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
312 }
313 llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
314 }
315};
316
317//===----------------------------------------------------------------------===//
318// ConvertIndexSizeOf
319//===----------------------------------------------------------------------===//
320
321/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
322struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
323 using Base::Base;
324
325 LogicalResult
326 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
327 ConversionPatternRewriter &rewriter) const override {
328 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
329 Type indexType = typeConverter->getIndexType();
330 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
331 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
332 op, indexType, IntegerAttr::get(indexType, bitwidth));
333 return success();
334 }
335};
336} // namespace
337
338//===----------------------------------------------------------------------===//
339// Pattern Population
340//===----------------------------------------------------------------------===//
341
343 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
344 patterns.add<
345 // clang-format off
346 ConvertIndexAdd,
347 ConvertIndexSub,
348 ConvertIndexMul,
349 ConvertIndexDivS,
350 ConvertIndexDivU,
351 ConvertIndexRemS,
352 ConvertIndexRemU,
353 ConvertIndexMaxS,
354 ConvertIndexMaxU,
355 ConvertIndexMinS,
356 ConvertIndexMinU,
357 ConvertIndexShl,
358 ConvertIndexShrS,
359 ConvertIndexShrU,
360 ConvertIndexAnd,
361 ConvertIndexOr,
362 ConvertIndexXor,
363 ConvertIndexConstantBoolOpPattern,
364 ConvertIndexConstantOpPattern,
365 ConvertIndexCeilDivSPattern,
366 ConvertIndexCeilDivUPattern,
367 ConvertIndexFloorDivSPattern,
368 ConvertIndexCastS,
369 ConvertIndexCastU,
370 ConvertIndexCmpPattern,
371 ConvertIndexSizeOf
372 >(typeConverter, patterns.getContext());
373}
374
375//===----------------------------------------------------------------------===//
376// ODS-Generated Definitions
377//===----------------------------------------------------------------------===//
378
379namespace mlir {
380#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381#include "mlir/Conversion/Passes.h.inc"
382} // namespace mlir
383
384//===----------------------------------------------------------------------===//
385// Pass Definition
386//===----------------------------------------------------------------------===//
387
388namespace {
389struct ConvertIndexToSPIRVPass
390 : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
391 using Base::Base;
392
393 void runOnOperation() override {
394 Operation *op = getOperation();
395 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
396 std::unique_ptr<SPIRVConversionTarget> target =
397 SPIRVConversionTarget::get(targetAttr);
398
399 SPIRVConversionOptions options;
400 options.use64bitIndex = this->use64bitIndex;
401 SPIRVTypeConverter typeConverter(targetAttr, options);
402
403 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
404 // in patterns for other dialects.
405 target->addLegalOp<UnrealizedConversionCastOp>();
406
407 // Allow the spirv operations we are converting to
408 target->addLegalDialect<spirv::SPIRVDialect>();
409 // Fail hard when there are any remaining 'index' ops.
410 target->addIllegalDialect<index::IndexDialect>();
411
412 RewritePatternSet patterns(&getContext());
414
415 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
416 signalPassFailure();
417 }
418};
419} // namespace
return success()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:23