MLIR  19.0.0git
IndexToSPIRV.cpp
Go to the documentation of this file.
1 //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../SPIRVCommon/Pattern.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace index;
20 
21 namespace {
22 
23 //===----------------------------------------------------------------------===//
24 // Trivial Conversions
25 //===----------------------------------------------------------------------===//
26 
38 
39 using ConvertIndexShl =
41 using ConvertIndexShrS =
43 using ConvertIndexShrU =
45 
46 /// It is the case that when we convert bitwise operations to SPIR-V operations
47 /// we must take into account the special pattern in SPIR-V that if the
48 /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
49 /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
50 /// index.add is never a boolean operation so we can directly convert it to the
51 /// Bitwise[And|Or]Op.
55 
56 //===----------------------------------------------------------------------===//
57 // ConvertConstantBool
58 //===----------------------------------------------------------------------===//
59 
60 // Converts index.bool.constant operation to spirv.Constant.
61 struct ConvertIndexConstantBoolOpPattern final
62  : OpConversionPattern<BoolConstantOp> {
64 
66  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
67  ConversionPatternRewriter &rewriter) const override {
68  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
69  op.getValueAttr());
70  return success();
71  }
72 };
73 
74 //===----------------------------------------------------------------------===//
75 // ConvertConstant
76 //===----------------------------------------------------------------------===//
77 
78 // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
79 // when required.
80 struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
82 
84  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
85  ConversionPatternRewriter &rewriter) const override {
86  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
87  Type indexType = typeConverter->getIndexType();
88 
89  APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
90  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
91  op, indexType, IntegerAttr::get(indexType, value));
92  return success();
93  }
94 };
95 
96 //===----------------------------------------------------------------------===//
97 // ConvertIndexCeilDivS
98 //===----------------------------------------------------------------------===//
99 
100 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
101 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
102 /// conversion in IndexToLLVM.
103 struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
105 
107  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
108  ConversionPatternRewriter &rewriter) const override {
109  Location loc = op.getLoc();
110  Value n = adaptor.getLhs();
111  Type n_type = n.getType();
112  Value m = adaptor.getRhs();
113 
114  // Define the constants
115  Value zero = rewriter.create<spirv::ConstantOp>(
116  loc, n_type, IntegerAttr::get(n_type, 0));
117  Value posOne = rewriter.create<spirv::ConstantOp>(
118  loc, n_type, IntegerAttr::get(n_type, 1));
119  Value negOne = rewriter.create<spirv::ConstantOp>(
120  loc, n_type, IntegerAttr::get(n_type, -1));
121 
122  // Compute `x`.
123  Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
124  Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
125 
126  // Compute the positive result.
127  Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
128  Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
129  Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
130 
131  // Compute the negative result.
132  Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
133  Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
134  Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
135 
136  // Pick the positive result if `n` and `m` have the same sign and `n` is
137  // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
138  Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
139  Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
140  Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
141  Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
142  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
143  return success();
144  }
145 };
146 
147 //===----------------------------------------------------------------------===//
148 // ConvertIndexCeilDivU
149 //===----------------------------------------------------------------------===//
150 
151 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
152 /// from the equivalent conversion in IndexToLLVM.
153 struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
155 
157  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
158  ConversionPatternRewriter &rewriter) const override {
159  Location loc = op.getLoc();
160  Value n = adaptor.getLhs();
161  Type n_type = n.getType();
162  Value m = adaptor.getRhs();
163 
164  // Define the constants
165  Value zero = rewriter.create<spirv::ConstantOp>(
166  loc, n_type, IntegerAttr::get(n_type, 0));
167  Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
168  IntegerAttr::get(n_type, 1));
169 
170  // Compute the non-zero result.
171  Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
172  Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
173  Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
174 
175  // Pick the result
176  Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
177  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
178  return success();
179  }
180 };
181 
182 //===----------------------------------------------------------------------===//
183 // ConvertIndexFloorDivS
184 //===----------------------------------------------------------------------===//
185 
186 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
187 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
188 /// in IndexToLLVM.
189 struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
191 
193  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
194  ConversionPatternRewriter &rewriter) const override {
195  Location loc = op.getLoc();
196  Value n = adaptor.getLhs();
197  Type n_type = n.getType();
198  Value m = adaptor.getRhs();
199 
200  // Define the constants
201  Value zero = rewriter.create<spirv::ConstantOp>(
202  loc, n_type, IntegerAttr::get(n_type, 0));
203  Value posOne = rewriter.create<spirv::ConstantOp>(
204  loc, n_type, IntegerAttr::get(n_type, 1));
205  Value negOne = rewriter.create<spirv::ConstantOp>(
206  loc, n_type, IntegerAttr::get(n_type, -1));
207 
208  // Compute `x`.
209  Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
210  Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
211 
212  // Compute the negative result
213  Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
214  Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
215  Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
216 
217  // Compute the positive result.
218  Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
219 
220  // Pick the negative result if `n` and `m` have different signs and `n` is
221  // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
222  Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
223  Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
224  Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
225 
226  Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
227  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
228  return success();
229  }
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // ConvertIndexCast
234 //===----------------------------------------------------------------------===//
235 
236 /// Convert a cast op. If the materialized index type is the same as the other
237 /// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239 /// zero extend when the result bitwidth is larger.
240 template <typename CastOp, typename ConvertOp>
241 struct ConvertIndexCast final : OpConversionPattern<CastOp> {
243 
245  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
248  Type indexType = typeConverter->getIndexType();
249 
250  Type srcType = adaptor.getInput().getType();
251  Type dstType = op.getType();
252  if (isa<IndexType>(srcType)) {
253  srcType = indexType;
254  }
255  if (isa<IndexType>(dstType)) {
256  dstType = indexType;
257  }
258 
259  if (srcType == dstType) {
260  rewriter.replaceOp(op, adaptor.getInput());
261  } else {
262  rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263  adaptor.getOperands());
264  }
265  return success();
266  }
267 };
268 
269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
271 
272 //===----------------------------------------------------------------------===//
273 // ConvertIndexCmp
274 //===----------------------------------------------------------------------===//
275 
276 // Helper template to replace the operation
277 template <typename ICmpOp>
278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279  ConversionPatternRewriter &rewriter) {
280  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
281  return success();
282 }
283 
284 struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
286 
288  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289  ConversionPatternRewriter &rewriter) const override {
290  // We must convert the predicates to the corresponding int comparions.
291  switch (op.getPred()) {
292  case IndexCmpPredicate::EQ:
293  return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294  case IndexCmpPredicate::NE:
295  return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296  case IndexCmpPredicate::SGE:
297  return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298  case IndexCmpPredicate::SGT:
299  return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300  case IndexCmpPredicate::SLE:
301  return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302  case IndexCmpPredicate::SLT:
303  return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304  case IndexCmpPredicate::UGE:
305  return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306  case IndexCmpPredicate::UGT:
307  return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308  case IndexCmpPredicate::ULE:
309  return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310  case IndexCmpPredicate::ULT:
311  return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
312  }
313  }
314 };
315 
316 //===----------------------------------------------------------------------===//
317 // ConvertIndexSizeOf
318 //===----------------------------------------------------------------------===//
319 
320 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
321 struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
323 
325  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
326  ConversionPatternRewriter &rewriter) const override {
327  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
328  Type indexType = typeConverter->getIndexType();
329  unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
330  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
331  op, indexType, IntegerAttr::get(indexType, bitwidth));
332  return success();
333  }
334 };
335 } // namespace
336 
337 //===----------------------------------------------------------------------===//
338 // Pattern Population
339 //===----------------------------------------------------------------------===//
340 
342  RewritePatternSet &patterns) {
343  patterns.add<
344  // clang-format off
345  ConvertIndexAdd,
346  ConvertIndexSub,
347  ConvertIndexMul,
348  ConvertIndexDivS,
349  ConvertIndexDivU,
350  ConvertIndexRemS,
351  ConvertIndexRemU,
352  ConvertIndexMaxS,
353  ConvertIndexMaxU,
354  ConvertIndexMinS,
355  ConvertIndexMinU,
356  ConvertIndexShl,
357  ConvertIndexShrS,
358  ConvertIndexShrU,
359  ConvertIndexAnd,
360  ConvertIndexOr,
361  ConvertIndexXor,
362  ConvertIndexConstantBoolOpPattern,
363  ConvertIndexConstantOpPattern,
364  ConvertIndexCeilDivSPattern,
365  ConvertIndexCeilDivUPattern,
366  ConvertIndexFloorDivSPattern,
367  ConvertIndexCastS,
368  ConvertIndexCastU,
369  ConvertIndexCmpPattern,
370  ConvertIndexSizeOf
371  >(typeConverter, patterns.getContext());
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // ODS-Generated Definitions
376 //===----------------------------------------------------------------------===//
377 
378 namespace mlir {
379 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
380 #include "mlir/Conversion/Passes.h.inc"
381 } // namespace mlir
382 
383 //===----------------------------------------------------------------------===//
384 // Pass Definition
385 //===----------------------------------------------------------------------===//
386 
387 namespace {
388 struct ConvertIndexToSPIRVPass
389  : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
390  using Base::Base;
391 
392  void runOnOperation() override {
393  Operation *op = getOperation();
395  std::unique_ptr<SPIRVConversionTarget> target =
396  SPIRVConversionTarget::get(targetAttr);
397 
399  options.use64bitIndex = this->use64bitIndex;
400  SPIRVTypeConverter typeConverter(targetAttr, options);
401 
402  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
403  // in patterns for other dialects.
404  target->addLegalOp<UnrealizedConversionCastOp>();
405 
406  // Allow the spirv operations we are converting to
407  target->addLegalDialect<spirv::SPIRVDialect>();
408  // Fail hard when there are any remaining 'index' ops.
409  target->addIllegalDialect<index::IndexDialect>();
410 
411  RewritePatternSet patterns(&getContext());
412  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
413 
414  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
415  signalPassFailure();
416  }
417 };
418 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23