MLIR  22.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/SmallVectorExtras.h"
32 #include "llvm/Support/FormatVariadic.h"
33 #include <cassert>
34 #include <cstdint>
35 #include <numeric>
36 
37 using namespace mlir;
38 
39 /// Returns the integer value from the first valid input element, assuming Value
40 /// inputs are defined by a constant index ops and Attribute inputs are integer
41 /// attributes.
42 static uint64_t getFirstIntValue(ArrayAttr attr) {
43  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
44 }
45 
46 /// Returns the number of bits for the given scalar/vector type.
47 static int getNumBits(Type type) {
48  // TODO: This does not take into account any memory layout or widening
49  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
50  // though in practice it will likely be stored as in a 4xi64 vector register.
51  if (auto vectorType = dyn_cast<VectorType>(type))
52  return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
53  return type.getIntOrFloatBitWidth();
54 }
55 
56 namespace {
57 
58 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
59  using Base::Base;
60 
61  LogicalResult
62  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
63  ConversionPatternRewriter &rewriter) const override {
64  Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
65  if (!dstType)
66  return failure();
67 
68  // If dstType is same as the source type or the vector size is 1, it can be
69  // directly replaced by the source.
70  if (dstType == adaptor.getSource().getType() ||
71  shapeCastOp.getResultVectorType().getNumElements() == 1) {
72  rewriter.replaceOp(shapeCastOp, adaptor.getSource());
73  return success();
74  }
75 
76  // Lowering for size-n vectors when n > 1 hasn't been implemented.
77  return failure();
78  }
79 };
80 
81 struct VectorBitcastConvert final
82  : public OpConversionPattern<vector::BitCastOp> {
83  using Base::Base;
84 
85  LogicalResult
86  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
87  ConversionPatternRewriter &rewriter) const override {
88  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
89  if (!dstType)
90  return failure();
91 
92  if (dstType == adaptor.getSource().getType()) {
93  rewriter.replaceOp(bitcastOp, adaptor.getSource());
94  return success();
95  }
96 
97  // Check that the source and destination type have the same bitwidth.
98  // Depending on the target environment, we may need to emulate certain
99  // types, which can cause issue with bitcast.
100  Type srcType = adaptor.getSource().getType();
101  if (getNumBits(dstType) != getNumBits(srcType)) {
102  return rewriter.notifyMatchFailure(
103  bitcastOp,
104  llvm::formatv("different source ({0}) and target ({1}) bitwidth",
105  srcType, dstType));
106  }
107 
108  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
109  adaptor.getSource());
110  return success();
111  }
112 };
113 
114 struct VectorBroadcastConvert final
115  : public OpConversionPattern<vector::BroadcastOp> {
116  using Base::Base;
117 
118  LogicalResult
119  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
120  ConversionPatternRewriter &rewriter) const override {
121  Type resultType =
122  getTypeConverter()->convertType(castOp.getResultVectorType());
123  if (!resultType)
124  return failure();
125 
126  if (isa<spirv::ScalarType>(resultType)) {
127  rewriter.replaceOp(castOp, adaptor.getSource());
128  return success();
129  }
130 
131  SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
132  adaptor.getSource());
133  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
134  source);
135  return success();
136  }
137 };
138 
139 // SPIR-V does not have a concept of a poison index for certain instructions,
140 // which creates a UB hazard when lowering from otherwise equivalent Vector
141 // dialect instructions, because this index will be considered out-of-bounds.
142 // To avoid this, this function implements a dynamic sanitization that returns
143 // some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
144 // (presumably more efficient), and otherwise index 0 (always in-bounds).
145 static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
146  Location loc, Value dynamicIndex,
147  int64_t kPoisonIndex, unsigned vectorSize) {
148  if (llvm::isPowerOf2_32(vectorSize)) {
149  Value inBoundsMask = spirv::ConstantOp::create(
150  rewriter, loc, dynamicIndex.getType(),
151  rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
152  return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
153  inBoundsMask);
154  }
155  Value poisonIndex = spirv::ConstantOp::create(
156  rewriter, loc, dynamicIndex.getType(),
157  rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
158  Value cmpResult =
159  spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
160  return spirv::SelectOp::create(
161  rewriter, loc, cmpResult,
162  spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
163  dynamicIndex);
164 }
165 
166 struct VectorExtractOpConvert final
167  : public OpConversionPattern<vector::ExtractOp> {
168  using Base::Base;
169 
170  LogicalResult
171  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
172  ConversionPatternRewriter &rewriter) const override {
173  Type dstType = getTypeConverter()->convertType(extractOp.getType());
174  if (!dstType)
175  return failure();
176 
177  if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
178  rewriter.replaceOp(extractOp, adaptor.getSource());
179  return success();
180  }
181 
182  if (std::optional<int64_t> id =
183  getConstantIntValue(extractOp.getMixedPosition()[0])) {
184  if (id == vector::ExtractOp::kPoisonIndex)
185  return rewriter.notifyMatchFailure(
186  extractOp,
187  "Static use of poison index handled elsewhere (folded to poison)");
188  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
189  extractOp, dstType, adaptor.getSource(),
190  rewriter.getI32ArrayAttr(id.value()));
191  } else {
192  Value sanitizedIndex = sanitizeDynamicIndex(
193  rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
194  vector::ExtractOp::kPoisonIndex,
195  extractOp.getSourceVectorType().getNumElements());
196  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
197  extractOp, dstType, adaptor.getSource(), sanitizedIndex);
198  }
199  return success();
200  }
201 };
202 
203 struct VectorExtractStridedSliceOpConvert final
204  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
205  using Base::Base;
206 
207  LogicalResult
208  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
209  ConversionPatternRewriter &rewriter) const override {
210  Type dstType = getTypeConverter()->convertType(extractOp.getType());
211  if (!dstType)
212  return failure();
213 
214  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
215  uint64_t size = getFirstIntValue(extractOp.getSizes());
216  uint64_t stride = getFirstIntValue(extractOp.getStrides());
217  if (stride != 1)
218  return failure();
219 
220  Value srcVector = adaptor.getOperands().front();
221 
222  // Extract vector<1xT> case.
223  if (isa<spirv::ScalarType>(dstType)) {
224  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
225  srcVector, offset);
226  return success();
227  }
228 
229  SmallVector<int32_t, 2> indices(size);
230  std::iota(indices.begin(), indices.end(), offset);
231 
232  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
233  extractOp, dstType, srcVector, srcVector,
234  rewriter.getI32ArrayAttr(indices));
235 
236  return success();
237  }
238 };
239 
240 template <class SPIRVFMAOp>
241 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
242  using Base::Base;
243 
244  LogicalResult
245  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
248  if (!dstType)
249  return failure();
250  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
251  adaptor.getRhs(), adaptor.getAcc());
252  return success();
253  }
254 };
255 
256 struct VectorFromElementsOpConvert final
257  : public OpConversionPattern<vector::FromElementsOp> {
258  using Base::Base;
259 
260  LogicalResult
261  matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
262  ConversionPatternRewriter &rewriter) const override {
263  Type resultType = getTypeConverter()->convertType(op.getType());
264  if (!resultType)
265  return failure();
266  ValueRange elements = adaptor.getElements();
267  if (isa<spirv::ScalarType>(resultType)) {
268  // In the case with a single scalar operand / single-element result,
269  // pass through the scalar.
270  rewriter.replaceOp(op, elements[0]);
271  return success();
272  }
273  // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
274  // vector.from_elements cases should not need to be handled, only 1d.
275  assert(cast<VectorType>(resultType).getRank() == 1);
276  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
277  elements);
278  return success();
279  }
280 };
281 
282 struct VectorInsertOpConvert final
283  : public OpConversionPattern<vector::InsertOp> {
284  using Base::Base;
285 
286  LogicalResult
287  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
288  ConversionPatternRewriter &rewriter) const override {
289  if (isa<VectorType>(insertOp.getValueToStoreType()))
290  return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
291  if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
292  return rewriter.notifyMatchFailure(insertOp,
293  "unsupported dest vector type");
294 
295  // Special case for inserting scalar values into size-1 vectors.
296  if (insertOp.getValueToStoreType().isIntOrFloat() &&
297  insertOp.getDestVectorType().getNumElements() == 1) {
298  rewriter.replaceOp(insertOp, adaptor.getValueToStore());
299  return success();
300  }
301 
302  if (std::optional<int64_t> id =
303  getConstantIntValue(insertOp.getMixedPosition()[0])) {
304  if (id == vector::InsertOp::kPoisonIndex)
305  return rewriter.notifyMatchFailure(
306  insertOp,
307  "Static use of poison index handled elsewhere (folded to poison)");
308  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
309  insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
310  } else {
311  Value sanitizedIndex = sanitizeDynamicIndex(
312  rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
313  vector::InsertOp::kPoisonIndex,
314  insertOp.getDestVectorType().getNumElements());
315  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
316  insertOp, insertOp.getDest(), adaptor.getValueToStore(),
317  sanitizedIndex);
318  }
319  return success();
320  }
321 };
322 
323 struct VectorInsertStridedSliceOpConvert final
324  : public OpConversionPattern<vector::InsertStridedSliceOp> {
325  using Base::Base;
326 
327  LogicalResult
328  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
329  ConversionPatternRewriter &rewriter) const override {
330  Value srcVector = adaptor.getOperands().front();
331  Value dstVector = adaptor.getOperands().back();
332 
333  uint64_t stride = getFirstIntValue(insertOp.getStrides());
334  if (stride != 1)
335  return failure();
336  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
337 
338  if (isa<spirv::ScalarType>(srcVector.getType())) {
339  assert(!isa<spirv::ScalarType>(dstVector.getType()));
340  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
341  insertOp, dstVector.getType(), srcVector, dstVector,
342  rewriter.getI32ArrayAttr(offset));
343  return success();
344  }
345 
346  uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
347  uint64_t insertSize =
348  cast<VectorType>(srcVector.getType()).getNumElements();
349 
350  SmallVector<int32_t, 2> indices(totalSize);
351  std::iota(indices.begin(), indices.end(), 0);
352  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
353  totalSize);
354 
355  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
356  insertOp, dstVector.getType(), dstVector, srcVector,
357  rewriter.getI32ArrayAttr(indices));
358 
359  return success();
360  }
361 };
362 
363 static SmallVector<Value> extractAllElements(
364  vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
365  VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
366  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
367  SmallVector<Value> values;
368  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
369  Location loc = reduceOp.getLoc();
370 
371  for (int i = 0; i < numElements; ++i) {
372  values.push_back(spirv::CompositeExtractOp::create(
373  rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
374  rewriter.getI32ArrayAttr({i})));
375  }
376  if (Value acc = adaptor.getAcc())
377  values.push_back(acc);
378 
379  return values;
380 }
381 
382 struct ReductionRewriteInfo {
383  Type resultType;
384  SmallVector<Value> extractedElements;
385 };
386 
387 FailureOr<ReductionRewriteInfo> static getReductionInfo(
388  vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
389  ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
390  Type resultType = typeConverter.convertType(op.getType());
391  if (!resultType)
392  return failure();
393 
394  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
395  if (!srcVectorType || srcVectorType.getRank() != 1)
396  return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
397 
398  SmallVector<Value> extractedElements =
399  extractAllElements(op, adaptor, srcVectorType, rewriter);
400 
401  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
402 }
403 
404 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
405  typename SPIRVSMinOp>
406 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
407  using Base::Base;
408 
409  LogicalResult
410  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
411  ConversionPatternRewriter &rewriter) const override {
412  auto reductionInfo =
413  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
414  if (failed(reductionInfo))
415  return failure();
416 
417  auto [resultType, extractedElements] = *reductionInfo;
418  Location loc = reduceOp->getLoc();
419  Value result = extractedElements.front();
420  for (Value next : llvm::drop_begin(extractedElements)) {
421  switch (reduceOp.getKind()) {
422 
423 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
424  case vector::CombiningKind::kind: \
425  if (llvm::isa<IntegerType>(resultType)) { \
426  result = spirv::iop::create(rewriter, loc, resultType, result, next); \
427  } else { \
428  assert(llvm::isa<FloatType>(resultType)); \
429  result = spirv::fop::create(rewriter, loc, resultType, result, next); \
430  } \
431  break
432 
433 #define INT_OR_FLOAT_CASE(kind, fop) \
434  case vector::CombiningKind::kind: \
435  result = fop::create(rewriter, loc, resultType, result, next); \
436  break
437 
438  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
439  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
440  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
441  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
442  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
443  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
444 
445  case vector::CombiningKind::AND:
446  case vector::CombiningKind::OR:
447  case vector::CombiningKind::XOR:
448  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
449  default:
450  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
451  }
452 #undef INT_AND_FLOAT_CASE
453 #undef INT_OR_FLOAT_CASE
454  }
455 
456  rewriter.replaceOp(reduceOp, result);
457  return success();
458  }
459 };
460 
461 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
462 struct VectorReductionFloatMinMax final
463  : OpConversionPattern<vector::ReductionOp> {
464  using Base::Base;
465 
466  LogicalResult
467  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
468  ConversionPatternRewriter &rewriter) const override {
469  auto reductionInfo =
470  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
471  if (failed(reductionInfo))
472  return failure();
473 
474  auto [resultType, extractedElements] = *reductionInfo;
475  Location loc = reduceOp->getLoc();
476  Value result = extractedElements.front();
477  for (Value next : llvm::drop_begin(extractedElements)) {
478  switch (reduceOp.getKind()) {
479 
480 #define INT_OR_FLOAT_CASE(kind, fop) \
481  case vector::CombiningKind::kind: \
482  result = fop::create(rewriter, loc, resultType, result, next); \
483  break
484 
485  INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
486  INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
487  INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
488  INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
489 
490  default:
491  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
492  }
493 #undef INT_OR_FLOAT_CASE
494  }
495 
496  rewriter.replaceOp(reduceOp, result);
497  return success();
498  }
499 };
500 
501 class VectorScalarBroadcastPattern final
502  : public OpConversionPattern<vector::BroadcastOp> {
503 public:
504  using Base::Base;
505 
506  LogicalResult
507  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
508  ConversionPatternRewriter &rewriter) const override {
509  if (isa<VectorType>(op.getSourceType())) {
510  return rewriter.notifyMatchFailure(
511  op, "only conversion of 'broadcast from scalar' is supported");
512  }
513  Type dstType = getTypeConverter()->convertType(op.getType());
514  if (!dstType)
515  return failure();
516  if (isa<spirv::ScalarType>(dstType)) {
517  rewriter.replaceOp(op, adaptor.getSource());
518  } else {
519  auto dstVecType = cast<VectorType>(dstType);
520  SmallVector<Value, 4> source(dstVecType.getNumElements(),
521  adaptor.getSource());
522  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
523  source);
524  }
525  return success();
526  }
527 };
528 
529 struct VectorShuffleOpConvert final
530  : public OpConversionPattern<vector::ShuffleOp> {
531  using Base::Base;
532 
533  LogicalResult
534  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
535  ConversionPatternRewriter &rewriter) const override {
536  VectorType oldResultType = shuffleOp.getResultVectorType();
537  Type newResultType = getTypeConverter()->convertType(oldResultType);
538  if (!newResultType)
539  return rewriter.notifyMatchFailure(shuffleOp,
540  "unsupported result vector type");
541 
542  auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
543 
544  VectorType oldV1Type = shuffleOp.getV1VectorType();
545  VectorType oldV2Type = shuffleOp.getV2VectorType();
546 
547  // When both operands and the result are SPIR-V vectors, emit a SPIR-V
548  // shuffle.
549  if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
550  oldResultType.getNumElements() > 1) {
551  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
552  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
553  rewriter.getI32ArrayAttr(mask));
554  return success();
555  }
556 
557  // When at least one of the operands or the result becomes a scalar after
558  // type conversion for SPIR-V, extract all the required elements and
559  // construct the result vector.
560  auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
561  Value scalarOrVec, int32_t idx) -> Value {
562  if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
563  return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
564  idx);
565 
566  assert(idx == 0 && "Invalid scalar element index");
567  return scalarOrVec;
568  };
569 
570  int32_t numV1Elems = oldV1Type.getNumElements();
571  SmallVector<Value> newOperands(mask.size());
572  for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
573  Value vec = adaptor.getV1();
574  int32_t elementIdx = shuffleIdx;
575  if (elementIdx >= numV1Elems) {
576  vec = adaptor.getV2();
577  elementIdx -= numV1Elems;
578  }
579 
580  newOperand = getElementAtIdx(vec, elementIdx);
581  }
582 
583  // Handle the scalar result corner case.
584  if (newOperands.size() == 1) {
585  rewriter.replaceOp(shuffleOp, newOperands.front());
586  return success();
587  }
588 
589  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
590  shuffleOp, newResultType, newOperands);
591  return success();
592  }
593 };
594 
595 struct VectorInterleaveOpConvert final
596  : public OpConversionPattern<vector::InterleaveOp> {
597  using Base::Base;
598 
599  LogicalResult
600  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
601  ConversionPatternRewriter &rewriter) const override {
602  // Check the result vector type.
603  VectorType oldResultType = interleaveOp.getResultVectorType();
604  Type newResultType = getTypeConverter()->convertType(oldResultType);
605  if (!newResultType)
606  return rewriter.notifyMatchFailure(interleaveOp,
607  "unsupported result vector type");
608 
609  // Interleave the indices.
610  VectorType sourceType = interleaveOp.getSourceVectorType();
611  int n = sourceType.getNumElements();
612 
613  // Input vectors of size 1 are converted to scalars by the type converter.
614  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
615  // use `spirv::CompositeConstructOp`.
616  if (n == 1) {
617  Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
618  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
619  interleaveOp, newResultType, newOperands);
620  return success();
621  }
622 
623  auto seq = llvm::seq<int64_t>(2 * n);
624  auto indices = llvm::map_to_vector(
625  seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
626 
627  // Emit a SPIR-V shuffle.
628  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
629  interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
630  rewriter.getI32ArrayAttr(indices));
631 
632  return success();
633  }
634 };
635 
636 struct VectorDeinterleaveOpConvert final
637  : public OpConversionPattern<vector::DeinterleaveOp> {
638  using Base::Base;
639 
640  LogicalResult
641  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
642  ConversionPatternRewriter &rewriter) const override {
643 
644  // Check the result vector type.
645  VectorType oldResultType = deinterleaveOp.getResultVectorType();
646  Type newResultType = getTypeConverter()->convertType(oldResultType);
647  if (!newResultType)
648  return rewriter.notifyMatchFailure(deinterleaveOp,
649  "unsupported result vector type");
650 
651  Location loc = deinterleaveOp->getLoc();
652 
653  // Deinterleave the indices.
654  Value sourceVector = adaptor.getSource();
655  VectorType sourceType = deinterleaveOp.getSourceVectorType();
656  int n = sourceType.getNumElements();
657 
658  // Output vectors of size 1 are converted to scalars by the type converter.
659  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
660  // use `spirv::CompositeExtractOp`.
661  if (n == 2) {
662  auto elem0 = spirv::CompositeExtractOp::create(
663  rewriter, loc, newResultType, sourceVector,
664  rewriter.getI32ArrayAttr({0}));
665 
666  auto elem1 = spirv::CompositeExtractOp::create(
667  rewriter, loc, newResultType, sourceVector,
668  rewriter.getI32ArrayAttr({1}));
669 
670  rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
671  return success();
672  }
673 
674  // Indices for `shuffleEven` (result 0).
675  auto seqEven = llvm::seq<int64_t>(n / 2);
676  auto indicesEven =
677  llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
678 
679  // Indices for `shuffleOdd` (result 1).
680  auto seqOdd = llvm::seq<int64_t>(n / 2);
681  auto indicesOdd =
682  llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
683 
684  // Create two SPIR-V shuffles.
685  auto shuffleEven = spirv::VectorShuffleOp::create(
686  rewriter, loc, newResultType, sourceVector, sourceVector,
687  rewriter.getI32ArrayAttr(indicesEven));
688 
689  auto shuffleOdd = spirv::VectorShuffleOp::create(
690  rewriter, loc, newResultType, sourceVector, sourceVector,
691  rewriter.getI32ArrayAttr(indicesOdd));
692 
693  rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
694  return success();
695  }
696 };
697 
698 struct VectorLoadOpConverter final
699  : public OpConversionPattern<vector::LoadOp> {
700  using Base::Base;
701 
702  LogicalResult
703  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
704  ConversionPatternRewriter &rewriter) const override {
705  auto memrefType = loadOp.getMemRefType();
706  auto attr =
707  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
708  if (!attr)
709  return rewriter.notifyMatchFailure(
710  loadOp, "expected spirv.storage_class memory space");
711 
712  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
713  auto loc = loadOp.getLoc();
714  Value accessChain =
715  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
716  adaptor.getIndices(), loc, rewriter);
717  if (!accessChain)
718  return rewriter.notifyMatchFailure(
719  loadOp, "failed to get memref element pointer");
720 
721  spirv::StorageClass storageClass = attr.getValue();
722  auto vectorType = loadOp.getVectorType();
723  // Use the converted vector type instead of original (single element vector
724  // would get converted to scalar).
725  auto spirvVectorType = typeConverter.convertType(vectorType);
726  if (!spirvVectorType)
727  return rewriter.notifyMatchFailure(loadOp, "unsupported vector type");
728 
729  auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
730 
731  std::optional<uint64_t> alignment = loadOp.getAlignment();
732  if (alignment > std::numeric_limits<uint32_t>::max()) {
733  return rewriter.notifyMatchFailure(loadOp,
734  "invalid alignment requirement");
735  }
736 
737  auto memoryAccess = spirv::MemoryAccess::None;
738  spirv::MemoryAccessAttr memoryAccessAttr;
739  IntegerAttr alignmentAttr;
740  if (alignment.has_value()) {
741  memoryAccess |= spirv::MemoryAccess::Aligned;
742  memoryAccessAttr =
743  spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
744  alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
745  }
746 
747  // For single element vectors, we don't need to bitcast the access chain to
748  // the original vector type. Both is going to be the same, a pointer
749  // to a scalar.
750  Value castedAccessChain =
751  (vectorType.getNumElements() == 1)
752  ? accessChain
753  : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
754  accessChain);
755 
756  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
757  castedAccessChain,
758  memoryAccessAttr, alignmentAttr);
759 
760  return success();
761  }
762 };
763 
764 struct VectorStoreOpConverter final
765  : public OpConversionPattern<vector::StoreOp> {
766  using Base::Base;
767 
768  LogicalResult
769  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
770  ConversionPatternRewriter &rewriter) const override {
771  auto memrefType = storeOp.getMemRefType();
772  auto attr =
773  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
774  if (!attr)
775  return rewriter.notifyMatchFailure(
776  storeOp, "expected spirv.storage_class memory space");
777 
778  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
779  auto loc = storeOp.getLoc();
780  Value accessChain =
781  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
782  adaptor.getIndices(), loc, rewriter);
783  if (!accessChain)
784  return rewriter.notifyMatchFailure(
785  storeOp, "failed to get memref element pointer");
786 
787  std::optional<uint64_t> alignment = storeOp.getAlignment();
788  if (alignment > std::numeric_limits<uint32_t>::max()) {
789  return rewriter.notifyMatchFailure(storeOp,
790  "invalid alignment requirement");
791  }
792 
793  spirv::StorageClass storageClass = attr.getValue();
794  auto vectorType = storeOp.getVectorType();
795  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
796 
797  // For single element vectors, we don't need to bitcast the access chain to
798  // the original vector type. Both is going to be the same, a pointer
799  // to a scalar.
800  Value castedAccessChain =
801  (vectorType.getNumElements() == 1)
802  ? accessChain
803  : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
804  accessChain);
805 
806  auto memoryAccess = spirv::MemoryAccess::None;
807  spirv::MemoryAccessAttr memoryAccessAttr;
808  IntegerAttr alignmentAttr;
809  if (alignment.has_value()) {
810  memoryAccess |= spirv::MemoryAccess::Aligned;
811  memoryAccessAttr =
812  spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
813  alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
814  }
815 
816  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
817  storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
818  alignmentAttr);
819 
820  return success();
821  }
822 };
823 
824 struct VectorReductionToIntDotProd final
825  : OpRewritePattern<vector::ReductionOp> {
826  using Base::Base;
827 
828  LogicalResult matchAndRewrite(vector::ReductionOp op,
829  PatternRewriter &rewriter) const override {
830  if (op.getKind() != vector::CombiningKind::ADD)
831  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
832 
833  auto resultType = dyn_cast<IntegerType>(op.getType());
834  if (!resultType)
835  return rewriter.notifyMatchFailure(op, "result is not an integer");
836 
837  int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
838  if (!llvm::is_contained({32, 64}, resultBitwidth))
839  return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
840 
841  VectorType inVecTy = op.getSourceVectorType();
842  if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
843  inVecTy.getShape().size() != 1 || inVecTy.isScalable())
844  return rewriter.notifyMatchFailure(op, "unsupported vector shape");
845 
846  auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
847  if (!mul)
848  return rewriter.notifyMatchFailure(
849  op, "reduction operand is not 'arith.muli'");
850 
851  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
852  spirv::SDotAccSatOp, false>(op, mul, rewriter)))
853  return success();
854 
855  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
856  spirv::UDotAccSatOp, false>(op, mul, rewriter)))
857  return success();
858 
859  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
860  spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
861  return success();
862 
863  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
864  spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
865  return success();
866 
867  return failure();
868  }
869 
870 private:
871  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
872  typename DotAccOp, bool SwapOperands>
873  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
874  PatternRewriter &rewriter) {
875  auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
876  if (!lhs)
877  return failure();
878  Value lhsIn = lhs.getIn();
879  auto lhsInType = cast<VectorType>(lhsIn.getType());
880  if (!lhsInType.getElementType().isInteger(8))
881  return failure();
882 
883  auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
884  if (!rhs)
885  return failure();
886  Value rhsIn = rhs.getIn();
887  auto rhsInType = cast<VectorType>(rhsIn.getType());
888  if (!rhsInType.getElementType().isInteger(8))
889  return failure();
890 
891  if (op.getSourceVectorType().getNumElements() == 3) {
892  IntegerType i8Type = rewriter.getI8Type();
893  auto v4i8Type = VectorType::get({4}, i8Type);
894  Location loc = op.getLoc();
895  Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
896  lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
897  ValueRange{lhsIn, zero});
898  rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
899  ValueRange{rhsIn, zero});
900  }
901 
902  // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
903  // we have to swap operands instead in that case.
904  if (SwapOperands)
905  std::swap(lhsIn, rhsIn);
906 
907  if (Value acc = op.getAcc()) {
908  rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
909  nullptr);
910  } else {
911  rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
912  nullptr);
913  }
914 
915  return success();
916  }
917 };
918 
919 struct VectorReductionToFPDotProd final
920  : OpConversionPattern<vector::ReductionOp> {
921  using Base::Base;
922 
923  LogicalResult
924  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
925  ConversionPatternRewriter &rewriter) const override {
926  if (op.getKind() != vector::CombiningKind::ADD)
927  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
928 
929  auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
930  if (!resultType)
931  return rewriter.notifyMatchFailure(op, "result is not a float");
932 
933  Value vec = adaptor.getVector();
934  Value acc = adaptor.getAcc();
935 
936  auto vectorType = dyn_cast<VectorType>(vec.getType());
937  if (!vectorType) {
938  assert(isa<FloatType>(vec.getType()) &&
939  "Expected the vector to be scalarized");
940  if (acc) {
941  rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
942  return success();
943  }
944 
945  rewriter.replaceOp(op, vec);
946  return success();
947  }
948 
949  Location loc = op.getLoc();
950  Value lhs;
951  Value rhs;
952  if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
953  lhs = mul.getLhs();
954  rhs = mul.getRhs();
955  } else {
956  // If the operand is not a mul, use a vector of ones for the dot operand
957  // to just sum up all values.
958  lhs = vec;
959  Attribute oneAttr =
960  rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
961  oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
962  rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
963  }
964  assert(lhs);
965  assert(rhs);
966 
967  Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
968  if (acc)
969  res = spirv::FAddOp::create(rewriter, loc, acc, res);
970 
971  rewriter.replaceOp(op, res);
972  return success();
973  }
974 };
975 
976 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
977  using Base::Base;
978 
979  LogicalResult
980  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
981  ConversionPatternRewriter &rewriter) const override {
982  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
983  Type dstType = typeConverter.convertType(stepOp.getType());
984  if (!dstType)
985  return failure();
986 
987  Location loc = stepOp.getLoc();
988  int64_t numElements = stepOp.getType().getNumElements();
989  auto intType =
990  rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
991 
992  // Input vectors of size 1 are converted to scalars by the type converter.
993  // We just create a constant in this case.
994  if (numElements == 1) {
995  Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
996  rewriter.replaceOp(stepOp, zero);
997  return success();
998  }
999 
1000  SmallVector<Value> source;
1001  source.reserve(numElements);
1002  for (int64_t i = 0; i < numElements; ++i) {
1003  Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1004  Value constOp =
1005  spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1006  source.push_back(constOp);
1007  }
1008  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1009  source);
1010  return success();
1011  }
1012 };
1013 
1014 struct VectorToElementOpConvert final
1015  : OpConversionPattern<vector::ToElementsOp> {
1016  using Base::Base;
1017 
1018  LogicalResult
1019  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1020  ConversionPatternRewriter &rewriter) const override {
1021 
1022  SmallVector<Value> results(toElementsOp->getNumResults());
1023  Location loc = toElementsOp.getLoc();
1024 
1025  // Input vectors of size 1 are converted to scalars by the type converter.
1026  // We cannot use `spirv::CompositeExtractOp` directly in this case.
1027  // For a scalar source, the result is just the scalar itself.
1028  if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1029  results[0] = adaptor.getSource();
1030  rewriter.replaceOp(toElementsOp, results);
1031  return success();
1032  }
1033 
1034  Type srcElementType = toElementsOp.getElements().getType().front();
1035  Type elementType = getTypeConverter()->convertType(srcElementType);
1036  if (!elementType)
1037  return rewriter.notifyMatchFailure(
1038  toElementsOp,
1039  llvm::formatv("failed to convert element type '{0}' to SPIR-V",
1040  srcElementType));
1041 
1042  for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1043  // Create an CompositeExtract operation only for results that are not
1044  // dead.
1045  if (element.use_empty())
1046  continue;
1047 
1048  Value result = spirv::CompositeExtractOp::create(
1049  rewriter, loc, elementType, adaptor.getSource(),
1050  rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
1051  results[idx] = result;
1052  }
1053 
1054  rewriter.replaceOp(toElementsOp, results);
1055  return success();
1056  }
1057 };
1058 
1059 } // namespace
1060 #define CL_INT_MAX_MIN_OPS \
1061  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1062 
1063 #define GL_INT_MAX_MIN_OPS \
1064  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1065 
1066 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1067 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1068 
1070  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1071  patterns.add<
1072  VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1073  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1074  VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1075  VectorToElementOpConvert, VectorInsertOpConvert,
1076  VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1077  VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1078  VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1079  VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1080  VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1081  VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1082  VectorScalarBroadcastPattern, VectorLoadOpConverter,
1083  VectorStoreOpConverter, VectorStepOpConvert>(
1084  typeConverter, patterns.getContext(), PatternBenefit(1));
1085 
1086  // Make sure that the more specialized dot product pattern has higher benefit
1087  // than the generic one that extracts all elements.
1088  patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
1089  PatternBenefit(2));
1090 }
1091 
1094  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
1095 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:276
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:451
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314