MLIR 23.0.0git
IndexToSPIRV.cpp
Go to the documentation of this file.
1//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16
17using namespace mlir;
18using namespace index;
19
20namespace {
21
22//===----------------------------------------------------------------------===//
23// Trivial Conversions
24//===----------------------------------------------------------------------===//
25
41
42using ConvertIndexShl =
44using ConvertIndexShrS =
46using ConvertIndexShrU =
48
49/// It is the case that when we convert bitwise operations to SPIR-V operations
50/// we must take into account the special pattern in SPIR-V that if the
51/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
52/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
53/// index.add is never a boolean operation so we can directly convert it to the
54/// Bitwise[And|Or]Op.
58
59//===----------------------------------------------------------------------===//
60// ConvertConstantBool
61//===----------------------------------------------------------------------===//
62
63// Converts index.bool.constant operation to spirv.Constant.
64struct ConvertIndexConstantBoolOpPattern final
65 : OpConversionPattern<BoolConstantOp> {
66 using Base::Base;
67
68 LogicalResult
69 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter) const override {
71 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
72 op.getValueAttr());
73 return success();
74 }
75};
76
77//===----------------------------------------------------------------------===//
78// ConvertConstant
79//===----------------------------------------------------------------------===//
80
81// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
82// when required.
83struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
84 using Base::Base;
85
86 LogicalResult
87 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override {
89 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
90 Type indexType = typeConverter->getIndexType();
91
92 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
93 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
94 op, indexType, IntegerAttr::get(indexType, value));
95 return success();
96 }
97};
98
99//===----------------------------------------------------------------------===//
100// ConvertIndexCeilDivS
101//===----------------------------------------------------------------------===//
102
103/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
104/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
105/// conversion in IndexToLLVM.
106struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
107 using Base::Base;
108
109 LogicalResult
110 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter) const override {
112 Location loc = op.getLoc();
113 Value n = adaptor.getLhs();
114 Type nType = n.getType();
115 Value m = adaptor.getRhs();
116
117 // Define the constants
118 Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
119 IntegerAttr::get(nType, 0));
120 Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
121 IntegerAttr::get(nType, 1));
122 Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
123 IntegerAttr::get(nType, -1));
124
125 // Compute `x`.
126 Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
127 Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
128
129 // Compute the positive result.
130 Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
131 Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
132 Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);
133
134 // Compute the negative result.
135 Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
136 Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
137 Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);
138
139 // Pick the positive result if `n` and `m` have the same sign and `n` is
140 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
141 Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
142 Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
143 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
144 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
145 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
146 return success();
147 }
148};
149
150//===----------------------------------------------------------------------===//
151// ConvertIndexCeilDivU
152//===----------------------------------------------------------------------===//
153
154/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
155/// from the equivalent conversion in IndexToLLVM.
156struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
157 using Base::Base;
158
159 LogicalResult
160 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
161 ConversionPatternRewriter &rewriter) const override {
162 Location loc = op.getLoc();
163 Value n = adaptor.getLhs();
164 Type nType = n.getType();
165 Value m = adaptor.getRhs();
166
167 // Define the constants
168 Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
169 IntegerAttr::get(nType, 0));
170 Value one = spirv::ConstantOp::create(rewriter, loc, nType,
171 IntegerAttr::get(nType, 1));
172
173 // Compute the non-zero result.
174 Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
175 Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
176 Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);
177
178 // Pick the result
179 Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
180 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
181 return success();
182 }
183};
184
185//===----------------------------------------------------------------------===//
186// ConvertIndexFloorDivS
187//===----------------------------------------------------------------------===//
188
189/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
190/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
191/// in IndexToLLVM.
192struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
193 using Base::Base;
194
195 LogicalResult
196 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
197 ConversionPatternRewriter &rewriter) const override {
198 Location loc = op.getLoc();
199 Value n = adaptor.getLhs();
200 Type nType = n.getType();
201 Value m = adaptor.getRhs();
202
203 // Define the constants
204 Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
205 IntegerAttr::get(nType, 0));
206 Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
207 IntegerAttr::get(nType, 1));
208 Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
209 IntegerAttr::get(nType, -1));
210
211 // Compute `x`.
212 Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
213 Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
214
215 // Compute the negative result
216 Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
217 Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
218 Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);
219
220 // Compute the positive result.
221 Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);
222
223 // Pick the negative result if `n` and `m` have different signs and `n` is
224 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
225 Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
226 Value diffSign =
227 spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
228 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
229
230 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
231 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
232 return success();
233 }
234};
235
236//===----------------------------------------------------------------------===//
237// ConvertIndexCast
238//===----------------------------------------------------------------------===//
239
240/// Convert a cast op. If the materialized index type is the same as the other
241/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
242/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
243/// zero extend when the result bitwidth is larger.
244template <typename CastOp, typename ConvertOp>
245struct ConvertIndexCast final : OpConversionPattern<CastOp> {
246 using OpConversionPattern<CastOp>::OpConversionPattern;
247
248 LogicalResult
249 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
250 ConversionPatternRewriter &rewriter) const override {
251 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
252 Type indexType = typeConverter->getIndexType();
253
254 Type srcType = adaptor.getInput().getType();
255 Type dstType = op.getType();
256 if (isa<IndexType>(srcType)) {
257 srcType = indexType;
258 }
259 if (isa<IndexType>(dstType)) {
260 dstType = indexType;
261 }
262
263 if (srcType == dstType) {
264 rewriter.replaceOp(op, adaptor.getInput());
265 } else {
266 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
267 adaptor.getOperands());
268 }
269 return success();
270 }
271};
272
273using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
274using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
275
276//===----------------------------------------------------------------------===//
277// ConvertIndexCmp
278//===----------------------------------------------------------------------===//
279
280// Helper template to replace the operation
281template <typename ICmpOp>
282static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter) {
284 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
285 return success();
286}
287
288struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
289 using Base::Base;
290
291 LogicalResult
292 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
293 ConversionPatternRewriter &rewriter) const override {
294 // We must convert the predicates to the corresponding int comparions.
295 switch (op.getPred()) {
296 case IndexCmpPredicate::EQ:
297 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::NE:
299 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SGE:
301 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SGT:
303 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::SLE:
305 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::SLT:
307 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::UGE:
309 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::UGT:
311 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
312 case IndexCmpPredicate::ULE:
313 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
314 case IndexCmpPredicate::ULT:
315 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
316 }
317 llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
318 }
319};
320
321//===----------------------------------------------------------------------===//
322// ConvertIndexSizeOf
323//===----------------------------------------------------------------------===//
324
325/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
326struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
327 using Base::Base;
328
329 LogicalResult
330 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
331 ConversionPatternRewriter &rewriter) const override {
332 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
333 Type indexType = typeConverter->getIndexType();
334 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
335 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
336 op, indexType, IntegerAttr::get(indexType, bitwidth));
337 return success();
338 }
339};
340} // namespace
341
342//===----------------------------------------------------------------------===//
343// Pattern Population
344//===----------------------------------------------------------------------===//
345
347 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
348 patterns.add<
349 // clang-format off
350 ConvertIndexAdd,
351 ConvertIndexSub,
352 ConvertIndexMul,
353 ConvertIndexDivS,
354 ConvertIndexDivU,
355 ConvertIndexRemS,
356 ConvertIndexRemU,
357 ConvertIndexShl,
358 ConvertIndexShrS,
359 ConvertIndexShrU,
360 ConvertIndexAnd,
361 ConvertIndexOr,
362 ConvertIndexXor,
363 ConvertIndexConstantBoolOpPattern,
364 ConvertIndexConstantOpPattern,
365 ConvertIndexCeilDivSPattern,
366 ConvertIndexCeilDivUPattern,
367 ConvertIndexFloorDivSPattern,
368 ConvertIndexCastS,
369 ConvertIndexCastU,
370 ConvertIndexCmpPattern,
371 ConvertIndexSizeOf
372 >(typeConverter, patterns.getContext());
373 // clang-format on
374
375 // GLSL min/max patterns.
376 patterns.add<ConvertIndexMaxSGL, ConvertIndexMaxUGL, ConvertIndexMinSGL,
377 ConvertIndexMinUGL>(typeConverter, patterns.getContext());
378
379 // OpenCL min/max patterns.
380 patterns.add<ConvertIndexMaxSCL, ConvertIndexMaxUCL, ConvertIndexMinSCL,
381 ConvertIndexMinUCL>(typeConverter, patterns.getContext());
382}
383
384//===----------------------------------------------------------------------===//
385// ODS-Generated Definitions
386//===----------------------------------------------------------------------===//
387
388namespace mlir {
389#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
390#include "mlir/Conversion/Passes.h.inc"
391} // namespace mlir
392
393//===----------------------------------------------------------------------===//
394// Pass Definition
395//===----------------------------------------------------------------------===//
396
397namespace {
398struct ConvertIndexToSPIRVPass
399 : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
400 using Base::Base;
401
402 void runOnOperation() override {
403 Operation *op = getOperation();
404 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
405 std::unique_ptr<SPIRVConversionTarget> target =
406 SPIRVConversionTarget::get(targetAttr);
407
408 SPIRVConversionOptions options;
409 options.use64bitIndex = this->use64bitIndex;
410 SPIRVTypeConverter typeConverter(targetAttr, options);
411
412 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
413 // in patterns for other dialects.
414 target->addLegalOp<UnrealizedConversionCastOp>();
415
416 // Fail hard when there are any remaining 'index' ops.
417 target->addIllegalDialect<index::IndexDialect>();
418
419 RewritePatternSet patterns(&getContext());
420 index::populateIndexToSPIRVPatterns(typeConverter, patterns);
421
422 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
423 signalPassFailure();
424 }
425};
426} // namespace
return success()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24