MLIR  20.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Location.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/SmallVectorExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include <cassert>
36 #include <cstdint>
37 #include <numeric>
38 
39 using namespace mlir;
40 
41 /// Returns the integer value from the first valid input element, assuming Value
42 /// inputs are defined by a constant index ops and Attribute inputs are integer
43 /// attributes.
44 static uint64_t getFirstIntValue(ArrayAttr attr) {
45  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
46 }
47 
48 /// Returns the number of bits for the given scalar/vector type.
49 static int getNumBits(Type type) {
50  // TODO: This does not take into account any memory layout or widening
51  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
52  // though in practice it will likely be stored as in a 4xi64 vector register.
53  if (auto vectorType = dyn_cast<VectorType>(type))
54  return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
55  return type.getIntOrFloatBitWidth();
56 }
57 
58 namespace {
59 
60 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
62 
63  LogicalResult
64  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
65  ConversionPatternRewriter &rewriter) const override {
66  Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
67  if (!dstType)
68  return failure();
69 
70  // If dstType is same as the source type or the vector size is 1, it can be
71  // directly replaced by the source.
72  if (dstType == adaptor.getSource().getType() ||
73  shapeCastOp.getResultVectorType().getNumElements() == 1) {
74  rewriter.replaceOp(shapeCastOp, adaptor.getSource());
75  return success();
76  }
77 
78  // Lowering for size-n vectors when n > 1 hasn't been implemented.
79  return failure();
80  }
81 };
82 
83 struct VectorBitcastConvert final
84  : public OpConversionPattern<vector::BitCastOp> {
86 
87  LogicalResult
88  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const override {
90  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
91  if (!dstType)
92  return failure();
93 
94  if (dstType == adaptor.getSource().getType()) {
95  rewriter.replaceOp(bitcastOp, adaptor.getSource());
96  return success();
97  }
98 
99  // Check that the source and destination type have the same bitwidth.
100  // Depending on the target environment, we may need to emulate certain
101  // types, which can cause issue with bitcast.
102  Type srcType = adaptor.getSource().getType();
103  if (getNumBits(dstType) != getNumBits(srcType)) {
104  return rewriter.notifyMatchFailure(
105  bitcastOp,
106  llvm::formatv("different source ({0}) and target ({1}) bitwidth",
107  srcType, dstType));
108  }
109 
110  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
111  adaptor.getSource());
112  return success();
113  }
114 };
115 
116 struct VectorBroadcastConvert final
117  : public OpConversionPattern<vector::BroadcastOp> {
119 
120  LogicalResult
121  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override {
123  Type resultType =
124  getTypeConverter()->convertType(castOp.getResultVectorType());
125  if (!resultType)
126  return failure();
127 
128  if (isa<spirv::ScalarType>(resultType)) {
129  rewriter.replaceOp(castOp, adaptor.getSource());
130  return success();
131  }
132 
133  SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
134  adaptor.getSource());
135  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
136  source);
137  return success();
138  }
139 };
140 
141 struct VectorExtractOpConvert final
142  : public OpConversionPattern<vector::ExtractOp> {
144 
145  LogicalResult
146  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
147  ConversionPatternRewriter &rewriter) const override {
148  Type dstType = getTypeConverter()->convertType(extractOp.getType());
149  if (!dstType)
150  return failure();
151 
152  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
153  rewriter.replaceOp(extractOp, adaptor.getVector());
154  return success();
155  }
156 
157  if (std::optional<int64_t> id =
158  getConstantIntValue(extractOp.getMixedPosition()[0]))
159  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
160  extractOp, dstType, adaptor.getVector(),
161  rewriter.getI32ArrayAttr(id.value()));
162  else
163  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
164  extractOp, dstType, adaptor.getVector(),
165  adaptor.getDynamicPosition()[0]);
166  return success();
167  }
168 };
169 
170 struct VectorExtractStridedSliceOpConvert final
171  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
173 
174  LogicalResult
175  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
176  ConversionPatternRewriter &rewriter) const override {
177  Type dstType = getTypeConverter()->convertType(extractOp.getType());
178  if (!dstType)
179  return failure();
180 
181  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
182  uint64_t size = getFirstIntValue(extractOp.getSizes());
183  uint64_t stride = getFirstIntValue(extractOp.getStrides());
184  if (stride != 1)
185  return failure();
186 
187  Value srcVector = adaptor.getOperands().front();
188 
189  // Extract vector<1xT> case.
190  if (isa<spirv::ScalarType>(dstType)) {
191  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
192  srcVector, offset);
193  return success();
194  }
195 
196  SmallVector<int32_t, 2> indices(size);
197  std::iota(indices.begin(), indices.end(), offset);
198 
199  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
200  extractOp, dstType, srcVector, srcVector,
201  rewriter.getI32ArrayAttr(indices));
202 
203  return success();
204  }
205 };
206 
207 template <class SPIRVFMAOp>
208 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
210 
211  LogicalResult
212  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override {
214  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
215  if (!dstType)
216  return failure();
217  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
218  adaptor.getRhs(), adaptor.getAcc());
219  return success();
220  }
221 };
222 
223 struct VectorInsertOpConvert final
224  : public OpConversionPattern<vector::InsertOp> {
226 
227  LogicalResult
228  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
229  ConversionPatternRewriter &rewriter) const override {
230  if (isa<VectorType>(insertOp.getSourceType()))
231  return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
232  if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
233  return rewriter.notifyMatchFailure(insertOp,
234  "unsupported dest vector type");
235 
236  // Special case for inserting scalar values into size-1 vectors.
237  if (insertOp.getSourceType().isIntOrFloat() &&
238  insertOp.getDestVectorType().getNumElements() == 1) {
239  rewriter.replaceOp(insertOp, adaptor.getSource());
240  return success();
241  }
242 
243  if (std::optional<int64_t> id =
244  getConstantIntValue(insertOp.getMixedPosition()[0]))
245  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
246  insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
247  else
248  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
249  insertOp, insertOp.getDest(), adaptor.getSource(),
250  adaptor.getDynamicPosition()[0]);
251  return success();
252  }
253 };
254 
255 struct VectorExtractElementOpConvert final
256  : public OpConversionPattern<vector::ExtractElementOp> {
258 
259  LogicalResult
260  matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
261  ConversionPatternRewriter &rewriter) const override {
262  Type resultType = getTypeConverter()->convertType(extractOp.getType());
263  if (!resultType)
264  return failure();
265 
266  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
267  rewriter.replaceOp(extractOp, adaptor.getVector());
268  return success();
269  }
270 
271  APInt cstPos;
272  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
273  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
274  extractOp, resultType, adaptor.getVector(),
275  rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
276  else
277  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
278  extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
279  return success();
280  }
281 };
282 
283 struct VectorInsertElementOpConvert final
284  : public OpConversionPattern<vector::InsertElementOp> {
286 
287  LogicalResult
288  matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
289  ConversionPatternRewriter &rewriter) const override {
290  Type vectorType = getTypeConverter()->convertType(insertOp.getType());
291  if (!vectorType)
292  return failure();
293 
294  if (isa<spirv::ScalarType>(vectorType)) {
295  rewriter.replaceOp(insertOp, adaptor.getSource());
296  return success();
297  }
298 
299  APInt cstPos;
300  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
301  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
302  insertOp, adaptor.getSource(), adaptor.getDest(),
303  cstPos.getSExtValue());
304  else
305  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
306  insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
307  adaptor.getPosition());
308  return success();
309  }
310 };
311 
312 struct VectorInsertStridedSliceOpConvert final
313  : public OpConversionPattern<vector::InsertStridedSliceOp> {
315 
316  LogicalResult
317  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
318  ConversionPatternRewriter &rewriter) const override {
319  Value srcVector = adaptor.getOperands().front();
320  Value dstVector = adaptor.getOperands().back();
321 
322  uint64_t stride = getFirstIntValue(insertOp.getStrides());
323  if (stride != 1)
324  return failure();
325  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
326 
327  if (isa<spirv::ScalarType>(srcVector.getType())) {
328  assert(!isa<spirv::ScalarType>(dstVector.getType()));
329  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
330  insertOp, dstVector.getType(), srcVector, dstVector,
331  rewriter.getI32ArrayAttr(offset));
332  return success();
333  }
334 
335  uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
336  uint64_t insertSize =
337  cast<VectorType>(srcVector.getType()).getNumElements();
338 
339  SmallVector<int32_t, 2> indices(totalSize);
340  std::iota(indices.begin(), indices.end(), 0);
341  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
342  totalSize);
343 
344  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
345  insertOp, dstVector.getType(), dstVector, srcVector,
346  rewriter.getI32ArrayAttr(indices));
347 
348  return success();
349  }
350 };
351 
352 static SmallVector<Value> extractAllElements(
353  vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
354  VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
355  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
356  SmallVector<Value> values;
357  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
358  Location loc = reduceOp.getLoc();
359 
360  for (int i = 0; i < numElements; ++i) {
361  values.push_back(rewriter.create<spirv::CompositeExtractOp>(
362  loc, srcVectorType.getElementType(), adaptor.getVector(),
363  rewriter.getI32ArrayAttr({i})));
364  }
365  if (Value acc = adaptor.getAcc())
366  values.push_back(acc);
367 
368  return values;
369 }
370 
371 struct ReductionRewriteInfo {
372  Type resultType;
373  SmallVector<Value> extractedElements;
374 };
375 
376 FailureOr<ReductionRewriteInfo> static getReductionInfo(
377  vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
378  ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
379  Type resultType = typeConverter.convertType(op.getType());
380  if (!resultType)
381  return failure();
382 
383  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
384  if (!srcVectorType || srcVectorType.getRank() != 1)
385  return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
386 
387  SmallVector<Value> extractedElements =
388  extractAllElements(op, adaptor, srcVectorType, rewriter);
389 
390  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
391 }
392 
393 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
394  typename SPIRVSMinOp>
395 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
397 
398  LogicalResult
399  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
400  ConversionPatternRewriter &rewriter) const override {
401  auto reductionInfo =
402  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
403  if (failed(reductionInfo))
404  return failure();
405 
406  auto [resultType, extractedElements] = *reductionInfo;
407  Location loc = reduceOp->getLoc();
408  Value result = extractedElements.front();
409  for (Value next : llvm::drop_begin(extractedElements)) {
410  switch (reduceOp.getKind()) {
411 
412 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
413  case vector::CombiningKind::kind: \
414  if (llvm::isa<IntegerType>(resultType)) { \
415  result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
416  } else { \
417  assert(llvm::isa<FloatType>(resultType)); \
418  result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
419  } \
420  break
421 
422 #define INT_OR_FLOAT_CASE(kind, fop) \
423  case vector::CombiningKind::kind: \
424  result = rewriter.create<fop>(loc, resultType, result, next); \
425  break
426 
427  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
428  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
429  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
430  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
431  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
432  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
433 
434  case vector::CombiningKind::AND:
435  case vector::CombiningKind::OR:
436  case vector::CombiningKind::XOR:
437  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
438  default:
439  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
440  }
441 #undef INT_AND_FLOAT_CASE
442 #undef INT_OR_FLOAT_CASE
443  }
444 
445  rewriter.replaceOp(reduceOp, result);
446  return success();
447  }
448 };
449 
450 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
451 struct VectorReductionFloatMinMax final
452  : OpConversionPattern<vector::ReductionOp> {
454 
455  LogicalResult
456  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
457  ConversionPatternRewriter &rewriter) const override {
458  auto reductionInfo =
459  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
460  if (failed(reductionInfo))
461  return failure();
462 
463  auto [resultType, extractedElements] = *reductionInfo;
464  Location loc = reduceOp->getLoc();
465  Value result = extractedElements.front();
466  for (Value next : llvm::drop_begin(extractedElements)) {
467  switch (reduceOp.getKind()) {
468 
469 #define INT_OR_FLOAT_CASE(kind, fop) \
470  case vector::CombiningKind::kind: \
471  result = rewriter.create<fop>(loc, resultType, result, next); \
472  break
473 
474  INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
475  INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
476  INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
477  INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
478 
479  default:
480  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
481  }
482 #undef INT_OR_FLOAT_CASE
483  }
484 
485  rewriter.replaceOp(reduceOp, result);
486  return success();
487  }
488 };
489 
490 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
491 public:
493 
494  LogicalResult
495  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
496  ConversionPatternRewriter &rewriter) const override {
497  Type dstType = getTypeConverter()->convertType(op.getType());
498  if (!dstType)
499  return failure();
500  if (isa<spirv::ScalarType>(dstType)) {
501  rewriter.replaceOp(op, adaptor.getInput());
502  } else {
503  auto dstVecType = cast<VectorType>(dstType);
504  SmallVector<Value, 4> source(dstVecType.getNumElements(),
505  adaptor.getInput());
506  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
507  source);
508  }
509  return success();
510  }
511 };
512 
513 struct VectorShuffleOpConvert final
514  : public OpConversionPattern<vector::ShuffleOp> {
516 
517  LogicalResult
518  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
519  ConversionPatternRewriter &rewriter) const override {
520  VectorType oldResultType = shuffleOp.getResultVectorType();
521  Type newResultType = getTypeConverter()->convertType(oldResultType);
522  if (!newResultType)
523  return rewriter.notifyMatchFailure(shuffleOp,
524  "unsupported result vector type");
525 
526  auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
527 
528  VectorType oldV1Type = shuffleOp.getV1VectorType();
529  VectorType oldV2Type = shuffleOp.getV2VectorType();
530 
531  // When both operands and the result are SPIR-V vectors, emit a SPIR-V
532  // shuffle.
533  if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
534  oldResultType.getNumElements() > 1) {
535  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
536  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
537  rewriter.getI32ArrayAttr(mask));
538  return success();
539  }
540 
541  // When at least one of the operands or the result becomes a scalar after
542  // type conversion for SPIR-V, extract all the required elements and
543  // construct the result vector.
544  auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
545  Value scalarOrVec, int32_t idx) -> Value {
546  if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
547  return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
548  idx);
549 
550  assert(idx == 0 && "Invalid scalar element index");
551  return scalarOrVec;
552  };
553 
554  int32_t numV1Elems = oldV1Type.getNumElements();
555  SmallVector<Value> newOperands(mask.size());
556  for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
557  Value vec = adaptor.getV1();
558  int32_t elementIdx = shuffleIdx;
559  if (elementIdx >= numV1Elems) {
560  vec = adaptor.getV2();
561  elementIdx -= numV1Elems;
562  }
563 
564  newOperand = getElementAtIdx(vec, elementIdx);
565  }
566 
567  // Handle the scalar result corner case.
568  if (newOperands.size() == 1) {
569  rewriter.replaceOp(shuffleOp, newOperands.front());
570  return success();
571  }
572 
573  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
574  shuffleOp, newResultType, newOperands);
575  return success();
576  }
577 };
578 
579 struct VectorInterleaveOpConvert final
580  : public OpConversionPattern<vector::InterleaveOp> {
582 
583  LogicalResult
584  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
585  ConversionPatternRewriter &rewriter) const override {
586  // Check the result vector type.
587  VectorType oldResultType = interleaveOp.getResultVectorType();
588  Type newResultType = getTypeConverter()->convertType(oldResultType);
589  if (!newResultType)
590  return rewriter.notifyMatchFailure(interleaveOp,
591  "unsupported result vector type");
592 
593  // Interleave the indices.
594  VectorType sourceType = interleaveOp.getSourceVectorType();
595  int n = sourceType.getNumElements();
596 
597  // Input vectors of size 1 are converted to scalars by the type converter.
598  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
599  // use `spirv::CompositeConstructOp`.
600  if (n == 1) {
601  Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
602  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
603  interleaveOp, newResultType, newOperands);
604  return success();
605  }
606 
607  auto seq = llvm::seq<int64_t>(2 * n);
608  auto indices = llvm::map_to_vector(
609  seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
610 
611  // Emit a SPIR-V shuffle.
612  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
613  interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
614  rewriter.getI32ArrayAttr(indices));
615 
616  return success();
617  }
618 };
619 
620 struct VectorDeinterleaveOpConvert final
621  : public OpConversionPattern<vector::DeinterleaveOp> {
623 
624  LogicalResult
625  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
626  ConversionPatternRewriter &rewriter) const override {
627 
628  // Check the result vector type.
629  VectorType oldResultType = deinterleaveOp.getResultVectorType();
630  Type newResultType = getTypeConverter()->convertType(oldResultType);
631  if (!newResultType)
632  return rewriter.notifyMatchFailure(deinterleaveOp,
633  "unsupported result vector type");
634 
635  Location loc = deinterleaveOp->getLoc();
636 
637  // Deinterleave the indices.
638  Value sourceVector = adaptor.getSource();
639  VectorType sourceType = deinterleaveOp.getSourceVectorType();
640  int n = sourceType.getNumElements();
641 
642  // Output vectors of size 1 are converted to scalars by the type converter.
643  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
644  // use `spirv::CompositeExtractOp`.
645  if (n == 2) {
646  auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
647  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
648 
649  auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
650  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
651 
652  rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
653  return success();
654  }
655 
656  // Indices for `shuffleEven` (result 0).
657  auto seqEven = llvm::seq<int64_t>(n / 2);
658  auto indicesEven =
659  llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
660 
661  // Indices for `shuffleOdd` (result 1).
662  auto seqOdd = llvm::seq<int64_t>(n / 2);
663  auto indicesOdd =
664  llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
665 
666  // Create two SPIR-V shuffles.
667  auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
668  loc, newResultType, sourceVector, sourceVector,
669  rewriter.getI32ArrayAttr(indicesEven));
670 
671  auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
672  loc, newResultType, sourceVector, sourceVector,
673  rewriter.getI32ArrayAttr(indicesOdd));
674 
675  rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
676  return success();
677  }
678 };
679 
680 struct VectorLoadOpConverter final
681  : public OpConversionPattern<vector::LoadOp> {
683 
684  LogicalResult
685  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
686  ConversionPatternRewriter &rewriter) const override {
687  auto memrefType = loadOp.getMemRefType();
688  auto attr =
689  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
690  if (!attr)
691  return rewriter.notifyMatchFailure(
692  loadOp, "expected spirv.storage_class memory space");
693 
694  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
695  auto loc = loadOp.getLoc();
696  Value accessChain =
697  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
698  adaptor.getIndices(), loc, rewriter);
699  if (!accessChain)
700  return rewriter.notifyMatchFailure(
701  loadOp, "failed to get memref element pointer");
702 
703  spirv::StorageClass storageClass = attr.getValue();
704  auto vectorType = loadOp.getVectorType();
705  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
706  Value castedAccessChain =
707  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
708  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
709  castedAccessChain);
710 
711  return success();
712  }
713 };
714 
715 struct VectorStoreOpConverter final
716  : public OpConversionPattern<vector::StoreOp> {
718 
719  LogicalResult
720  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
721  ConversionPatternRewriter &rewriter) const override {
722  auto memrefType = storeOp.getMemRefType();
723  auto attr =
724  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
725  if (!attr)
726  return rewriter.notifyMatchFailure(
727  storeOp, "expected spirv.storage_class memory space");
728 
729  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
730  auto loc = storeOp.getLoc();
731  Value accessChain =
732  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
733  adaptor.getIndices(), loc, rewriter);
734  if (!accessChain)
735  return rewriter.notifyMatchFailure(
736  storeOp, "failed to get memref element pointer");
737 
738  spirv::StorageClass storageClass = attr.getValue();
739  auto vectorType = storeOp.getVectorType();
740  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
741  Value castedAccessChain =
742  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
743  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
744  adaptor.getValueToStore());
745 
746  return success();
747  }
748 };
749 
750 struct VectorReductionToIntDotProd final
751  : OpRewritePattern<vector::ReductionOp> {
753 
754  LogicalResult matchAndRewrite(vector::ReductionOp op,
755  PatternRewriter &rewriter) const override {
756  if (op.getKind() != vector::CombiningKind::ADD)
757  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
758 
759  auto resultType = dyn_cast<IntegerType>(op.getType());
760  if (!resultType)
761  return rewriter.notifyMatchFailure(op, "result is not an integer");
762 
763  int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
764  if (!llvm::is_contained({32, 64}, resultBitwidth))
765  return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
766 
767  VectorType inVecTy = op.getSourceVectorType();
768  if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
769  inVecTy.getShape().size() != 1 || inVecTy.isScalable())
770  return rewriter.notifyMatchFailure(op, "unsupported vector shape");
771 
772  auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
773  if (!mul)
774  return rewriter.notifyMatchFailure(
775  op, "reduction operand is not 'arith.muli'");
776 
777  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
778  spirv::SDotAccSatOp, false>(op, mul, rewriter)))
779  return success();
780 
781  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
782  spirv::UDotAccSatOp, false>(op, mul, rewriter)))
783  return success();
784 
785  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
786  spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
787  return success();
788 
789  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
790  spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
791  return success();
792 
793  return failure();
794  }
795 
796 private:
797  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
798  typename DotAccOp, bool SwapOperands>
799  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
800  PatternRewriter &rewriter) {
801  auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
802  if (!lhs)
803  return failure();
804  Value lhsIn = lhs.getIn();
805  auto lhsInType = cast<VectorType>(lhsIn.getType());
806  if (!lhsInType.getElementType().isInteger(8))
807  return failure();
808 
809  auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
810  if (!rhs)
811  return failure();
812  Value rhsIn = rhs.getIn();
813  auto rhsInType = cast<VectorType>(rhsIn.getType());
814  if (!rhsInType.getElementType().isInteger(8))
815  return failure();
816 
817  if (op.getSourceVectorType().getNumElements() == 3) {
818  IntegerType i8Type = rewriter.getI8Type();
819  auto v4i8Type = VectorType::get({4}, i8Type);
820  Location loc = op.getLoc();
821  Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
822  lhsIn = rewriter.create<spirv::CompositeConstructOp>(
823  loc, v4i8Type, ValueRange{lhsIn, zero});
824  rhsIn = rewriter.create<spirv::CompositeConstructOp>(
825  loc, v4i8Type, ValueRange{rhsIn, zero});
826  }
827 
828  // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
829  // we have to swap operands instead in that case.
830  if (SwapOperands)
831  std::swap(lhsIn, rhsIn);
832 
833  if (Value acc = op.getAcc()) {
834  rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
835  nullptr);
836  } else {
837  rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
838  nullptr);
839  }
840 
841  return success();
842  }
843 };
844 
845 struct VectorReductionToFPDotProd final
846  : OpConversionPattern<vector::ReductionOp> {
848 
849  LogicalResult
850  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
851  ConversionPatternRewriter &rewriter) const override {
852  if (op.getKind() != vector::CombiningKind::ADD)
853  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
854 
855  auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
856  if (!resultType)
857  return rewriter.notifyMatchFailure(op, "result is not a float");
858 
859  Value vec = adaptor.getVector();
860  Value acc = adaptor.getAcc();
861 
862  auto vectorType = dyn_cast<VectorType>(vec.getType());
863  if (!vectorType) {
864  assert(isa<FloatType>(vec.getType()) &&
865  "Expected the vector to be scalarized");
866  if (acc) {
867  rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
868  return success();
869  }
870 
871  rewriter.replaceOp(op, vec);
872  return success();
873  }
874 
875  Location loc = op.getLoc();
876  Value lhs;
877  Value rhs;
878  if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
879  lhs = mul.getLhs();
880  rhs = mul.getRhs();
881  } else {
882  // If the operand is not a mul, use a vector of ones for the dot operand
883  // to just sum up all values.
884  lhs = vec;
885  Attribute oneAttr =
886  rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
887  oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
888  rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
889  }
890  assert(lhs);
891  assert(rhs);
892 
893  Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
894  if (acc)
895  res = rewriter.create<spirv::FAddOp>(loc, acc, res);
896 
897  rewriter.replaceOp(op, res);
898  return success();
899  }
900 };
901 
902 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
904 
905  LogicalResult
906  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
907  ConversionPatternRewriter &rewriter) const override {
908  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
909  Type dstType = typeConverter.convertType(stepOp.getType());
910  if (!dstType)
911  return failure();
912 
913  Location loc = stepOp.getLoc();
914  int64_t numElements = stepOp.getType().getNumElements();
915  auto intType =
916  rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
917 
918  // Input vectors of size 1 are converted to scalars by the type converter.
919  // We just create a constant in this case.
920  if (numElements == 1) {
921  Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
922  rewriter.replaceOp(stepOp, zero);
923  return success();
924  }
925 
926  SmallVector<Value> source;
927  source.reserve(numElements);
928  for (int64_t i = 0; i < numElements; ++i) {
929  Attribute intAttr = rewriter.getIntegerAttr(intType, i);
930  Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
931  source.push_back(constOp);
932  }
933  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
934  source);
935  return success();
936  }
937 };
938 
939 } // namespace
940 #define CL_INT_MAX_MIN_OPS \
941  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
942 
943 #define GL_INT_MAX_MIN_OPS \
944  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
945 
946 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
947 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
948 
950  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
951  patterns.add<
952  VectorBitcastConvert, VectorBroadcastConvert,
953  VectorExtractElementOpConvert, VectorExtractOpConvert,
954  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
955  VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
956  VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
957  VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
958  VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
959  VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
960  VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
961  VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
962  VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
963  VectorStepOpConvert>(typeConverter, patterns.getContext(),
964  PatternBenefit(1));
965 
966  // Make sure that the more specialized dot product pattern has higher benefit
967  // than the generic one that extracts all elements.
968  patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
969  PatternBenefit(2));
970 }
971 
973  RewritePatternSet &patterns) {
974  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
975 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
#define MINUI(lhs, rhs)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:316
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI8Type()
Definition: Builders.cpp:103
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362