MLIR  20.0.0git
IndexToSPIRV.cpp
Go to the documentation of this file.
1 //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../SPIRVCommon/Pattern.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace index;
20 
21 namespace {
22 
23 //===----------------------------------------------------------------------===//
24 // Trivial Conversions
25 //===----------------------------------------------------------------------===//
26 
38 
39 using ConvertIndexShl =
41 using ConvertIndexShrS =
43 using ConvertIndexShrU =
45 
46 /// It is the case that when we convert bitwise operations to SPIR-V operations
47 /// we must take into account the special pattern in SPIR-V that if the
48 /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
49 /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
50 /// index.add is never a boolean operation so we can directly convert it to the
51 /// Bitwise[And|Or]Op.
55 
56 //===----------------------------------------------------------------------===//
57 // ConvertConstantBool
58 //===----------------------------------------------------------------------===//
59 
60 // Converts index.bool.constant operation to spirv.Constant.
61 struct ConvertIndexConstantBoolOpPattern final
62  : OpConversionPattern<BoolConstantOp> {
64 
65  LogicalResult
66  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
67  ConversionPatternRewriter &rewriter) const override {
68  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
69  op.getValueAttr());
70  return success();
71  }
72 };
73 
74 //===----------------------------------------------------------------------===//
75 // ConvertConstant
76 //===----------------------------------------------------------------------===//
77 
78 // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
79 // when required.
80 struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
82 
83  LogicalResult
84  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
85  ConversionPatternRewriter &rewriter) const override {
86  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
87  Type indexType = typeConverter->getIndexType();
88 
89  APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
90  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
91  op, indexType, IntegerAttr::get(indexType, value));
92  return success();
93  }
94 };
95 
96 //===----------------------------------------------------------------------===//
97 // ConvertIndexCeilDivS
98 //===----------------------------------------------------------------------===//
99 
100 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
101 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
102 /// conversion in IndexToLLVM.
103 struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
105 
106  LogicalResult
107  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
108  ConversionPatternRewriter &rewriter) const override {
109  Location loc = op.getLoc();
110  Value n = adaptor.getLhs();
111  Type n_type = n.getType();
112  Value m = adaptor.getRhs();
113 
114  // Define the constants
115  Value zero = rewriter.create<spirv::ConstantOp>(
116  loc, n_type, IntegerAttr::get(n_type, 0));
117  Value posOne = rewriter.create<spirv::ConstantOp>(
118  loc, n_type, IntegerAttr::get(n_type, 1));
119  Value negOne = rewriter.create<spirv::ConstantOp>(
120  loc, n_type, IntegerAttr::get(n_type, -1));
121 
122  // Compute `x`.
123  Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
124  Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
125 
126  // Compute the positive result.
127  Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
128  Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
129  Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
130 
131  // Compute the negative result.
132  Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
133  Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
134  Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
135 
136  // Pick the positive result if `n` and `m` have the same sign and `n` is
137  // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
138  Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
139  Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
140  Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
141  Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
142  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
143  return success();
144  }
145 };
146 
147 //===----------------------------------------------------------------------===//
148 // ConvertIndexCeilDivU
149 //===----------------------------------------------------------------------===//
150 
151 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
152 /// from the equivalent conversion in IndexToLLVM.
153 struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
155 
156  LogicalResult
157  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
158  ConversionPatternRewriter &rewriter) const override {
159  Location loc = op.getLoc();
160  Value n = adaptor.getLhs();
161  Type n_type = n.getType();
162  Value m = adaptor.getRhs();
163 
164  // Define the constants
165  Value zero = rewriter.create<spirv::ConstantOp>(
166  loc, n_type, IntegerAttr::get(n_type, 0));
167  Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
168  IntegerAttr::get(n_type, 1));
169 
170  // Compute the non-zero result.
171  Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
172  Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
173  Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
174 
175  // Pick the result
176  Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
177  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
178  return success();
179  }
180 };
181 
182 //===----------------------------------------------------------------------===//
183 // ConvertIndexFloorDivS
184 //===----------------------------------------------------------------------===//
185 
186 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
187 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
188 /// in IndexToLLVM.
189 struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
191 
192  LogicalResult
193  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
194  ConversionPatternRewriter &rewriter) const override {
195  Location loc = op.getLoc();
196  Value n = adaptor.getLhs();
197  Type n_type = n.getType();
198  Value m = adaptor.getRhs();
199 
200  // Define the constants
201  Value zero = rewriter.create<spirv::ConstantOp>(
202  loc, n_type, IntegerAttr::get(n_type, 0));
203  Value posOne = rewriter.create<spirv::ConstantOp>(
204  loc, n_type, IntegerAttr::get(n_type, 1));
205  Value negOne = rewriter.create<spirv::ConstantOp>(
206  loc, n_type, IntegerAttr::get(n_type, -1));
207 
208  // Compute `x`.
209  Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
210  Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
211 
212  // Compute the negative result
213  Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
214  Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
215  Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
216 
217  // Compute the positive result.
218  Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
219 
220  // Pick the negative result if `n` and `m` have different signs and `n` is
221  // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
222  Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
223  Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
224  Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
225 
226  Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
227  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
228  return success();
229  }
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // ConvertIndexCast
234 //===----------------------------------------------------------------------===//
235 
236 /// Convert a cast op. If the materialized index type is the same as the other
237 /// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239 /// zero extend when the result bitwidth is larger.
240 template <typename CastOp, typename ConvertOp>
241 struct ConvertIndexCast final : OpConversionPattern<CastOp> {
243 
244  LogicalResult
245  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
248  Type indexType = typeConverter->getIndexType();
249 
250  Type srcType = adaptor.getInput().getType();
251  Type dstType = op.getType();
252  if (isa<IndexType>(srcType)) {
253  srcType = indexType;
254  }
255  if (isa<IndexType>(dstType)) {
256  dstType = indexType;
257  }
258 
259  if (srcType == dstType) {
260  rewriter.replaceOp(op, adaptor.getInput());
261  } else {
262  rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263  adaptor.getOperands());
264  }
265  return success();
266  }
267 };
268 
269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
271 
272 //===----------------------------------------------------------------------===//
273 // ConvertIndexCmp
274 //===----------------------------------------------------------------------===//
275 
276 // Helper template to replace the operation
277 template <typename ICmpOp>
278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279  ConversionPatternRewriter &rewriter) {
280  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
281  return success();
282 }
283 
284 struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
286 
287  LogicalResult
288  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289  ConversionPatternRewriter &rewriter) const override {
290  // We must convert the predicates to the corresponding int comparions.
291  switch (op.getPred()) {
292  case IndexCmpPredicate::EQ:
293  return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294  case IndexCmpPredicate::NE:
295  return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296  case IndexCmpPredicate::SGE:
297  return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298  case IndexCmpPredicate::SGT:
299  return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300  case IndexCmpPredicate::SLE:
301  return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302  case IndexCmpPredicate::SLT:
303  return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304  case IndexCmpPredicate::UGE:
305  return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306  case IndexCmpPredicate::UGT:
307  return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308  case IndexCmpPredicate::ULE:
309  return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310  case IndexCmpPredicate::ULT:
311  return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
312  }
313  llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
314  }
315 };
316 
317 //===----------------------------------------------------------------------===//
318 // ConvertIndexSizeOf
319 //===----------------------------------------------------------------------===//
320 
321 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
322 struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
324 
325  LogicalResult
326  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
327  ConversionPatternRewriter &rewriter) const override {
328  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
329  Type indexType = typeConverter->getIndexType();
330  unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
331  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
332  op, indexType, IntegerAttr::get(indexType, bitwidth));
333  return success();
334  }
335 };
336 } // namespace
337 
338 //===----------------------------------------------------------------------===//
339 // Pattern Population
340 //===----------------------------------------------------------------------===//
341 
343  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
344  patterns.add<
345  // clang-format off
346  ConvertIndexAdd,
347  ConvertIndexSub,
348  ConvertIndexMul,
349  ConvertIndexDivS,
350  ConvertIndexDivU,
351  ConvertIndexRemS,
352  ConvertIndexRemU,
353  ConvertIndexMaxS,
354  ConvertIndexMaxU,
355  ConvertIndexMinS,
356  ConvertIndexMinU,
357  ConvertIndexShl,
358  ConvertIndexShrS,
359  ConvertIndexShrU,
360  ConvertIndexAnd,
361  ConvertIndexOr,
362  ConvertIndexXor,
363  ConvertIndexConstantBoolOpPattern,
364  ConvertIndexConstantOpPattern,
365  ConvertIndexCeilDivSPattern,
366  ConvertIndexCeilDivUPattern,
367  ConvertIndexFloorDivSPattern,
368  ConvertIndexCastS,
369  ConvertIndexCastU,
370  ConvertIndexCmpPattern,
371  ConvertIndexSizeOf
372  >(typeConverter, patterns.getContext());
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // ODS-Generated Definitions
377 //===----------------------------------------------------------------------===//
378 
379 namespace mlir {
380 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381 #include "mlir/Conversion/Passes.h.inc"
382 } // namespace mlir
383 
384 //===----------------------------------------------------------------------===//
385 // Pass Definition
386 //===----------------------------------------------------------------------===//
387 
388 namespace {
389 struct ConvertIndexToSPIRVPass
390  : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
391  using Base::Base;
392 
393  void runOnOperation() override {
394  Operation *op = getOperation();
396  std::unique_ptr<SPIRVConversionTarget> target =
397  SPIRVConversionTarget::get(targetAttr);
398 
400  options.use64bitIndex = this->use64bitIndex;
401  SPIRVTypeConverter typeConverter(targetAttr, options);
402 
403  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
404  // in patterns for other dialects.
405  target->addLegalOp<UnrealizedConversionCastOp>();
406 
407  // Allow the spirv operations we are converting to
408  target->addLegalDialect<spirv::SPIRVDialect>();
409  // Fail hard when there are any remaining 'index' ops.
410  target->addIllegalDialect<index::IndexDialect>();
411 
412  RewritePatternSet patterns(&getContext());
413  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
414 
415  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
416  signalPassFailure();
417  }
418 };
419 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23