MLIR  20.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/SmallVectorExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include <cassert>
35 #include <cstdint>
36 #include <numeric>
37 
38 using namespace mlir;
39 
40 /// Returns the integer value from the first valid input element, assuming Value
41 /// inputs are defined by a constant index ops and Attribute inputs are integer
42 /// attributes.
43 static uint64_t getFirstIntValue(ValueRange values) {
44  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
45 }
46 static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
47  return cast<IntegerAttr>(attr[0]).getInt();
48 }
49 static uint64_t getFirstIntValue(ArrayAttr attr) {
50  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
51 }
52 static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
53  auto attr = foldResults[0].dyn_cast<Attribute>();
54  if (attr)
55  return getFirstIntValue(attr);
56 
57  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
58 }
59 
60 /// Returns the number of bits for the given scalar/vector type.
61 static int getNumBits(Type type) {
62  // TODO: This does not take into account any memory layout or widening
63  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
64  // though in practice it will likely be stored as in a 4xi64 vector register.
65  if (auto vectorType = dyn_cast<VectorType>(type))
66  return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
67  return type.getIntOrFloatBitWidth();
68 }
69 
70 namespace {
71 
72 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
74 
75  LogicalResult
76  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
77  ConversionPatternRewriter &rewriter) const override {
78  Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
79  if (!dstType)
80  return failure();
81 
82  // If dstType is same as the source type or the vector size is 1, it can be
83  // directly replaced by the source.
84  if (dstType == adaptor.getSource().getType() ||
85  shapeCastOp.getResultVectorType().getNumElements() == 1) {
86  rewriter.replaceOp(shapeCastOp, adaptor.getSource());
87  return success();
88  }
89 
90  // Lowering for size-n vectors when n > 1 hasn't been implemented.
91  return failure();
92  }
93 };
94 
95 struct VectorBitcastConvert final
96  : public OpConversionPattern<vector::BitCastOp> {
98 
99  LogicalResult
100  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
101  ConversionPatternRewriter &rewriter) const override {
102  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
103  if (!dstType)
104  return failure();
105 
106  if (dstType == adaptor.getSource().getType()) {
107  rewriter.replaceOp(bitcastOp, adaptor.getSource());
108  return success();
109  }
110 
111  // Check that the source and destination type have the same bitwidth.
112  // Depending on the target environment, we may need to emulate certain
113  // types, which can cause issue with bitcast.
114  Type srcType = adaptor.getSource().getType();
115  if (getNumBits(dstType) != getNumBits(srcType)) {
116  return rewriter.notifyMatchFailure(
117  bitcastOp,
118  llvm::formatv("different source ({0}) and target ({1}) bitwidth",
119  srcType, dstType));
120  }
121 
122  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
123  adaptor.getSource());
124  return success();
125  }
126 };
127 
128 struct VectorBroadcastConvert final
129  : public OpConversionPattern<vector::BroadcastOp> {
131 
132  LogicalResult
133  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
134  ConversionPatternRewriter &rewriter) const override {
135  Type resultType =
136  getTypeConverter()->convertType(castOp.getResultVectorType());
137  if (!resultType)
138  return failure();
139 
140  if (isa<spirv::ScalarType>(resultType)) {
141  rewriter.replaceOp(castOp, adaptor.getSource());
142  return success();
143  }
144 
145  SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
146  adaptor.getSource());
147  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
148  source);
149  return success();
150  }
151 };
152 
153 struct VectorExtractOpConvert final
154  : public OpConversionPattern<vector::ExtractOp> {
156 
157  LogicalResult
158  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
159  ConversionPatternRewriter &rewriter) const override {
160  if (extractOp.hasDynamicPosition())
161  return failure();
162 
163  Type dstType = getTypeConverter()->convertType(extractOp.getType());
164  if (!dstType)
165  return failure();
166 
167  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
168  rewriter.replaceOp(extractOp, adaptor.getVector());
169  return success();
170  }
171 
172  int32_t id = getFirstIntValue(extractOp.getMixedPosition());
173  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
174  extractOp, adaptor.getVector(), id);
175  return success();
176  }
177 };
178 
179 struct VectorExtractStridedSliceOpConvert final
180  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
182 
183  LogicalResult
184  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override {
186  Type dstType = getTypeConverter()->convertType(extractOp.getType());
187  if (!dstType)
188  return failure();
189 
190  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
191  uint64_t size = getFirstIntValue(extractOp.getSizes());
192  uint64_t stride = getFirstIntValue(extractOp.getStrides());
193  if (stride != 1)
194  return failure();
195 
196  Value srcVector = adaptor.getOperands().front();
197 
198  // Extract vector<1xT> case.
199  if (isa<spirv::ScalarType>(dstType)) {
200  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
201  srcVector, offset);
202  return success();
203  }
204 
205  SmallVector<int32_t, 2> indices(size);
206  std::iota(indices.begin(), indices.end(), offset);
207 
208  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
209  extractOp, dstType, srcVector, srcVector,
210  rewriter.getI32ArrayAttr(indices));
211 
212  return success();
213  }
214 };
215 
216 template <class SPIRVFMAOp>
217 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
219 
220  LogicalResult
221  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
222  ConversionPatternRewriter &rewriter) const override {
223  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
224  if (!dstType)
225  return failure();
226  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
227  adaptor.getRhs(), adaptor.getAcc());
228  return success();
229  }
230 };
231 
232 struct VectorInsertOpConvert final
233  : public OpConversionPattern<vector::InsertOp> {
235 
236  LogicalResult
237  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
238  ConversionPatternRewriter &rewriter) const override {
239  if (isa<VectorType>(insertOp.getSourceType()))
240  return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
241  if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
242  return rewriter.notifyMatchFailure(insertOp,
243  "unsupported dest vector type");
244 
245  // Special case for inserting scalar values into size-1 vectors.
246  if (insertOp.getSourceType().isIntOrFloat() &&
247  insertOp.getDestVectorType().getNumElements() == 1) {
248  rewriter.replaceOp(insertOp, adaptor.getSource());
249  return success();
250  }
251 
252  int32_t id = getFirstIntValue(insertOp.getMixedPosition());
253  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
254  insertOp, adaptor.getSource(), adaptor.getDest(), id);
255  return success();
256  }
257 };
258 
259 struct VectorExtractElementOpConvert final
260  : public OpConversionPattern<vector::ExtractElementOp> {
262 
263  LogicalResult
264  matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const override {
266  Type resultType = getTypeConverter()->convertType(extractOp.getType());
267  if (!resultType)
268  return failure();
269 
270  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
271  rewriter.replaceOp(extractOp, adaptor.getVector());
272  return success();
273  }
274 
275  APInt cstPos;
276  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
277  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
278  extractOp, resultType, adaptor.getVector(),
279  rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
280  else
281  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
282  extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
283  return success();
284  }
285 };
286 
287 struct VectorInsertElementOpConvert final
288  : public OpConversionPattern<vector::InsertElementOp> {
290 
291  LogicalResult
292  matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
293  ConversionPatternRewriter &rewriter) const override {
294  Type vectorType = getTypeConverter()->convertType(insertOp.getType());
295  if (!vectorType)
296  return failure();
297 
298  if (isa<spirv::ScalarType>(vectorType)) {
299  rewriter.replaceOp(insertOp, adaptor.getSource());
300  return success();
301  }
302 
303  APInt cstPos;
304  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
305  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
306  insertOp, adaptor.getSource(), adaptor.getDest(),
307  cstPos.getSExtValue());
308  else
309  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
310  insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
311  adaptor.getPosition());
312  return success();
313  }
314 };
315 
316 struct VectorInsertStridedSliceOpConvert final
317  : public OpConversionPattern<vector::InsertStridedSliceOp> {
319 
320  LogicalResult
321  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
322  ConversionPatternRewriter &rewriter) const override {
323  Value srcVector = adaptor.getOperands().front();
324  Value dstVector = adaptor.getOperands().back();
325 
326  uint64_t stride = getFirstIntValue(insertOp.getStrides());
327  if (stride != 1)
328  return failure();
329  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
330 
331  if (isa<spirv::ScalarType>(srcVector.getType())) {
332  assert(!isa<spirv::ScalarType>(dstVector.getType()));
333  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
334  insertOp, dstVector.getType(), srcVector, dstVector,
335  rewriter.getI32ArrayAttr(offset));
336  return success();
337  }
338 
339  uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
340  uint64_t insertSize =
341  cast<VectorType>(srcVector.getType()).getNumElements();
342 
343  SmallVector<int32_t, 2> indices(totalSize);
344  std::iota(indices.begin(), indices.end(), 0);
345  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
346  totalSize);
347 
348  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
349  insertOp, dstVector.getType(), dstVector, srcVector,
350  rewriter.getI32ArrayAttr(indices));
351 
352  return success();
353  }
354 };
355 
356 static SmallVector<Value> extractAllElements(
357  vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
358  VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
359  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
360  SmallVector<Value> values;
361  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
362  Location loc = reduceOp.getLoc();
363 
364  for (int i = 0; i < numElements; ++i) {
365  values.push_back(rewriter.create<spirv::CompositeExtractOp>(
366  loc, srcVectorType.getElementType(), adaptor.getVector(),
367  rewriter.getI32ArrayAttr({i})));
368  }
369  if (Value acc = adaptor.getAcc())
370  values.push_back(acc);
371 
372  return values;
373 }
374 
375 struct ReductionRewriteInfo {
376  Type resultType;
377  SmallVector<Value> extractedElements;
378 };
379 
380 FailureOr<ReductionRewriteInfo> static getReductionInfo(
381  vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
382  ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
383  Type resultType = typeConverter.convertType(op.getType());
384  if (!resultType)
385  return failure();
386 
387  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
388  if (!srcVectorType || srcVectorType.getRank() != 1)
389  return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
390 
391  SmallVector<Value> extractedElements =
392  extractAllElements(op, adaptor, srcVectorType, rewriter);
393 
394  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
395 }
396 
397 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
398  typename SPIRVSMinOp>
399 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
401 
402  LogicalResult
403  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
404  ConversionPatternRewriter &rewriter) const override {
405  auto reductionInfo =
406  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
407  if (failed(reductionInfo))
408  return failure();
409 
410  auto [resultType, extractedElements] = *reductionInfo;
411  Location loc = reduceOp->getLoc();
412  Value result = extractedElements.front();
413  for (Value next : llvm::drop_begin(extractedElements)) {
414  switch (reduceOp.getKind()) {
415 
416 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
417  case vector::CombiningKind::kind: \
418  if (llvm::isa<IntegerType>(resultType)) { \
419  result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
420  } else { \
421  assert(llvm::isa<FloatType>(resultType)); \
422  result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
423  } \
424  break
425 
426 #define INT_OR_FLOAT_CASE(kind, fop) \
427  case vector::CombiningKind::kind: \
428  result = rewriter.create<fop>(loc, resultType, result, next); \
429  break
430 
431  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
432  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
433  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
434  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
435  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
436  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
437 
438  case vector::CombiningKind::AND:
439  case vector::CombiningKind::OR:
440  case vector::CombiningKind::XOR:
441  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
442  default:
443  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
444  }
445 #undef INT_AND_FLOAT_CASE
446 #undef INT_OR_FLOAT_CASE
447  }
448 
449  rewriter.replaceOp(reduceOp, result);
450  return success();
451  }
452 };
453 
454 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
455 struct VectorReductionFloatMinMax final
456  : OpConversionPattern<vector::ReductionOp> {
458 
459  LogicalResult
460  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
461  ConversionPatternRewriter &rewriter) const override {
462  auto reductionInfo =
463  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
464  if (failed(reductionInfo))
465  return failure();
466 
467  auto [resultType, extractedElements] = *reductionInfo;
468  Location loc = reduceOp->getLoc();
469  Value result = extractedElements.front();
470  for (Value next : llvm::drop_begin(extractedElements)) {
471  switch (reduceOp.getKind()) {
472 
473 #define INT_OR_FLOAT_CASE(kind, fop) \
474  case vector::CombiningKind::kind: \
475  result = rewriter.create<fop>(loc, resultType, result, next); \
476  break
477 
478  INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
479  INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
480  INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
481  INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
482 
483  default:
484  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
485  }
486 #undef INT_OR_FLOAT_CASE
487  }
488 
489  rewriter.replaceOp(reduceOp, result);
490  return success();
491  }
492 };
493 
494 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
495 public:
497 
498  LogicalResult
499  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
500  ConversionPatternRewriter &rewriter) const override {
501  Type dstType = getTypeConverter()->convertType(op.getType());
502  if (!dstType)
503  return failure();
504  if (isa<spirv::ScalarType>(dstType)) {
505  rewriter.replaceOp(op, adaptor.getInput());
506  } else {
507  auto dstVecType = cast<VectorType>(dstType);
508  SmallVector<Value, 4> source(dstVecType.getNumElements(),
509  adaptor.getInput());
510  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
511  source);
512  }
513  return success();
514  }
515 };
516 
517 struct VectorShuffleOpConvert final
518  : public OpConversionPattern<vector::ShuffleOp> {
520 
521  LogicalResult
522  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
523  ConversionPatternRewriter &rewriter) const override {
524  VectorType oldResultType = shuffleOp.getResultVectorType();
525  Type newResultType = getTypeConverter()->convertType(oldResultType);
526  if (!newResultType)
527  return rewriter.notifyMatchFailure(shuffleOp,
528  "unsupported result vector type");
529 
530  SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
531  shuffleOp.getMask(), [](Attribute attr) -> int32_t {
532  return cast<IntegerAttr>(attr).getValue().getZExtValue();
533  });
534 
535  VectorType oldV1Type = shuffleOp.getV1VectorType();
536  VectorType oldV2Type = shuffleOp.getV2VectorType();
537 
538  // When both operands and the result are SPIR-V vectors, emit a SPIR-V
539  // shuffle.
540  if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
541  oldResultType.getNumElements() > 1) {
542  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
543  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
544  rewriter.getI32ArrayAttr(mask));
545  return success();
546  }
547 
548  // When at least one of the operands or the result becomes a scalar after
549  // type conversion for SPIR-V, extract all the required elements and
550  // construct the result vector.
551  auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
552  Value scalarOrVec, int32_t idx) -> Value {
553  if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
554  return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
555  idx);
556 
557  assert(idx == 0 && "Invalid scalar element index");
558  return scalarOrVec;
559  };
560 
561  int32_t numV1Elems = oldV1Type.getNumElements();
562  SmallVector<Value> newOperands(mask.size());
563  for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
564  Value vec = adaptor.getV1();
565  int32_t elementIdx = shuffleIdx;
566  if (elementIdx >= numV1Elems) {
567  vec = adaptor.getV2();
568  elementIdx -= numV1Elems;
569  }
570 
571  newOperand = getElementAtIdx(vec, elementIdx);
572  }
573 
574  // Handle the scalar result corner case.
575  if (newOperands.size() == 1) {
576  rewriter.replaceOp(shuffleOp, newOperands.front());
577  return success();
578  }
579 
580  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
581  shuffleOp, newResultType, newOperands);
582  return success();
583  }
584 };
585 
586 struct VectorInterleaveOpConvert final
587  : public OpConversionPattern<vector::InterleaveOp> {
589 
590  LogicalResult
591  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
592  ConversionPatternRewriter &rewriter) const override {
593  // Check the result vector type.
594  VectorType oldResultType = interleaveOp.getResultVectorType();
595  Type newResultType = getTypeConverter()->convertType(oldResultType);
596  if (!newResultType)
597  return rewriter.notifyMatchFailure(interleaveOp,
598  "unsupported result vector type");
599 
600  // Interleave the indices.
601  VectorType sourceType = interleaveOp.getSourceVectorType();
602  int n = sourceType.getNumElements();
603 
604  // Input vectors of size 1 are converted to scalars by the type converter.
605  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
606  // use `spirv::CompositeConstructOp`.
607  if (n == 1) {
608  Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
609  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
610  interleaveOp, newResultType, newOperands);
611  return success();
612  }
613 
614  auto seq = llvm::seq<int64_t>(2 * n);
615  auto indices = llvm::map_to_vector(
616  seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
617 
618  // Emit a SPIR-V shuffle.
619  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
620  interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
621  rewriter.getI32ArrayAttr(indices));
622 
623  return success();
624  }
625 };
626 
627 struct VectorDeinterleaveOpConvert final
628  : public OpConversionPattern<vector::DeinterleaveOp> {
630 
631  LogicalResult
632  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
633  ConversionPatternRewriter &rewriter) const override {
634 
635  // Check the result vector type.
636  VectorType oldResultType = deinterleaveOp.getResultVectorType();
637  Type newResultType = getTypeConverter()->convertType(oldResultType);
638  if (!newResultType)
639  return rewriter.notifyMatchFailure(deinterleaveOp,
640  "unsupported result vector type");
641 
642  Location loc = deinterleaveOp->getLoc();
643 
644  // Deinterleave the indices.
645  Value sourceVector = adaptor.getSource();
646  VectorType sourceType = deinterleaveOp.getSourceVectorType();
647  int n = sourceType.getNumElements();
648 
649  // Output vectors of size 1 are converted to scalars by the type converter.
650  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
651  // use `spirv::CompositeExtractOp`.
652  if (n == 2) {
653  auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
654  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
655 
656  auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
657  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
658 
659  rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
660  return success();
661  }
662 
663  // Indices for `shuffleEven` (result 0).
664  auto seqEven = llvm::seq<int64_t>(n / 2);
665  auto indicesEven =
666  llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
667 
668  // Indices for `shuffleOdd` (result 1).
669  auto seqOdd = llvm::seq<int64_t>(n / 2);
670  auto indicesOdd =
671  llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
672 
673  // Create two SPIR-V shuffles.
674  auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
675  loc, newResultType, sourceVector, sourceVector,
676  rewriter.getI32ArrayAttr(indicesEven));
677 
678  auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
679  loc, newResultType, sourceVector, sourceVector,
680  rewriter.getI32ArrayAttr(indicesOdd));
681 
682  rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
683  return success();
684  }
685 };
686 
687 struct VectorLoadOpConverter final
688  : public OpConversionPattern<vector::LoadOp> {
690 
691  LogicalResult
692  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
693  ConversionPatternRewriter &rewriter) const override {
694  auto memrefType = loadOp.getMemRefType();
695  auto attr =
696  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
697  if (!attr)
698  return rewriter.notifyMatchFailure(
699  loadOp, "expected spirv.storage_class memory space");
700 
701  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
702  auto loc = loadOp.getLoc();
703  Value accessChain =
704  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
705  adaptor.getIndices(), loc, rewriter);
706  if (!accessChain)
707  return rewriter.notifyMatchFailure(
708  loadOp, "failed to get memref element pointer");
709 
710  spirv::StorageClass storageClass = attr.getValue();
711  auto vectorType = loadOp.getVectorType();
712  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
713  Value castedAccessChain =
714  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
715  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
716  castedAccessChain);
717 
718  return success();
719  }
720 };
721 
722 struct VectorStoreOpConverter final
723  : public OpConversionPattern<vector::StoreOp> {
725 
726  LogicalResult
727  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
728  ConversionPatternRewriter &rewriter) const override {
729  auto memrefType = storeOp.getMemRefType();
730  auto attr =
731  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
732  if (!attr)
733  return rewriter.notifyMatchFailure(
734  storeOp, "expected spirv.storage_class memory space");
735 
736  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
737  auto loc = storeOp.getLoc();
738  Value accessChain =
739  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
740  adaptor.getIndices(), loc, rewriter);
741  if (!accessChain)
742  return rewriter.notifyMatchFailure(
743  storeOp, "failed to get memref element pointer");
744 
745  spirv::StorageClass storageClass = attr.getValue();
746  auto vectorType = storeOp.getVectorType();
747  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
748  Value castedAccessChain =
749  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
750  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
751  adaptor.getValueToStore());
752 
753  return success();
754  }
755 };
756 
757 struct VectorReductionToIntDotProd final
758  : OpRewritePattern<vector::ReductionOp> {
760 
761  LogicalResult matchAndRewrite(vector::ReductionOp op,
762  PatternRewriter &rewriter) const override {
763  if (op.getKind() != vector::CombiningKind::ADD)
764  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
765 
766  auto resultType = dyn_cast<IntegerType>(op.getType());
767  if (!resultType)
768  return rewriter.notifyMatchFailure(op, "result is not an integer");
769 
770  int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
771  if (!llvm::is_contained({32, 64}, resultBitwidth))
772  return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
773 
774  VectorType inVecTy = op.getSourceVectorType();
775  if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
776  inVecTy.getShape().size() != 1 || inVecTy.isScalable())
777  return rewriter.notifyMatchFailure(op, "unsupported vector shape");
778 
779  auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
780  if (!mul)
781  return rewriter.notifyMatchFailure(
782  op, "reduction operand is not 'arith.muli'");
783 
784  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
785  spirv::SDotAccSatOp, false>(op, mul, rewriter)))
786  return success();
787 
788  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
789  spirv::UDotAccSatOp, false>(op, mul, rewriter)))
790  return success();
791 
792  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
793  spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
794  return success();
795 
796  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
797  spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
798  return success();
799 
800  return failure();
801  }
802 
803 private:
804  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
805  typename DotAccOp, bool SwapOperands>
806  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
807  PatternRewriter &rewriter) {
808  auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
809  if (!lhs)
810  return failure();
811  Value lhsIn = lhs.getIn();
812  auto lhsInType = cast<VectorType>(lhsIn.getType());
813  if (!lhsInType.getElementType().isInteger(8))
814  return failure();
815 
816  auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
817  if (!rhs)
818  return failure();
819  Value rhsIn = rhs.getIn();
820  auto rhsInType = cast<VectorType>(rhsIn.getType());
821  if (!rhsInType.getElementType().isInteger(8))
822  return failure();
823 
824  if (op.getSourceVectorType().getNumElements() == 3) {
825  IntegerType i8Type = rewriter.getI8Type();
826  auto v4i8Type = VectorType::get({4}, i8Type);
827  Location loc = op.getLoc();
828  Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
829  lhsIn = rewriter.create<spirv::CompositeConstructOp>(
830  loc, v4i8Type, ValueRange{lhsIn, zero});
831  rhsIn = rewriter.create<spirv::CompositeConstructOp>(
832  loc, v4i8Type, ValueRange{rhsIn, zero});
833  }
834 
835  // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
836  // we have to swap operands instead in that case.
837  if (SwapOperands)
838  std::swap(lhsIn, rhsIn);
839 
840  if (Value acc = op.getAcc()) {
841  rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
842  nullptr);
843  } else {
844  rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
845  nullptr);
846  }
847 
848  return success();
849  }
850 };
851 
852 struct VectorReductionToFPDotProd final
853  : OpConversionPattern<vector::ReductionOp> {
855 
856  LogicalResult
857  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
858  ConversionPatternRewriter &rewriter) const override {
859  if (op.getKind() != vector::CombiningKind::ADD)
860  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
861 
862  auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
863  if (!resultType)
864  return rewriter.notifyMatchFailure(op, "result is not a float");
865 
866  Value vec = adaptor.getVector();
867  Value acc = adaptor.getAcc();
868 
869  auto vectorType = dyn_cast<VectorType>(vec.getType());
870  if (!vectorType) {
871  assert(isa<FloatType>(vec.getType()) &&
872  "Expected the vector to be scalarized");
873  if (acc) {
874  rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
875  return success();
876  }
877 
878  rewriter.replaceOp(op, vec);
879  return success();
880  }
881 
882  Location loc = op.getLoc();
883  Value lhs;
884  Value rhs;
885  if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
886  lhs = mul.getLhs();
887  rhs = mul.getRhs();
888  } else {
889  // If the operand is not a mul, use a vector of ones for the dot operand
890  // to just sum up all values.
891  lhs = vec;
892  Attribute oneAttr =
893  rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
894  oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
895  rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
896  }
897  assert(lhs);
898  assert(rhs);
899 
900  Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
901  if (acc)
902  res = rewriter.create<spirv::FAddOp>(loc, acc, res);
903 
904  rewriter.replaceOp(op, res);
905  return success();
906  }
907 };
908 
909 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
911 
912  LogicalResult
913  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
914  ConversionPatternRewriter &rewriter) const override {
915  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
916  Type dstType = typeConverter.convertType(stepOp.getType());
917  if (!dstType)
918  return failure();
919 
920  Location loc = stepOp.getLoc();
921  int64_t numElements = stepOp.getType().getNumElements();
922  auto intType =
923  rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
924 
925  // Input vectors of size 1 are converted to scalars by the type converter.
926  // We just create a constant in this case.
927  if (numElements == 1) {
928  Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
929  rewriter.replaceOp(stepOp, zero);
930  return success();
931  }
932 
933  SmallVector<Value> source;
934  source.reserve(numElements);
935  for (int64_t i = 0; i < numElements; ++i) {
936  Attribute intAttr = rewriter.getIntegerAttr(intType, i);
937  Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
938  source.push_back(constOp);
939  }
940  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
941  source);
942  return success();
943  }
944 };
945 
946 } // namespace
947 #define CL_INT_MAX_MIN_OPS \
948  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
949 
950 #define GL_INT_MAX_MIN_OPS \
951  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
952 
953 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
954 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
955 
957  RewritePatternSet &patterns) {
958  patterns.add<
959  VectorBitcastConvert, VectorBroadcastConvert,
960  VectorExtractElementOpConvert, VectorExtractOpConvert,
961  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
962  VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
963  VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
964  VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
965  VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
966  VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
967  VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
968  VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
969  VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
970  VectorStepOpConvert>(typeConverter, patterns.getContext(),
971  PatternBenefit(1));
972 
973  // Make sure that the more specialized dot product pattern has higher benefit
974  // than the generic one that extracts all elements.
975  patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
976  PatternBenefit(2));
977 }
978 
980  RewritePatternSet &patterns) {
981  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
982 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
#define MINUI(lhs, rhs)
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:287
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
IntegerType getI8Type()
Definition: Builders.cpp:83
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362