MLIR  22.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/SmallVectorExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include <cassert>
35 #include <cstdint>
36 #include <numeric>
37 
38 using namespace mlir;
39 
40 /// Returns the integer value from the first valid input element, assuming Value
41 /// inputs are defined by a constant index ops and Attribute inputs are integer
42 /// attributes.
43 static uint64_t getFirstIntValue(ArrayAttr attr) {
44  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
45 }
46 
47 /// Returns the number of bits for the given scalar/vector type.
48 static int getNumBits(Type type) {
49  // TODO: This does not take into account any memory layout or widening
50  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
51  // though in practice it will likely be stored as in a 4xi64 vector register.
52  if (auto vectorType = dyn_cast<VectorType>(type))
53  return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
54  return type.getIntOrFloatBitWidth();
55 }
56 
57 namespace {
58 
59 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
61 
62  LogicalResult
63  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const override {
65  Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
66  if (!dstType)
67  return failure();
68 
69  // If dstType is same as the source type or the vector size is 1, it can be
70  // directly replaced by the source.
71  if (dstType == adaptor.getSource().getType() ||
72  shapeCastOp.getResultVectorType().getNumElements() == 1) {
73  rewriter.replaceOp(shapeCastOp, adaptor.getSource());
74  return success();
75  }
76 
77  // Lowering for size-n vectors when n > 1 hasn't been implemented.
78  return failure();
79  }
80 };
81 
82 // Convert `vector.splat` to `vector.broadcast`. There is a path from
83 // `vector.broadcast` to SPIRV via other patterns.
84 struct VectorSplatToBroadcast final
85  : public OpConversionPattern<vector::SplatOp> {
87  LogicalResult
88  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const override {
90  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
91  adaptor.getInput());
92  return success();
93  }
94 };
95 
96 struct VectorBitcastConvert final
97  : public OpConversionPattern<vector::BitCastOp> {
99 
100  LogicalResult
101  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
102  ConversionPatternRewriter &rewriter) const override {
103  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
104  if (!dstType)
105  return failure();
106 
107  if (dstType == adaptor.getSource().getType()) {
108  rewriter.replaceOp(bitcastOp, adaptor.getSource());
109  return success();
110  }
111 
112  // Check that the source and destination type have the same bitwidth.
113  // Depending on the target environment, we may need to emulate certain
114  // types, which can cause issue with bitcast.
115  Type srcType = adaptor.getSource().getType();
116  if (getNumBits(dstType) != getNumBits(srcType)) {
117  return rewriter.notifyMatchFailure(
118  bitcastOp,
119  llvm::formatv("different source ({0}) and target ({1}) bitwidth",
120  srcType, dstType));
121  }
122 
123  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
124  adaptor.getSource());
125  return success();
126  }
127 };
128 
129 struct VectorBroadcastConvert final
130  : public OpConversionPattern<vector::BroadcastOp> {
132 
133  LogicalResult
134  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
135  ConversionPatternRewriter &rewriter) const override {
136  Type resultType =
137  getTypeConverter()->convertType(castOp.getResultVectorType());
138  if (!resultType)
139  return failure();
140 
141  if (isa<spirv::ScalarType>(resultType)) {
142  rewriter.replaceOp(castOp, adaptor.getSource());
143  return success();
144  }
145 
146  SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
147  adaptor.getSource());
148  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
149  source);
150  return success();
151  }
152 };
153 
154 // SPIR-V does not have a concept of a poison index for certain instructions,
155 // which creates a UB hazard when lowering from otherwise equivalent Vector
156 // dialect instructions, because this index will be considered out-of-bounds.
157 // To avoid this, this function implements a dynamic sanitization that returns
158 // some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
159 // (presumably more efficient), and otherwise index 0 (always in-bounds).
160 static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
161  Location loc, Value dynamicIndex,
162  int64_t kPoisonIndex, unsigned vectorSize) {
163  if (llvm::isPowerOf2_32(vectorSize)) {
164  Value inBoundsMask = spirv::ConstantOp::create(
165  rewriter, loc, dynamicIndex.getType(),
166  rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
167  return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
168  inBoundsMask);
169  }
170  Value poisonIndex = spirv::ConstantOp::create(
171  rewriter, loc, dynamicIndex.getType(),
172  rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
173  Value cmpResult =
174  spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
175  return spirv::SelectOp::create(
176  rewriter, loc, cmpResult,
177  spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
178  dynamicIndex);
179 }
180 
181 struct VectorExtractOpConvert final
182  : public OpConversionPattern<vector::ExtractOp> {
184 
185  LogicalResult
186  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
187  ConversionPatternRewriter &rewriter) const override {
188  Type dstType = getTypeConverter()->convertType(extractOp.getType());
189  if (!dstType)
190  return failure();
191 
192  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
193  rewriter.replaceOp(extractOp, adaptor.getVector());
194  return success();
195  }
196 
197  if (std::optional<int64_t> id =
198  getConstantIntValue(extractOp.getMixedPosition()[0])) {
199  if (id == vector::ExtractOp::kPoisonIndex)
200  return rewriter.notifyMatchFailure(
201  extractOp,
202  "Static use of poison index handled elsewhere (folded to poison)");
203  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
204  extractOp, dstType, adaptor.getVector(),
205  rewriter.getI32ArrayAttr(id.value()));
206  } else {
207  Value sanitizedIndex = sanitizeDynamicIndex(
208  rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
209  vector::ExtractOp::kPoisonIndex,
210  extractOp.getSourceVectorType().getNumElements());
211  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
212  extractOp, dstType, adaptor.getVector(), sanitizedIndex);
213  }
214  return success();
215  }
216 };
217 
218 struct VectorExtractStridedSliceOpConvert final
219  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
221 
222  LogicalResult
223  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const override {
225  Type dstType = getTypeConverter()->convertType(extractOp.getType());
226  if (!dstType)
227  return failure();
228 
229  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
230  uint64_t size = getFirstIntValue(extractOp.getSizes());
231  uint64_t stride = getFirstIntValue(extractOp.getStrides());
232  if (stride != 1)
233  return failure();
234 
235  Value srcVector = adaptor.getOperands().front();
236 
237  // Extract vector<1xT> case.
238  if (isa<spirv::ScalarType>(dstType)) {
239  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
240  srcVector, offset);
241  return success();
242  }
243 
244  SmallVector<int32_t, 2> indices(size);
245  std::iota(indices.begin(), indices.end(), offset);
246 
247  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
248  extractOp, dstType, srcVector, srcVector,
249  rewriter.getI32ArrayAttr(indices));
250 
251  return success();
252  }
253 };
254 
255 template <class SPIRVFMAOp>
256 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
258 
259  LogicalResult
260  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
261  ConversionPatternRewriter &rewriter) const override {
262  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
263  if (!dstType)
264  return failure();
265  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
266  adaptor.getRhs(), adaptor.getAcc());
267  return success();
268  }
269 };
270 
271 struct VectorFromElementsOpConvert final
272  : public OpConversionPattern<vector::FromElementsOp> {
274 
275  LogicalResult
276  matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
277  ConversionPatternRewriter &rewriter) const override {
278  Type resultType = getTypeConverter()->convertType(op.getType());
279  if (!resultType)
280  return failure();
281  OperandRange elements = op.getElements();
282  if (isa<spirv::ScalarType>(resultType)) {
283  // In the case with a single scalar operand / single-element result,
284  // pass through the scalar.
285  rewriter.replaceOp(op, elements[0]);
286  return success();
287  }
288  // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
289  // vector.from_elements cases should not need to be handled, only 1d.
290  assert(cast<VectorType>(resultType).getRank() == 1);
291  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
292  elements);
293  return success();
294  }
295 };
296 
297 struct VectorInsertOpConvert final
298  : public OpConversionPattern<vector::InsertOp> {
300 
301  LogicalResult
302  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
303  ConversionPatternRewriter &rewriter) const override {
304  if (isa<VectorType>(insertOp.getValueToStoreType()))
305  return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
306  if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
307  return rewriter.notifyMatchFailure(insertOp,
308  "unsupported dest vector type");
309 
310  // Special case for inserting scalar values into size-1 vectors.
311  if (insertOp.getValueToStoreType().isIntOrFloat() &&
312  insertOp.getDestVectorType().getNumElements() == 1) {
313  rewriter.replaceOp(insertOp, adaptor.getValueToStore());
314  return success();
315  }
316 
317  if (std::optional<int64_t> id =
318  getConstantIntValue(insertOp.getMixedPosition()[0])) {
319  if (id == vector::InsertOp::kPoisonIndex)
320  return rewriter.notifyMatchFailure(
321  insertOp,
322  "Static use of poison index handled elsewhere (folded to poison)");
323  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
324  insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
325  } else {
326  Value sanitizedIndex = sanitizeDynamicIndex(
327  rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
328  vector::InsertOp::kPoisonIndex,
329  insertOp.getDestVectorType().getNumElements());
330  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
331  insertOp, insertOp.getDest(), adaptor.getValueToStore(),
332  sanitizedIndex);
333  }
334  return success();
335  }
336 };
337 
338 struct VectorInsertStridedSliceOpConvert final
339  : public OpConversionPattern<vector::InsertStridedSliceOp> {
341 
342  LogicalResult
343  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
344  ConversionPatternRewriter &rewriter) const override {
345  Value srcVector = adaptor.getOperands().front();
346  Value dstVector = adaptor.getOperands().back();
347 
348  uint64_t stride = getFirstIntValue(insertOp.getStrides());
349  if (stride != 1)
350  return failure();
351  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
352 
353  if (isa<spirv::ScalarType>(srcVector.getType())) {
354  assert(!isa<spirv::ScalarType>(dstVector.getType()));
355  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
356  insertOp, dstVector.getType(), srcVector, dstVector,
357  rewriter.getI32ArrayAttr(offset));
358  return success();
359  }
360 
361  uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
362  uint64_t insertSize =
363  cast<VectorType>(srcVector.getType()).getNumElements();
364 
365  SmallVector<int32_t, 2> indices(totalSize);
366  std::iota(indices.begin(), indices.end(), 0);
367  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
368  totalSize);
369 
370  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
371  insertOp, dstVector.getType(), dstVector, srcVector,
372  rewriter.getI32ArrayAttr(indices));
373 
374  return success();
375  }
376 };
377 
378 static SmallVector<Value> extractAllElements(
379  vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
380  VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
381  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
382  SmallVector<Value> values;
383  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
384  Location loc = reduceOp.getLoc();
385 
386  for (int i = 0; i < numElements; ++i) {
387  values.push_back(spirv::CompositeExtractOp::create(
388  rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
389  rewriter.getI32ArrayAttr({i})));
390  }
391  if (Value acc = adaptor.getAcc())
392  values.push_back(acc);
393 
394  return values;
395 }
396 
397 struct ReductionRewriteInfo {
398  Type resultType;
399  SmallVector<Value> extractedElements;
400 };
401 
402 FailureOr<ReductionRewriteInfo> static getReductionInfo(
403  vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
404  ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
405  Type resultType = typeConverter.convertType(op.getType());
406  if (!resultType)
407  return failure();
408 
409  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
410  if (!srcVectorType || srcVectorType.getRank() != 1)
411  return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
412 
413  SmallVector<Value> extractedElements =
414  extractAllElements(op, adaptor, srcVectorType, rewriter);
415 
416  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
417 }
418 
419 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
420  typename SPIRVSMinOp>
421 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
423 
424  LogicalResult
425  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
426  ConversionPatternRewriter &rewriter) const override {
427  auto reductionInfo =
428  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
429  if (failed(reductionInfo))
430  return failure();
431 
432  auto [resultType, extractedElements] = *reductionInfo;
433  Location loc = reduceOp->getLoc();
434  Value result = extractedElements.front();
435  for (Value next : llvm::drop_begin(extractedElements)) {
436  switch (reduceOp.getKind()) {
437 
438 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
439  case vector::CombiningKind::kind: \
440  if (llvm::isa<IntegerType>(resultType)) { \
441  result = spirv::iop::create(rewriter, loc, resultType, result, next); \
442  } else { \
443  assert(llvm::isa<FloatType>(resultType)); \
444  result = spirv::fop::create(rewriter, loc, resultType, result, next); \
445  } \
446  break
447 
448 #define INT_OR_FLOAT_CASE(kind, fop) \
449  case vector::CombiningKind::kind: \
450  result = fop::create(rewriter, loc, resultType, result, next); \
451  break
452 
453  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
454  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
455  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
456  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
457  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
458  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
459 
460  case vector::CombiningKind::AND:
461  case vector::CombiningKind::OR:
462  case vector::CombiningKind::XOR:
463  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
464  default:
465  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
466  }
467 #undef INT_AND_FLOAT_CASE
468 #undef INT_OR_FLOAT_CASE
469  }
470 
471  rewriter.replaceOp(reduceOp, result);
472  return success();
473  }
474 };
475 
476 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
477 struct VectorReductionFloatMinMax final
478  : OpConversionPattern<vector::ReductionOp> {
480 
481  LogicalResult
482  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
483  ConversionPatternRewriter &rewriter) const override {
484  auto reductionInfo =
485  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
486  if (failed(reductionInfo))
487  return failure();
488 
489  auto [resultType, extractedElements] = *reductionInfo;
490  Location loc = reduceOp->getLoc();
491  Value result = extractedElements.front();
492  for (Value next : llvm::drop_begin(extractedElements)) {
493  switch (reduceOp.getKind()) {
494 
495 #define INT_OR_FLOAT_CASE(kind, fop) \
496  case vector::CombiningKind::kind: \
497  result = fop::create(rewriter, loc, resultType, result, next); \
498  break
499 
500  INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
501  INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
502  INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
503  INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
504 
505  default:
506  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
507  }
508 #undef INT_OR_FLOAT_CASE
509  }
510 
511  rewriter.replaceOp(reduceOp, result);
512  return success();
513  }
514 };
515 
516 class VectorScalarBroadcastPattern final
517  : public OpConversionPattern<vector::BroadcastOp> {
518 public:
520 
521  LogicalResult
522  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
523  ConversionPatternRewriter &rewriter) const override {
524  if (isa<VectorType>(op.getSourceType())) {
525  return rewriter.notifyMatchFailure(
526  op, "only conversion of 'broadcast from scalar' is supported");
527  }
528  Type dstType = getTypeConverter()->convertType(op.getType());
529  if (!dstType)
530  return failure();
531  if (isa<spirv::ScalarType>(dstType)) {
532  rewriter.replaceOp(op, adaptor.getSource());
533  } else {
534  auto dstVecType = cast<VectorType>(dstType);
535  SmallVector<Value, 4> source(dstVecType.getNumElements(),
536  adaptor.getSource());
537  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
538  source);
539  }
540  return success();
541  }
542 };
543 
544 struct VectorShuffleOpConvert final
545  : public OpConversionPattern<vector::ShuffleOp> {
547 
548  LogicalResult
549  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
550  ConversionPatternRewriter &rewriter) const override {
551  VectorType oldResultType = shuffleOp.getResultVectorType();
552  Type newResultType = getTypeConverter()->convertType(oldResultType);
553  if (!newResultType)
554  return rewriter.notifyMatchFailure(shuffleOp,
555  "unsupported result vector type");
556 
557  auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
558 
559  VectorType oldV1Type = shuffleOp.getV1VectorType();
560  VectorType oldV2Type = shuffleOp.getV2VectorType();
561 
562  // When both operands and the result are SPIR-V vectors, emit a SPIR-V
563  // shuffle.
564  if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
565  oldResultType.getNumElements() > 1) {
566  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
567  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
568  rewriter.getI32ArrayAttr(mask));
569  return success();
570  }
571 
572  // When at least one of the operands or the result becomes a scalar after
573  // type conversion for SPIR-V, extract all the required elements and
574  // construct the result vector.
575  auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
576  Value scalarOrVec, int32_t idx) -> Value {
577  if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
578  return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
579  idx);
580 
581  assert(idx == 0 && "Invalid scalar element index");
582  return scalarOrVec;
583  };
584 
585  int32_t numV1Elems = oldV1Type.getNumElements();
586  SmallVector<Value> newOperands(mask.size());
587  for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
588  Value vec = adaptor.getV1();
589  int32_t elementIdx = shuffleIdx;
590  if (elementIdx >= numV1Elems) {
591  vec = adaptor.getV2();
592  elementIdx -= numV1Elems;
593  }
594 
595  newOperand = getElementAtIdx(vec, elementIdx);
596  }
597 
598  // Handle the scalar result corner case.
599  if (newOperands.size() == 1) {
600  rewriter.replaceOp(shuffleOp, newOperands.front());
601  return success();
602  }
603 
604  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
605  shuffleOp, newResultType, newOperands);
606  return success();
607  }
608 };
609 
610 struct VectorInterleaveOpConvert final
611  : public OpConversionPattern<vector::InterleaveOp> {
613 
614  LogicalResult
615  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
616  ConversionPatternRewriter &rewriter) const override {
617  // Check the result vector type.
618  VectorType oldResultType = interleaveOp.getResultVectorType();
619  Type newResultType = getTypeConverter()->convertType(oldResultType);
620  if (!newResultType)
621  return rewriter.notifyMatchFailure(interleaveOp,
622  "unsupported result vector type");
623 
624  // Interleave the indices.
625  VectorType sourceType = interleaveOp.getSourceVectorType();
626  int n = sourceType.getNumElements();
627 
628  // Input vectors of size 1 are converted to scalars by the type converter.
629  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
630  // use `spirv::CompositeConstructOp`.
631  if (n == 1) {
632  Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
633  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
634  interleaveOp, newResultType, newOperands);
635  return success();
636  }
637 
638  auto seq = llvm::seq<int64_t>(2 * n);
639  auto indices = llvm::map_to_vector(
640  seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
641 
642  // Emit a SPIR-V shuffle.
643  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
644  interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
645  rewriter.getI32ArrayAttr(indices));
646 
647  return success();
648  }
649 };
650 
651 struct VectorDeinterleaveOpConvert final
652  : public OpConversionPattern<vector::DeinterleaveOp> {
654 
655  LogicalResult
656  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
657  ConversionPatternRewriter &rewriter) const override {
658 
659  // Check the result vector type.
660  VectorType oldResultType = deinterleaveOp.getResultVectorType();
661  Type newResultType = getTypeConverter()->convertType(oldResultType);
662  if (!newResultType)
663  return rewriter.notifyMatchFailure(deinterleaveOp,
664  "unsupported result vector type");
665 
666  Location loc = deinterleaveOp->getLoc();
667 
668  // Deinterleave the indices.
669  Value sourceVector = adaptor.getSource();
670  VectorType sourceType = deinterleaveOp.getSourceVectorType();
671  int n = sourceType.getNumElements();
672 
673  // Output vectors of size 1 are converted to scalars by the type converter.
674  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
675  // use `spirv::CompositeExtractOp`.
676  if (n == 2) {
677  auto elem0 = spirv::CompositeExtractOp::create(
678  rewriter, loc, newResultType, sourceVector,
679  rewriter.getI32ArrayAttr({0}));
680 
681  auto elem1 = spirv::CompositeExtractOp::create(
682  rewriter, loc, newResultType, sourceVector,
683  rewriter.getI32ArrayAttr({1}));
684 
685  rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
686  return success();
687  }
688 
689  // Indices for `shuffleEven` (result 0).
690  auto seqEven = llvm::seq<int64_t>(n / 2);
691  auto indicesEven =
692  llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
693 
694  // Indices for `shuffleOdd` (result 1).
695  auto seqOdd = llvm::seq<int64_t>(n / 2);
696  auto indicesOdd =
697  llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
698 
699  // Create two SPIR-V shuffles.
700  auto shuffleEven = spirv::VectorShuffleOp::create(
701  rewriter, loc, newResultType, sourceVector, sourceVector,
702  rewriter.getI32ArrayAttr(indicesEven));
703 
704  auto shuffleOdd = spirv::VectorShuffleOp::create(
705  rewriter, loc, newResultType, sourceVector, sourceVector,
706  rewriter.getI32ArrayAttr(indicesOdd));
707 
708  rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
709  return success();
710  }
711 };
712 
713 struct VectorLoadOpConverter final
714  : public OpConversionPattern<vector::LoadOp> {
716 
717  LogicalResult
718  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
719  ConversionPatternRewriter &rewriter) const override {
720  auto memrefType = loadOp.getMemRefType();
721  auto attr =
722  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
723  if (!attr)
724  return rewriter.notifyMatchFailure(
725  loadOp, "expected spirv.storage_class memory space");
726 
727  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
728  auto loc = loadOp.getLoc();
729  Value accessChain =
730  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
731  adaptor.getIndices(), loc, rewriter);
732  if (!accessChain)
733  return rewriter.notifyMatchFailure(
734  loadOp, "failed to get memref element pointer");
735 
736  spirv::StorageClass storageClass = attr.getValue();
737  auto vectorType = loadOp.getVectorType();
738  // Use the converted vector type instead of original (single element vector
739  // would get converted to scalar).
740  auto spirvVectorType = typeConverter.convertType(vectorType);
741  if (!spirvVectorType)
742  return rewriter.notifyMatchFailure(loadOp, "unsupported vector type");
743 
744  auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
745 
746  std::optional<uint64_t> alignment = loadOp.getAlignment();
747  if (alignment > std::numeric_limits<uint32_t>::max()) {
748  return rewriter.notifyMatchFailure(loadOp,
749  "invalid alignment requirement");
750  }
751 
752  auto memoryAccess = spirv::MemoryAccess::None;
753  spirv::MemoryAccessAttr memoryAccessAttr;
754  IntegerAttr alignmentAttr;
755  if (alignment.has_value()) {
756  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
757  memoryAccessAttr =
758  spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
759  alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
760  }
761 
762  // For single element vectors, we don't need to bitcast the access chain to
763  // the original vector type. Both is going to be the same, a pointer
764  // to a scalar.
765  Value castedAccessChain =
766  (vectorType.getNumElements() == 1)
767  ? accessChain
768  : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
769  accessChain);
770 
771  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
772  castedAccessChain,
773  memoryAccessAttr, alignmentAttr);
774 
775  return success();
776  }
777 };
778 
779 struct VectorStoreOpConverter final
780  : public OpConversionPattern<vector::StoreOp> {
782 
783  LogicalResult
784  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
785  ConversionPatternRewriter &rewriter) const override {
786  auto memrefType = storeOp.getMemRefType();
787  auto attr =
788  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
789  if (!attr)
790  return rewriter.notifyMatchFailure(
791  storeOp, "expected spirv.storage_class memory space");
792 
793  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
794  auto loc = storeOp.getLoc();
795  Value accessChain =
796  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
797  adaptor.getIndices(), loc, rewriter);
798  if (!accessChain)
799  return rewriter.notifyMatchFailure(
800  storeOp, "failed to get memref element pointer");
801 
802  std::optional<uint64_t> alignment = storeOp.getAlignment();
803  if (alignment > std::numeric_limits<uint32_t>::max()) {
804  return rewriter.notifyMatchFailure(storeOp,
805  "invalid alignment requirement");
806  }
807 
808  spirv::StorageClass storageClass = attr.getValue();
809  auto vectorType = storeOp.getVectorType();
810  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
811 
812  // For single element vectors, we don't need to bitcast the access chain to
813  // the original vector type. Both is going to be the same, a pointer
814  // to a scalar.
815  Value castedAccessChain =
816  (vectorType.getNumElements() == 1)
817  ? accessChain
818  : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
819  accessChain);
820 
821  auto memoryAccess = spirv::MemoryAccess::None;
822  spirv::MemoryAccessAttr memoryAccessAttr;
823  IntegerAttr alignmentAttr;
824  if (alignment.has_value()) {
825  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
826  memoryAccessAttr =
827  spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
828  alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
829  }
830 
831  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
832  storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
833  alignmentAttr);
834 
835  return success();
836  }
837 };
838 
839 struct VectorReductionToIntDotProd final
840  : OpRewritePattern<vector::ReductionOp> {
842 
843  LogicalResult matchAndRewrite(vector::ReductionOp op,
844  PatternRewriter &rewriter) const override {
845  if (op.getKind() != vector::CombiningKind::ADD)
846  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
847 
848  auto resultType = dyn_cast<IntegerType>(op.getType());
849  if (!resultType)
850  return rewriter.notifyMatchFailure(op, "result is not an integer");
851 
852  int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
853  if (!llvm::is_contained({32, 64}, resultBitwidth))
854  return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
855 
856  VectorType inVecTy = op.getSourceVectorType();
857  if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
858  inVecTy.getShape().size() != 1 || inVecTy.isScalable())
859  return rewriter.notifyMatchFailure(op, "unsupported vector shape");
860 
861  auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
862  if (!mul)
863  return rewriter.notifyMatchFailure(
864  op, "reduction operand is not 'arith.muli'");
865 
866  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
867  spirv::SDotAccSatOp, false>(op, mul, rewriter)))
868  return success();
869 
870  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
871  spirv::UDotAccSatOp, false>(op, mul, rewriter)))
872  return success();
873 
874  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
875  spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
876  return success();
877 
878  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
879  spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
880  return success();
881 
882  return failure();
883  }
884 
885 private:
886  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
887  typename DotAccOp, bool SwapOperands>
888  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
889  PatternRewriter &rewriter) {
890  auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
891  if (!lhs)
892  return failure();
893  Value lhsIn = lhs.getIn();
894  auto lhsInType = cast<VectorType>(lhsIn.getType());
895  if (!lhsInType.getElementType().isInteger(8))
896  return failure();
897 
898  auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
899  if (!rhs)
900  return failure();
901  Value rhsIn = rhs.getIn();
902  auto rhsInType = cast<VectorType>(rhsIn.getType());
903  if (!rhsInType.getElementType().isInteger(8))
904  return failure();
905 
906  if (op.getSourceVectorType().getNumElements() == 3) {
907  IntegerType i8Type = rewriter.getI8Type();
908  auto v4i8Type = VectorType::get({4}, i8Type);
909  Location loc = op.getLoc();
910  Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
911  lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
912  ValueRange{lhsIn, zero});
913  rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
914  ValueRange{rhsIn, zero});
915  }
916 
917  // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
918  // we have to swap operands instead in that case.
919  if (SwapOperands)
920  std::swap(lhsIn, rhsIn);
921 
922  if (Value acc = op.getAcc()) {
923  rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
924  nullptr);
925  } else {
926  rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
927  nullptr);
928  }
929 
930  return success();
931  }
932 };
933 
934 struct VectorReductionToFPDotProd final
935  : OpConversionPattern<vector::ReductionOp> {
937 
938  LogicalResult
939  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
940  ConversionPatternRewriter &rewriter) const override {
941  if (op.getKind() != vector::CombiningKind::ADD)
942  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
943 
944  auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
945  if (!resultType)
946  return rewriter.notifyMatchFailure(op, "result is not a float");
947 
948  Value vec = adaptor.getVector();
949  Value acc = adaptor.getAcc();
950 
951  auto vectorType = dyn_cast<VectorType>(vec.getType());
952  if (!vectorType) {
953  assert(isa<FloatType>(vec.getType()) &&
954  "Expected the vector to be scalarized");
955  if (acc) {
956  rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
957  return success();
958  }
959 
960  rewriter.replaceOp(op, vec);
961  return success();
962  }
963 
964  Location loc = op.getLoc();
965  Value lhs;
966  Value rhs;
967  if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
968  lhs = mul.getLhs();
969  rhs = mul.getRhs();
970  } else {
971  // If the operand is not a mul, use a vector of ones for the dot operand
972  // to just sum up all values.
973  lhs = vec;
974  Attribute oneAttr =
975  rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
976  oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
977  rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
978  }
979  assert(lhs);
980  assert(rhs);
981 
982  Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
983  if (acc)
984  res = spirv::FAddOp::create(rewriter, loc, acc, res);
985 
986  rewriter.replaceOp(op, res);
987  return success();
988  }
989 };
990 
991 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
993 
994  LogicalResult
995  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
996  ConversionPatternRewriter &rewriter) const override {
997  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
998  Type dstType = typeConverter.convertType(stepOp.getType());
999  if (!dstType)
1000  return failure();
1001 
1002  Location loc = stepOp.getLoc();
1003  int64_t numElements = stepOp.getType().getNumElements();
1004  auto intType =
1005  rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
1006 
1007  // Input vectors of size 1 are converted to scalars by the type converter.
1008  // We just create a constant in this case.
1009  if (numElements == 1) {
1010  Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
1011  rewriter.replaceOp(stepOp, zero);
1012  return success();
1013  }
1014 
1015  SmallVector<Value> source;
1016  source.reserve(numElements);
1017  for (int64_t i = 0; i < numElements; ++i) {
1018  Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1019  Value constOp =
1020  spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1021  source.push_back(constOp);
1022  }
1023  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1024  source);
1025  return success();
1026  }
1027 };
1028 
1029 struct VectorToElementOpConvert final
1030  : OpConversionPattern<vector::ToElementsOp> {
1032 
1033  LogicalResult
1034  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1035  ConversionPatternRewriter &rewriter) const override {
1036 
1037  SmallVector<Value> results(toElementsOp->getNumResults());
1038  Location loc = toElementsOp.getLoc();
1039 
1040  // Input vectors of size 1 are converted to scalars by the type converter.
1041  // We cannot use `spirv::CompositeExtractOp` directly in this case.
1042  // For a scalar source, the result is just the scalar itself.
1043  if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1044  results[0] = adaptor.getSource();
1045  rewriter.replaceOp(toElementsOp, results);
1046  return success();
1047  }
1048 
1049  Type srcElementType = toElementsOp.getElements().getType().front();
1050  Type elementType = getTypeConverter()->convertType(srcElementType);
1051  if (!elementType)
1052  return rewriter.notifyMatchFailure(
1053  toElementsOp,
1054  llvm::formatv("failed to convert element type '{0}' to SPIR-V",
1055  srcElementType));
1056 
1057  for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1058  // Create an CompositeExtract operation only for results that are not
1059  // dead.
1060  if (element.use_empty())
1061  continue;
1062 
1063  Value result = spirv::CompositeExtractOp::create(
1064  rewriter, loc, elementType, adaptor.getSource(),
1065  rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
1066  results[idx] = result;
1067  }
1068 
1069  rewriter.replaceOp(toElementsOp, results);
1070  return success();
1071  }
1072 };
1073 
1074 } // namespace
1075 #define CL_INT_MAX_MIN_OPS \
1076  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1077 
1078 #define GL_INT_MAX_MIN_OPS \
1079  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1080 
1081 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1082 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1083 
1085  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1086  patterns.add<
1087  VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1088  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1089  VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1090  VectorToElementOpConvert, VectorInsertOpConvert,
1091  VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1092  VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1093  VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1094  VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1095  VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
1096  VectorShuffleOpConvert, VectorInterleaveOpConvert,
1097  VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
1098  VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
1099  typeConverter, patterns.getContext(), PatternBenefit(1));
1100 
1101  // Make sure that the more specialized dot product pattern has higher benefit
1102  // than the generic one that extracts all elements.
1103  patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
1104  PatternBenefit(2));
1105 }
1106 
1109  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
1110 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:271
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI8Type()
Definition: Builders.cpp:58
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319