MLIR  22.0.0git
IndexToSPIRV.cpp
Go to the documentation of this file.
1 //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../SPIRVCommon/Pattern.h"
16 
17 using namespace mlir;
18 using namespace index;
19 
20 namespace {
21 
22 //===----------------------------------------------------------------------===//
23 // Trivial Conversions
24 //===----------------------------------------------------------------------===//
25 
37 
38 using ConvertIndexShl =
40 using ConvertIndexShrS =
42 using ConvertIndexShrU =
44 
45 /// It is the case that when we convert bitwise operations to SPIR-V operations
46 /// we must take into account the special pattern in SPIR-V that if the
47 /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
48 /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
49 /// index.add is never a boolean operation so we can directly convert it to the
50 /// Bitwise[And|Or]Op.
54 
55 //===----------------------------------------------------------------------===//
56 // ConvertConstantBool
57 //===----------------------------------------------------------------------===//
58 
59 // Converts index.bool.constant operation to spirv.Constant.
60 struct ConvertIndexConstantBoolOpPattern final
61  : OpConversionPattern<BoolConstantOp> {
63 
64  LogicalResult
65  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
66  ConversionPatternRewriter &rewriter) const override {
67  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
68  op.getValueAttr());
69  return success();
70  }
71 };
72 
73 //===----------------------------------------------------------------------===//
74 // ConvertConstant
75 //===----------------------------------------------------------------------===//
76 
77 // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
78 // when required.
79 struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
81 
82  LogicalResult
83  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
84  ConversionPatternRewriter &rewriter) const override {
85  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
86  Type indexType = typeConverter->getIndexType();
87 
88  APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
89  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
90  op, indexType, IntegerAttr::get(indexType, value));
91  return success();
92  }
93 };
94 
95 //===----------------------------------------------------------------------===//
96 // ConvertIndexCeilDivS
97 //===----------------------------------------------------------------------===//
98 
99 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
100 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
101 /// conversion in IndexToLLVM.
102 struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
104 
105  LogicalResult
106  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
107  ConversionPatternRewriter &rewriter) const override {
108  Location loc = op.getLoc();
109  Value n = adaptor.getLhs();
110  Type n_type = n.getType();
111  Value m = adaptor.getRhs();
112 
113  // Define the constants
114  Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
115  IntegerAttr::get(n_type, 0));
116  Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
117  IntegerAttr::get(n_type, 1));
118  Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
119  IntegerAttr::get(n_type, -1));
120 
121  // Compute `x`.
122  Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
123  Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
124 
125  // Compute the positive result.
126  Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
127  Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
128  Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);
129 
130  // Compute the negative result.
131  Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
132  Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
133  Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);
134 
135  // Pick the positive result if `n` and `m` have the same sign and `n` is
136  // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
137  Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
138  Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
139  Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
140  Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
141  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
142  return success();
143  }
144 };
145 
146 //===----------------------------------------------------------------------===//
147 // ConvertIndexCeilDivU
148 //===----------------------------------------------------------------------===//
149 
150 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
151 /// from the equivalent conversion in IndexToLLVM.
152 struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
154 
155  LogicalResult
156  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
157  ConversionPatternRewriter &rewriter) const override {
158  Location loc = op.getLoc();
159  Value n = adaptor.getLhs();
160  Type n_type = n.getType();
161  Value m = adaptor.getRhs();
162 
163  // Define the constants
164  Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
165  IntegerAttr::get(n_type, 0));
166  Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
167  IntegerAttr::get(n_type, 1));
168 
169  // Compute the non-zero result.
170  Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
171  Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
172  Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);
173 
174  // Pick the result
175  Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
176  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
177  return success();
178  }
179 };
180 
181 //===----------------------------------------------------------------------===//
182 // ConvertIndexFloorDivS
183 //===----------------------------------------------------------------------===//
184 
185 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
186 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
187 /// in IndexToLLVM.
188 struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
190 
191  LogicalResult
192  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
193  ConversionPatternRewriter &rewriter) const override {
194  Location loc = op.getLoc();
195  Value n = adaptor.getLhs();
196  Type n_type = n.getType();
197  Value m = adaptor.getRhs();
198 
199  // Define the constants
200  Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
201  IntegerAttr::get(n_type, 0));
202  Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
203  IntegerAttr::get(n_type, 1));
204  Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
205  IntegerAttr::get(n_type, -1));
206 
207  // Compute `x`.
208  Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
209  Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
210 
211  // Compute the negative result
212  Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
213  Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
214  Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);
215 
216  // Compute the positive result.
217  Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);
218 
219  // Pick the negative result if `n` and `m` have different signs and `n` is
220  // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
221  Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
222  Value diffSign =
223  spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
224  Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
225 
226  Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
227  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
228  return success();
229  }
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // ConvertIndexCast
234 //===----------------------------------------------------------------------===//
235 
236 /// Convert a cast op. If the materialized index type is the same as the other
237 /// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239 /// zero extend when the result bitwidth is larger.
240 template <typename CastOp, typename ConvertOp>
241 struct ConvertIndexCast final : OpConversionPattern<CastOp> {
243 
244  LogicalResult
245  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
248  Type indexType = typeConverter->getIndexType();
249 
250  Type srcType = adaptor.getInput().getType();
251  Type dstType = op.getType();
252  if (isa<IndexType>(srcType)) {
253  srcType = indexType;
254  }
255  if (isa<IndexType>(dstType)) {
256  dstType = indexType;
257  }
258 
259  if (srcType == dstType) {
260  rewriter.replaceOp(op, adaptor.getInput());
261  } else {
262  rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263  adaptor.getOperands());
264  }
265  return success();
266  }
267 };
268 
269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
271 
272 //===----------------------------------------------------------------------===//
273 // ConvertIndexCmp
274 //===----------------------------------------------------------------------===//
275 
276 // Helper template to replace the operation
277 template <typename ICmpOp>
278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279  ConversionPatternRewriter &rewriter) {
280  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
281  return success();
282 }
283 
284 struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
286 
287  LogicalResult
288  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289  ConversionPatternRewriter &rewriter) const override {
290  // We must convert the predicates to the corresponding int comparions.
291  switch (op.getPred()) {
292  case IndexCmpPredicate::EQ:
293  return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294  case IndexCmpPredicate::NE:
295  return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296  case IndexCmpPredicate::SGE:
297  return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298  case IndexCmpPredicate::SGT:
299  return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300  case IndexCmpPredicate::SLE:
301  return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302  case IndexCmpPredicate::SLT:
303  return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304  case IndexCmpPredicate::UGE:
305  return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306  case IndexCmpPredicate::UGT:
307  return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308  case IndexCmpPredicate::ULE:
309  return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310  case IndexCmpPredicate::ULT:
311  return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
312  }
313  llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
314  }
315 };
316 
317 //===----------------------------------------------------------------------===//
318 // ConvertIndexSizeOf
319 //===----------------------------------------------------------------------===//
320 
321 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
322 struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
324 
325  LogicalResult
326  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
327  ConversionPatternRewriter &rewriter) const override {
328  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
329  Type indexType = typeConverter->getIndexType();
330  unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
331  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
332  op, indexType, IntegerAttr::get(indexType, bitwidth));
333  return success();
334  }
335 };
336 } // namespace
337 
338 //===----------------------------------------------------------------------===//
339 // Pattern Population
340 //===----------------------------------------------------------------------===//
341 
343  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
344  patterns.add<
345  // clang-format off
346  ConvertIndexAdd,
347  ConvertIndexSub,
348  ConvertIndexMul,
349  ConvertIndexDivS,
350  ConvertIndexDivU,
351  ConvertIndexRemS,
352  ConvertIndexRemU,
353  ConvertIndexMaxS,
354  ConvertIndexMaxU,
355  ConvertIndexMinS,
356  ConvertIndexMinU,
357  ConvertIndexShl,
358  ConvertIndexShrS,
359  ConvertIndexShrU,
360  ConvertIndexAnd,
361  ConvertIndexOr,
362  ConvertIndexXor,
363  ConvertIndexConstantBoolOpPattern,
364  ConvertIndexConstantOpPattern,
365  ConvertIndexCeilDivSPattern,
366  ConvertIndexCeilDivUPattern,
367  ConvertIndexFloorDivSPattern,
368  ConvertIndexCastS,
369  ConvertIndexCastU,
370  ConvertIndexCmpPattern,
371  ConvertIndexSizeOf
372  >(typeConverter, patterns.getContext());
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // ODS-Generated Definitions
377 //===----------------------------------------------------------------------===//
378 
379 namespace mlir {
380 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381 #include "mlir/Conversion/Passes.h.inc"
382 } // namespace mlir
383 
384 //===----------------------------------------------------------------------===//
385 // Pass Definition
386 //===----------------------------------------------------------------------===//
387 
388 namespace {
389 struct ConvertIndexToSPIRVPass
390  : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
391  using Base::Base;
392 
393  void runOnOperation() override {
394  Operation *op = getOperation();
396  std::unique_ptr<SPIRVConversionTarget> target =
397  SPIRVConversionTarget::get(targetAttr);
398 
400  options.use64bitIndex = this->use64bitIndex;
401  SPIRVTypeConverter typeConverter(targetAttr, options);
402 
403  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
404  // in patterns for other dialects.
405  target->addLegalOp<UnrealizedConversionCastOp>();
406 
407  // Allow the spirv operations we are converting to
408  target->addLegalDialect<spirv::SPIRVDialect>();
409  // Fail hard when there are any remaining 'index' ops.
410  target->addIllegalDialect<index::IndexDialect>();
411 
414 
415  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
416  signalPassFailure();
417  }
418 };
419 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23