MLIR  20.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Location.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/SmallVectorExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include <cassert>
36 #include <cstdint>
37 #include <numeric>
38 
39 using namespace mlir;
40 
41 /// Returns the integer value from the first valid input element, assuming Value
42 /// inputs are defined by a constant index ops and Attribute inputs are integer
43 /// attributes.
44 static uint64_t getFirstIntValue(ArrayAttr attr) {
45  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
46 }
47 
48 /// Returns the number of bits for the given scalar/vector type.
49 static int getNumBits(Type type) {
50  // TODO: This does not take into account any memory layout or widening
51  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
52  // though in practice it will likely be stored as in a 4xi64 vector register.
53  if (auto vectorType = dyn_cast<VectorType>(type))
54  return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
55  return type.getIntOrFloatBitWidth();
56 }
57 
58 namespace {
59 
60 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
62 
63  LogicalResult
64  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
65  ConversionPatternRewriter &rewriter) const override {
66  Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
67  if (!dstType)
68  return failure();
69 
70  // If dstType is same as the source type or the vector size is 1, it can be
71  // directly replaced by the source.
72  if (dstType == adaptor.getSource().getType() ||
73  shapeCastOp.getResultVectorType().getNumElements() == 1) {
74  rewriter.replaceOp(shapeCastOp, adaptor.getSource());
75  return success();
76  }
77 
78  // Lowering for size-n vectors when n > 1 hasn't been implemented.
79  return failure();
80  }
81 };
82 
83 struct VectorBitcastConvert final
84  : public OpConversionPattern<vector::BitCastOp> {
86 
87  LogicalResult
88  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const override {
90  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
91  if (!dstType)
92  return failure();
93 
94  if (dstType == adaptor.getSource().getType()) {
95  rewriter.replaceOp(bitcastOp, adaptor.getSource());
96  return success();
97  }
98 
99  // Check that the source and destination type have the same bitwidth.
100  // Depending on the target environment, we may need to emulate certain
101  // types, which can cause issue with bitcast.
102  Type srcType = adaptor.getSource().getType();
103  if (getNumBits(dstType) != getNumBits(srcType)) {
104  return rewriter.notifyMatchFailure(
105  bitcastOp,
106  llvm::formatv("different source ({0}) and target ({1}) bitwidth",
107  srcType, dstType));
108  }
109 
110  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
111  adaptor.getSource());
112  return success();
113  }
114 };
115 
116 struct VectorBroadcastConvert final
117  : public OpConversionPattern<vector::BroadcastOp> {
119 
120  LogicalResult
121  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override {
123  Type resultType =
124  getTypeConverter()->convertType(castOp.getResultVectorType());
125  if (!resultType)
126  return failure();
127 
128  if (isa<spirv::ScalarType>(resultType)) {
129  rewriter.replaceOp(castOp, adaptor.getSource());
130  return success();
131  }
132 
133  SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
134  adaptor.getSource());
135  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
136  source);
137  return success();
138  }
139 };
140 
141 struct VectorExtractOpConvert final
142  : public OpConversionPattern<vector::ExtractOp> {
144 
145  LogicalResult
146  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
147  ConversionPatternRewriter &rewriter) const override {
148  Type dstType = getTypeConverter()->convertType(extractOp.getType());
149  if (!dstType)
150  return failure();
151 
152  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
153  rewriter.replaceOp(extractOp, adaptor.getVector());
154  return success();
155  }
156 
157  if (std::optional<int64_t> id =
158  getConstantIntValue(extractOp.getMixedPosition()[0]))
159  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
160  extractOp, dstType, adaptor.getVector(),
161  rewriter.getI32ArrayAttr(id.value()));
162  else
163  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
164  extractOp, dstType, adaptor.getVector(),
165  adaptor.getDynamicPosition()[0]);
166  return success();
167  }
168 };
169 
170 struct VectorExtractStridedSliceOpConvert final
171  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
173 
174  LogicalResult
175  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
176  ConversionPatternRewriter &rewriter) const override {
177  Type dstType = getTypeConverter()->convertType(extractOp.getType());
178  if (!dstType)
179  return failure();
180 
181  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
182  uint64_t size = getFirstIntValue(extractOp.getSizes());
183  uint64_t stride = getFirstIntValue(extractOp.getStrides());
184  if (stride != 1)
185  return failure();
186 
187  Value srcVector = adaptor.getOperands().front();
188 
189  // Extract vector<1xT> case.
190  if (isa<spirv::ScalarType>(dstType)) {
191  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
192  srcVector, offset);
193  return success();
194  }
195 
196  SmallVector<int32_t, 2> indices(size);
197  std::iota(indices.begin(), indices.end(), offset);
198 
199  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
200  extractOp, dstType, srcVector, srcVector,
201  rewriter.getI32ArrayAttr(indices));
202 
203  return success();
204  }
205 };
206 
207 template <class SPIRVFMAOp>
208 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
210 
211  LogicalResult
212  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override {
214  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
215  if (!dstType)
216  return failure();
217  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
218  adaptor.getRhs(), adaptor.getAcc());
219  return success();
220  }
221 };
222 
223 struct VectorFromElementsOpConvert final
224  : public OpConversionPattern<vector::FromElementsOp> {
226 
227  LogicalResult
228  matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
229  ConversionPatternRewriter &rewriter) const override {
230  Type resultType = getTypeConverter()->convertType(op.getType());
231  if (!resultType)
232  return failure();
233  OperandRange elements = op.getElements();
234  if (isa<spirv::ScalarType>(resultType)) {
235  // In the case with a single scalar operand / single-element result,
236  // pass through the scalar.
237  rewriter.replaceOp(op, elements[0]);
238  return success();
239  }
240  // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
241  // vector.from_elements cases should not need to be handled, only 1d.
242  assert(cast<VectorType>(resultType).getRank() == 1);
243  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
244  elements);
245  return success();
246  }
247 };
248 
249 struct VectorInsertOpConvert final
250  : public OpConversionPattern<vector::InsertOp> {
252 
253  LogicalResult
254  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
255  ConversionPatternRewriter &rewriter) const override {
256  if (isa<VectorType>(insertOp.getSourceType()))
257  return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
258  if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
259  return rewriter.notifyMatchFailure(insertOp,
260  "unsupported dest vector type");
261 
262  // Special case for inserting scalar values into size-1 vectors.
263  if (insertOp.getSourceType().isIntOrFloat() &&
264  insertOp.getDestVectorType().getNumElements() == 1) {
265  rewriter.replaceOp(insertOp, adaptor.getSource());
266  return success();
267  }
268 
269  if (std::optional<int64_t> id =
270  getConstantIntValue(insertOp.getMixedPosition()[0]))
271  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
272  insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
273  else
274  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
275  insertOp, insertOp.getDest(), adaptor.getSource(),
276  adaptor.getDynamicPosition()[0]);
277  return success();
278  }
279 };
280 
281 struct VectorExtractElementOpConvert final
282  : public OpConversionPattern<vector::ExtractElementOp> {
284 
285  LogicalResult
286  matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
287  ConversionPatternRewriter &rewriter) const override {
288  Type resultType = getTypeConverter()->convertType(extractOp.getType());
289  if (!resultType)
290  return failure();
291 
292  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
293  rewriter.replaceOp(extractOp, adaptor.getVector());
294  return success();
295  }
296 
297  APInt cstPos;
298  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
299  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
300  extractOp, resultType, adaptor.getVector(),
301  rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
302  else
303  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
304  extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
305  return success();
306  }
307 };
308 
309 struct VectorInsertElementOpConvert final
310  : public OpConversionPattern<vector::InsertElementOp> {
312 
313  LogicalResult
314  matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
315  ConversionPatternRewriter &rewriter) const override {
316  Type vectorType = getTypeConverter()->convertType(insertOp.getType());
317  if (!vectorType)
318  return failure();
319 
320  if (isa<spirv::ScalarType>(vectorType)) {
321  rewriter.replaceOp(insertOp, adaptor.getSource());
322  return success();
323  }
324 
325  APInt cstPos;
326  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
327  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
328  insertOp, adaptor.getSource(), adaptor.getDest(),
329  cstPos.getSExtValue());
330  else
331  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
332  insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
333  adaptor.getPosition());
334  return success();
335  }
336 };
337 
338 struct VectorInsertStridedSliceOpConvert final
339  : public OpConversionPattern<vector::InsertStridedSliceOp> {
341 
342  LogicalResult
343  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
344  ConversionPatternRewriter &rewriter) const override {
345  Value srcVector = adaptor.getOperands().front();
346  Value dstVector = adaptor.getOperands().back();
347 
348  uint64_t stride = getFirstIntValue(insertOp.getStrides());
349  if (stride != 1)
350  return failure();
351  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
352 
353  if (isa<spirv::ScalarType>(srcVector.getType())) {
354  assert(!isa<spirv::ScalarType>(dstVector.getType()));
355  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
356  insertOp, dstVector.getType(), srcVector, dstVector,
357  rewriter.getI32ArrayAttr(offset));
358  return success();
359  }
360 
361  uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
362  uint64_t insertSize =
363  cast<VectorType>(srcVector.getType()).getNumElements();
364 
365  SmallVector<int32_t, 2> indices(totalSize);
366  std::iota(indices.begin(), indices.end(), 0);
367  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
368  totalSize);
369 
370  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
371  insertOp, dstVector.getType(), dstVector, srcVector,
372  rewriter.getI32ArrayAttr(indices));
373 
374  return success();
375  }
376 };
377 
378 static SmallVector<Value> extractAllElements(
379  vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
380  VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
381  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
382  SmallVector<Value> values;
383  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
384  Location loc = reduceOp.getLoc();
385 
386  for (int i = 0; i < numElements; ++i) {
387  values.push_back(rewriter.create<spirv::CompositeExtractOp>(
388  loc, srcVectorType.getElementType(), adaptor.getVector(),
389  rewriter.getI32ArrayAttr({i})));
390  }
391  if (Value acc = adaptor.getAcc())
392  values.push_back(acc);
393 
394  return values;
395 }
396 
397 struct ReductionRewriteInfo {
398  Type resultType;
399  SmallVector<Value> extractedElements;
400 };
401 
402 FailureOr<ReductionRewriteInfo> static getReductionInfo(
403  vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
404  ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
405  Type resultType = typeConverter.convertType(op.getType());
406  if (!resultType)
407  return failure();
408 
409  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
410  if (!srcVectorType || srcVectorType.getRank() != 1)
411  return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
412 
413  SmallVector<Value> extractedElements =
414  extractAllElements(op, adaptor, srcVectorType, rewriter);
415 
416  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
417 }
418 
419 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
420  typename SPIRVSMinOp>
421 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
423 
424  LogicalResult
425  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
426  ConversionPatternRewriter &rewriter) const override {
427  auto reductionInfo =
428  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
429  if (failed(reductionInfo))
430  return failure();
431 
432  auto [resultType, extractedElements] = *reductionInfo;
433  Location loc = reduceOp->getLoc();
434  Value result = extractedElements.front();
435  for (Value next : llvm::drop_begin(extractedElements)) {
436  switch (reduceOp.getKind()) {
437 
438 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
439  case vector::CombiningKind::kind: \
440  if (llvm::isa<IntegerType>(resultType)) { \
441  result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
442  } else { \
443  assert(llvm::isa<FloatType>(resultType)); \
444  result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
445  } \
446  break
447 
448 #define INT_OR_FLOAT_CASE(kind, fop) \
449  case vector::CombiningKind::kind: \
450  result = rewriter.create<fop>(loc, resultType, result, next); \
451  break
452 
453  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
454  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
455  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
456  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
457  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
458  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
459 
460  case vector::CombiningKind::AND:
461  case vector::CombiningKind::OR:
462  case vector::CombiningKind::XOR:
463  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
464  default:
465  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
466  }
467 #undef INT_AND_FLOAT_CASE
468 #undef INT_OR_FLOAT_CASE
469  }
470 
471  rewriter.replaceOp(reduceOp, result);
472  return success();
473  }
474 };
475 
476 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
477 struct VectorReductionFloatMinMax final
478  : OpConversionPattern<vector::ReductionOp> {
480 
481  LogicalResult
482  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
483  ConversionPatternRewriter &rewriter) const override {
484  auto reductionInfo =
485  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
486  if (failed(reductionInfo))
487  return failure();
488 
489  auto [resultType, extractedElements] = *reductionInfo;
490  Location loc = reduceOp->getLoc();
491  Value result = extractedElements.front();
492  for (Value next : llvm::drop_begin(extractedElements)) {
493  switch (reduceOp.getKind()) {
494 
495 #define INT_OR_FLOAT_CASE(kind, fop) \
496  case vector::CombiningKind::kind: \
497  result = rewriter.create<fop>(loc, resultType, result, next); \
498  break
499 
500  INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
501  INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
502  INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
503  INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
504 
505  default:
506  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
507  }
508 #undef INT_OR_FLOAT_CASE
509  }
510 
511  rewriter.replaceOp(reduceOp, result);
512  return success();
513  }
514 };
515 
516 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
517 public:
519 
520  LogicalResult
521  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
522  ConversionPatternRewriter &rewriter) const override {
523  Type dstType = getTypeConverter()->convertType(op.getType());
524  if (!dstType)
525  return failure();
526  if (isa<spirv::ScalarType>(dstType)) {
527  rewriter.replaceOp(op, adaptor.getInput());
528  } else {
529  auto dstVecType = cast<VectorType>(dstType);
530  SmallVector<Value, 4> source(dstVecType.getNumElements(),
531  adaptor.getInput());
532  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
533  source);
534  }
535  return success();
536  }
537 };
538 
539 struct VectorShuffleOpConvert final
540  : public OpConversionPattern<vector::ShuffleOp> {
542 
543  LogicalResult
544  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
545  ConversionPatternRewriter &rewriter) const override {
546  VectorType oldResultType = shuffleOp.getResultVectorType();
547  Type newResultType = getTypeConverter()->convertType(oldResultType);
548  if (!newResultType)
549  return rewriter.notifyMatchFailure(shuffleOp,
550  "unsupported result vector type");
551 
552  auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
553 
554  VectorType oldV1Type = shuffleOp.getV1VectorType();
555  VectorType oldV2Type = shuffleOp.getV2VectorType();
556 
557  // When both operands and the result are SPIR-V vectors, emit a SPIR-V
558  // shuffle.
559  if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
560  oldResultType.getNumElements() > 1) {
561  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
562  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
563  rewriter.getI32ArrayAttr(mask));
564  return success();
565  }
566 
567  // When at least one of the operands or the result becomes a scalar after
568  // type conversion for SPIR-V, extract all the required elements and
569  // construct the result vector.
570  auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
571  Value scalarOrVec, int32_t idx) -> Value {
572  if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
573  return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
574  idx);
575 
576  assert(idx == 0 && "Invalid scalar element index");
577  return scalarOrVec;
578  };
579 
580  int32_t numV1Elems = oldV1Type.getNumElements();
581  SmallVector<Value> newOperands(mask.size());
582  for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
583  Value vec = adaptor.getV1();
584  int32_t elementIdx = shuffleIdx;
585  if (elementIdx >= numV1Elems) {
586  vec = adaptor.getV2();
587  elementIdx -= numV1Elems;
588  }
589 
590  newOperand = getElementAtIdx(vec, elementIdx);
591  }
592 
593  // Handle the scalar result corner case.
594  if (newOperands.size() == 1) {
595  rewriter.replaceOp(shuffleOp, newOperands.front());
596  return success();
597  }
598 
599  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
600  shuffleOp, newResultType, newOperands);
601  return success();
602  }
603 };
604 
605 struct VectorInterleaveOpConvert final
606  : public OpConversionPattern<vector::InterleaveOp> {
608 
609  LogicalResult
610  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
611  ConversionPatternRewriter &rewriter) const override {
612  // Check the result vector type.
613  VectorType oldResultType = interleaveOp.getResultVectorType();
614  Type newResultType = getTypeConverter()->convertType(oldResultType);
615  if (!newResultType)
616  return rewriter.notifyMatchFailure(interleaveOp,
617  "unsupported result vector type");
618 
619  // Interleave the indices.
620  VectorType sourceType = interleaveOp.getSourceVectorType();
621  int n = sourceType.getNumElements();
622 
623  // Input vectors of size 1 are converted to scalars by the type converter.
624  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
625  // use `spirv::CompositeConstructOp`.
626  if (n == 1) {
627  Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
628  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
629  interleaveOp, newResultType, newOperands);
630  return success();
631  }
632 
633  auto seq = llvm::seq<int64_t>(2 * n);
634  auto indices = llvm::map_to_vector(
635  seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
636 
637  // Emit a SPIR-V shuffle.
638  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
639  interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
640  rewriter.getI32ArrayAttr(indices));
641 
642  return success();
643  }
644 };
645 
646 struct VectorDeinterleaveOpConvert final
647  : public OpConversionPattern<vector::DeinterleaveOp> {
649 
650  LogicalResult
651  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
652  ConversionPatternRewriter &rewriter) const override {
653 
654  // Check the result vector type.
655  VectorType oldResultType = deinterleaveOp.getResultVectorType();
656  Type newResultType = getTypeConverter()->convertType(oldResultType);
657  if (!newResultType)
658  return rewriter.notifyMatchFailure(deinterleaveOp,
659  "unsupported result vector type");
660 
661  Location loc = deinterleaveOp->getLoc();
662 
663  // Deinterleave the indices.
664  Value sourceVector = adaptor.getSource();
665  VectorType sourceType = deinterleaveOp.getSourceVectorType();
666  int n = sourceType.getNumElements();
667 
668  // Output vectors of size 1 are converted to scalars by the type converter.
669  // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
670  // use `spirv::CompositeExtractOp`.
671  if (n == 2) {
672  auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
673  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
674 
675  auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
676  loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
677 
678  rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
679  return success();
680  }
681 
682  // Indices for `shuffleEven` (result 0).
683  auto seqEven = llvm::seq<int64_t>(n / 2);
684  auto indicesEven =
685  llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
686 
687  // Indices for `shuffleOdd` (result 1).
688  auto seqOdd = llvm::seq<int64_t>(n / 2);
689  auto indicesOdd =
690  llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
691 
692  // Create two SPIR-V shuffles.
693  auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
694  loc, newResultType, sourceVector, sourceVector,
695  rewriter.getI32ArrayAttr(indicesEven));
696 
697  auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
698  loc, newResultType, sourceVector, sourceVector,
699  rewriter.getI32ArrayAttr(indicesOdd));
700 
701  rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
702  return success();
703  }
704 };
705 
706 struct VectorLoadOpConverter final
707  : public OpConversionPattern<vector::LoadOp> {
709 
710  LogicalResult
711  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
712  ConversionPatternRewriter &rewriter) const override {
713  auto memrefType = loadOp.getMemRefType();
714  auto attr =
715  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
716  if (!attr)
717  return rewriter.notifyMatchFailure(
718  loadOp, "expected spirv.storage_class memory space");
719 
720  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
721  auto loc = loadOp.getLoc();
722  Value accessChain =
723  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
724  adaptor.getIndices(), loc, rewriter);
725  if (!accessChain)
726  return rewriter.notifyMatchFailure(
727  loadOp, "failed to get memref element pointer");
728 
729  spirv::StorageClass storageClass = attr.getValue();
730  auto vectorType = loadOp.getVectorType();
731  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
732  Value castedAccessChain =
733  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
734  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
735  castedAccessChain);
736 
737  return success();
738  }
739 };
740 
741 struct VectorStoreOpConverter final
742  : public OpConversionPattern<vector::StoreOp> {
744 
745  LogicalResult
746  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
747  ConversionPatternRewriter &rewriter) const override {
748  auto memrefType = storeOp.getMemRefType();
749  auto attr =
750  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
751  if (!attr)
752  return rewriter.notifyMatchFailure(
753  storeOp, "expected spirv.storage_class memory space");
754 
755  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
756  auto loc = storeOp.getLoc();
757  Value accessChain =
758  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
759  adaptor.getIndices(), loc, rewriter);
760  if (!accessChain)
761  return rewriter.notifyMatchFailure(
762  storeOp, "failed to get memref element pointer");
763 
764  spirv::StorageClass storageClass = attr.getValue();
765  auto vectorType = storeOp.getVectorType();
766  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
767  Value castedAccessChain =
768  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
769  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
770  adaptor.getValueToStore());
771 
772  return success();
773  }
774 };
775 
776 struct VectorReductionToIntDotProd final
777  : OpRewritePattern<vector::ReductionOp> {
779 
780  LogicalResult matchAndRewrite(vector::ReductionOp op,
781  PatternRewriter &rewriter) const override {
782  if (op.getKind() != vector::CombiningKind::ADD)
783  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
784 
785  auto resultType = dyn_cast<IntegerType>(op.getType());
786  if (!resultType)
787  return rewriter.notifyMatchFailure(op, "result is not an integer");
788 
789  int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
790  if (!llvm::is_contained({32, 64}, resultBitwidth))
791  return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
792 
793  VectorType inVecTy = op.getSourceVectorType();
794  if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
795  inVecTy.getShape().size() != 1 || inVecTy.isScalable())
796  return rewriter.notifyMatchFailure(op, "unsupported vector shape");
797 
798  auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
799  if (!mul)
800  return rewriter.notifyMatchFailure(
801  op, "reduction operand is not 'arith.muli'");
802 
803  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
804  spirv::SDotAccSatOp, false>(op, mul, rewriter)))
805  return success();
806 
807  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
808  spirv::UDotAccSatOp, false>(op, mul, rewriter)))
809  return success();
810 
811  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
812  spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
813  return success();
814 
815  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
816  spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
817  return success();
818 
819  return failure();
820  }
821 
822 private:
823  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
824  typename DotAccOp, bool SwapOperands>
825  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
826  PatternRewriter &rewriter) {
827  auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
828  if (!lhs)
829  return failure();
830  Value lhsIn = lhs.getIn();
831  auto lhsInType = cast<VectorType>(lhsIn.getType());
832  if (!lhsInType.getElementType().isInteger(8))
833  return failure();
834 
835  auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
836  if (!rhs)
837  return failure();
838  Value rhsIn = rhs.getIn();
839  auto rhsInType = cast<VectorType>(rhsIn.getType());
840  if (!rhsInType.getElementType().isInteger(8))
841  return failure();
842 
843  if (op.getSourceVectorType().getNumElements() == 3) {
844  IntegerType i8Type = rewriter.getI8Type();
845  auto v4i8Type = VectorType::get({4}, i8Type);
846  Location loc = op.getLoc();
847  Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
848  lhsIn = rewriter.create<spirv::CompositeConstructOp>(
849  loc, v4i8Type, ValueRange{lhsIn, zero});
850  rhsIn = rewriter.create<spirv::CompositeConstructOp>(
851  loc, v4i8Type, ValueRange{rhsIn, zero});
852  }
853 
854  // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
855  // we have to swap operands instead in that case.
856  if (SwapOperands)
857  std::swap(lhsIn, rhsIn);
858 
859  if (Value acc = op.getAcc()) {
860  rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
861  nullptr);
862  } else {
863  rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
864  nullptr);
865  }
866 
867  return success();
868  }
869 };
870 
871 struct VectorReductionToFPDotProd final
872  : OpConversionPattern<vector::ReductionOp> {
874 
875  LogicalResult
876  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
877  ConversionPatternRewriter &rewriter) const override {
878  if (op.getKind() != vector::CombiningKind::ADD)
879  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
880 
881  auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
882  if (!resultType)
883  return rewriter.notifyMatchFailure(op, "result is not a float");
884 
885  Value vec = adaptor.getVector();
886  Value acc = adaptor.getAcc();
887 
888  auto vectorType = dyn_cast<VectorType>(vec.getType());
889  if (!vectorType) {
890  assert(isa<FloatType>(vec.getType()) &&
891  "Expected the vector to be scalarized");
892  if (acc) {
893  rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
894  return success();
895  }
896 
897  rewriter.replaceOp(op, vec);
898  return success();
899  }
900 
901  Location loc = op.getLoc();
902  Value lhs;
903  Value rhs;
904  if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
905  lhs = mul.getLhs();
906  rhs = mul.getRhs();
907  } else {
908  // If the operand is not a mul, use a vector of ones for the dot operand
909  // to just sum up all values.
910  lhs = vec;
911  Attribute oneAttr =
912  rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
913  oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
914  rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
915  }
916  assert(lhs);
917  assert(rhs);
918 
919  Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
920  if (acc)
921  res = rewriter.create<spirv::FAddOp>(loc, acc, res);
922 
923  rewriter.replaceOp(op, res);
924  return success();
925  }
926 };
927 
928 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
930 
931  LogicalResult
932  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
933  ConversionPatternRewriter &rewriter) const override {
934  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
935  Type dstType = typeConverter.convertType(stepOp.getType());
936  if (!dstType)
937  return failure();
938 
939  Location loc = stepOp.getLoc();
940  int64_t numElements = stepOp.getType().getNumElements();
941  auto intType =
942  rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
943 
944  // Input vectors of size 1 are converted to scalars by the type converter.
945  // We just create a constant in this case.
946  if (numElements == 1) {
947  Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
948  rewriter.replaceOp(stepOp, zero);
949  return success();
950  }
951 
952  SmallVector<Value> source;
953  source.reserve(numElements);
954  for (int64_t i = 0; i < numElements; ++i) {
955  Attribute intAttr = rewriter.getIntegerAttr(intType, i);
956  Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
957  source.push_back(constOp);
958  }
959  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
960  source);
961  return success();
962  }
963 };
964 
965 } // namespace
966 #define CL_INT_MAX_MIN_OPS \
967  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
968 
969 #define GL_INT_MAX_MIN_OPS \
970  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
971 
972 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
973 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
974 
976  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
977  patterns.add<
978  VectorBitcastConvert, VectorBroadcastConvert,
979  VectorExtractElementOpConvert, VectorExtractOpConvert,
980  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
981  VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
982  VectorInsertElementOpConvert, VectorInsertOpConvert,
983  VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
984  VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
985  VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
986  VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
987  VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
988  VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
989  VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
990  VectorStepOpConvert>(typeConverter, patterns.getContext(),
991  PatternBenefit(1));
992 
993  // Make sure that the more specialized dot product pattern has higher benefit
994  // than the generic one that extracts all elements.
995  patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
996  PatternBenefit(2));
997 }
998 
1001  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
1002 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
#define MINUI(lhs, rhs)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:316
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI8Type()
Definition: Builders.cpp:103
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362