MLIR  19.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/SmallVectorExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include <cassert>
36 #include <cstdint>
37 #include <numeric>
38 
39 using namespace mlir;
40 
41 /// Returns the integer value from the first valid input element, assuming Value
42 /// inputs are defined by a constant index ops and Attribute inputs are integer
43 /// attributes.
44 static uint64_t getFirstIntValue(ValueRange values) {
45  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
46 }
47 static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
48  return cast<IntegerAttr>(attr[0]).getInt();
49 }
50 static uint64_t getFirstIntValue(ArrayAttr attr) {
51  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
52 }
53 static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
54  auto attr = foldResults[0].dyn_cast<Attribute>();
55  if (attr)
56  return getFirstIntValue(attr);
57 
58  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
59 }
60 
61 /// Returns the number of bits for the given scalar/vector type.
62 static int getNumBits(Type type) {
63  // TODO: This does not take into account any memory layout or widening
64  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
65  // though in practice it will likely be stored as in a 4xi64 vector register.
66  if (auto vectorType = dyn_cast<VectorType>(type))
67  return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
68  return type.getIntOrFloatBitWidth();
69 }
70 
71 namespace {
72 
73 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
75 
77  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
78  ConversionPatternRewriter &rewriter) const override {
79  Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
80  if (!dstType)
81  return failure();
82 
83  // If dstType is same as the source type or the vector size is 1, it can be
84  // directly replaced by the source.
85  if (dstType == adaptor.getSource().getType() ||
86  shapeCastOp.getResultVectorType().getNumElements() == 1) {
87  rewriter.replaceOp(shapeCastOp, adaptor.getSource());
88  return success();
89  }
90 
91  // Lowering for size-n vectors when n > 1 hasn't been implemented.
92  return failure();
93  }
94 };
95 
96 struct VectorBitcastConvert final
97  : public OpConversionPattern<vector::BitCastOp> {
99 
101  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
102  ConversionPatternRewriter &rewriter) const override {
103  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
104  if (!dstType)
105  return failure();
106 
107  if (dstType == adaptor.getSource().getType()) {
108  rewriter.replaceOp(bitcastOp, adaptor.getSource());
109  return success();
110  }
111 
112  // Check that the source and destination type have the same bitwidth.
113  // Depending on the target environment, we may need to emulate certain
114  // types, which can cause issue with bitcast.
115  Type srcType = adaptor.getSource().getType();
116  if (getNumBits(dstType) != getNumBits(srcType)) {
117  return rewriter.notifyMatchFailure(
118  bitcastOp,
119  llvm::formatv("different source ({0}) and target ({1}) bitwidth",
120  srcType, dstType));
121  }
122 
123  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
124  adaptor.getSource());
125  return success();
126  }
127 };
128 
129 struct VectorBroadcastConvert final
130  : public OpConversionPattern<vector::BroadcastOp> {
132 
134  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
135  ConversionPatternRewriter &rewriter) const override {
136  Type resultType =
137  getTypeConverter()->convertType(castOp.getResultVectorType());
138  if (!resultType)
139  return failure();
140 
141  if (isa<spirv::ScalarType>(resultType)) {
142  rewriter.replaceOp(castOp, adaptor.getSource());
143  return success();
144  }
145 
146  SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
147  adaptor.getSource());
148  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
149  castOp, castOp.getResultVectorType(), source);
150  return success();
151  }
152 };
153 
154 struct VectorExtractOpConvert final
155  : public OpConversionPattern<vector::ExtractOp> {
157 
159  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
160  ConversionPatternRewriter &rewriter) const override {
161  if (extractOp.hasDynamicPosition())
162  return failure();
163 
164  Type dstType = getTypeConverter()->convertType(extractOp.getType());
165  if (!dstType)
166  return failure();
167 
168  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
169  rewriter.replaceOp(extractOp, adaptor.getVector());
170  return success();
171  }
172 
173  int32_t id = getFirstIntValue(extractOp.getMixedPosition());
174  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
175  extractOp, adaptor.getVector(), id);
176  return success();
177  }
178 };
179 
180 struct VectorExtractStridedSliceOpConvert final
181  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
183 
185  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const override {
187  Type dstType = getTypeConverter()->convertType(extractOp.getType());
188  if (!dstType)
189  return failure();
190 
191  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
192  uint64_t size = getFirstIntValue(extractOp.getSizes());
193  uint64_t stride = getFirstIntValue(extractOp.getStrides());
194  if (stride != 1)
195  return failure();
196 
197  Value srcVector = adaptor.getOperands().front();
198 
199  // Extract vector<1xT> case.
200  if (isa<spirv::ScalarType>(dstType)) {
201  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
202  srcVector, offset);
203  return success();
204  }
205 
206  SmallVector<int32_t, 2> indices(size);
207  std::iota(indices.begin(), indices.end(), offset);
208 
209  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
210  extractOp, dstType, srcVector, srcVector,
211  rewriter.getI32ArrayAttr(indices));
212 
213  return success();
214  }
215 };
216 
217 template <class SPIRVFMAOp>
218 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
220 
222  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
223  ConversionPatternRewriter &rewriter) const override {
224  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
225  if (!dstType)
226  return failure();
227  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
228  adaptor.getRhs(), adaptor.getAcc());
229  return success();
230  }
231 };
232 
233 struct VectorInsertOpConvert final
234  : public OpConversionPattern<vector::InsertOp> {
236 
238  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
239  ConversionPatternRewriter &rewriter) const override {
240  if (isa<VectorType>(insertOp.getSourceType()))
241  return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
242  if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
243  return rewriter.notifyMatchFailure(insertOp,
244  "unsupported dest vector type");
245 
246  // Special case for inserting scalar values into size-1 vectors.
247  if (insertOp.getSourceType().isIntOrFloat() &&
248  insertOp.getDestVectorType().getNumElements() == 1) {
249  rewriter.replaceOp(insertOp, adaptor.getSource());
250  return success();
251  }
252 
253  int32_t id = getFirstIntValue(insertOp.getMixedPosition());
254  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
255  insertOp, adaptor.getSource(), adaptor.getDest(), id);
256  return success();
257  }
258 };
259 
260 struct VectorExtractElementOpConvert final
261  : public OpConversionPattern<vector::ExtractElementOp> {
263 
265  matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
266  ConversionPatternRewriter &rewriter) const override {
267  Type resultType = getTypeConverter()->convertType(extractOp.getType());
268  if (!resultType)
269  return failure();
270 
271  if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
272  rewriter.replaceOp(extractOp, adaptor.getVector());
273  return success();
274  }
275 
276  APInt cstPos;
277  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
278  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
279  extractOp, resultType, adaptor.getVector(),
280  rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
281  else
282  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
283  extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
284  return success();
285  }
286 };
287 
288 struct VectorInsertElementOpConvert final
289  : public OpConversionPattern<vector::InsertElementOp> {
291 
293  matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
294  ConversionPatternRewriter &rewriter) const override {
295  Type vectorType = getTypeConverter()->convertType(insertOp.getType());
296  if (!vectorType)
297  return failure();
298 
299  if (isa<spirv::ScalarType>(vectorType)) {
300  rewriter.replaceOp(insertOp, adaptor.getSource());
301  return success();
302  }
303 
304  APInt cstPos;
305  if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
306  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
307  insertOp, adaptor.getSource(), adaptor.getDest(),
308  cstPos.getSExtValue());
309  else
310  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
311  insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
312  adaptor.getPosition());
313  return success();
314  }
315 };
316 
317 struct VectorInsertStridedSliceOpConvert final
318  : public OpConversionPattern<vector::InsertStridedSliceOp> {
320 
322  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
323  ConversionPatternRewriter &rewriter) const override {
324  Value srcVector = adaptor.getOperands().front();
325  Value dstVector = adaptor.getOperands().back();
326 
327  uint64_t stride = getFirstIntValue(insertOp.getStrides());
328  if (stride != 1)
329  return failure();
330  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
331 
332  if (isa<spirv::ScalarType>(srcVector.getType())) {
333  assert(!isa<spirv::ScalarType>(dstVector.getType()));
334  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
335  insertOp, dstVector.getType(), srcVector, dstVector,
336  rewriter.getI32ArrayAttr(offset));
337  return success();
338  }
339 
340  uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
341  uint64_t insertSize =
342  cast<VectorType>(srcVector.getType()).getNumElements();
343 
344  SmallVector<int32_t, 2> indices(totalSize);
345  std::iota(indices.begin(), indices.end(), 0);
346  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
347  totalSize);
348 
349  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
350  insertOp, dstVector.getType(), dstVector, srcVector,
351  rewriter.getI32ArrayAttr(indices));
352 
353  return success();
354  }
355 };
356 
357 static SmallVector<Value> extractAllElements(
358  vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
359  VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
360  int numElements = static_cast<int>(srcVectorType.getDimSize(0));
361  SmallVector<Value> values;
362  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
363  Location loc = reduceOp.getLoc();
364 
365  for (int i = 0; i < numElements; ++i) {
366  values.push_back(rewriter.create<spirv::CompositeExtractOp>(
367  loc, srcVectorType.getElementType(), adaptor.getVector(),
368  rewriter.getI32ArrayAttr({i})));
369  }
370  if (Value acc = adaptor.getAcc())
371  values.push_back(acc);
372 
373  return values;
374 }
375 
376 struct ReductionRewriteInfo {
377  Type resultType;
378  SmallVector<Value> extractedElements;
379 };
380 
381 FailureOr<ReductionRewriteInfo> static getReductionInfo(
382  vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
383  ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
384  Type resultType = typeConverter.convertType(op.getType());
385  if (!resultType)
386  return failure();
387 
388  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
389  if (!srcVectorType || srcVectorType.getRank() != 1)
390  return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
391 
392  SmallVector<Value> extractedElements =
393  extractAllElements(op, adaptor, srcVectorType, rewriter);
394 
395  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
396 }
397 
398 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
399  typename SPIRVSMinOp>
400 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
402 
404  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
405  ConversionPatternRewriter &rewriter) const override {
406  auto reductionInfo =
407  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
408  if (failed(reductionInfo))
409  return failure();
410 
411  auto [resultType, extractedElements] = *reductionInfo;
412  Location loc = reduceOp->getLoc();
413  Value result = extractedElements.front();
414  for (Value next : llvm::drop_begin(extractedElements)) {
415  switch (reduceOp.getKind()) {
416 
417 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
418  case vector::CombiningKind::kind: \
419  if (llvm::isa<IntegerType>(resultType)) { \
420  result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
421  } else { \
422  assert(llvm::isa<FloatType>(resultType)); \
423  result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
424  } \
425  break
426 
427 #define INT_OR_FLOAT_CASE(kind, fop) \
428  case vector::CombiningKind::kind: \
429  result = rewriter.create<fop>(loc, resultType, result, next); \
430  break
431 
432  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
433  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
434  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
435  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
436  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
437  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
438 
439  case vector::CombiningKind::AND:
440  case vector::CombiningKind::OR:
441  case vector::CombiningKind::XOR:
442  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
443  default:
444  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
445  }
446 #undef INT_AND_FLOAT_CASE
447 #undef INT_OR_FLOAT_CASE
448  }
449 
450  rewriter.replaceOp(reduceOp, result);
451  return success();
452  }
453 };
454 
455 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
456 struct VectorReductionFloatMinMax final
457  : OpConversionPattern<vector::ReductionOp> {
459 
461  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
462  ConversionPatternRewriter &rewriter) const override {
463  auto reductionInfo =
464  getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
465  if (failed(reductionInfo))
466  return failure();
467 
468  auto [resultType, extractedElements] = *reductionInfo;
469  Location loc = reduceOp->getLoc();
470  Value result = extractedElements.front();
471  for (Value next : llvm::drop_begin(extractedElements)) {
472  switch (reduceOp.getKind()) {
473 
474 #define INT_OR_FLOAT_CASE(kind, fop) \
475  case vector::CombiningKind::kind: \
476  result = rewriter.create<fop>(loc, resultType, result, next); \
477  break
478 
479  INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
480  INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
481  INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
482  INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
483 
484  default:
485  return rewriter.notifyMatchFailure(reduceOp, "not handled here");
486  }
487 #undef INT_OR_FLOAT_CASE
488  }
489 
490  rewriter.replaceOp(reduceOp, result);
491  return success();
492  }
493 };
494 
495 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
496 public:
498 
500  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
501  ConversionPatternRewriter &rewriter) const override {
502  Type dstType = getTypeConverter()->convertType(op.getType());
503  if (!dstType)
504  return failure();
505  if (isa<spirv::ScalarType>(dstType)) {
506  rewriter.replaceOp(op, adaptor.getInput());
507  } else {
508  auto dstVecType = cast<VectorType>(dstType);
509  SmallVector<Value, 4> source(dstVecType.getNumElements(),
510  adaptor.getInput());
511  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
512  source);
513  }
514  return success();
515  }
516 };
517 
518 struct VectorShuffleOpConvert final
519  : public OpConversionPattern<vector::ShuffleOp> {
521 
523  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
524  ConversionPatternRewriter &rewriter) const override {
525  auto oldResultType = shuffleOp.getResultVectorType();
526  Type newResultType = getTypeConverter()->convertType(oldResultType);
527  if (!newResultType)
528  return rewriter.notifyMatchFailure(shuffleOp,
529  "unsupported result vector type");
530 
531  SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
532  shuffleOp.getMask(), [](Attribute attr) -> int32_t {
533  return cast<IntegerAttr>(attr).getValue().getZExtValue();
534  });
535 
536  auto oldV1Type = shuffleOp.getV1VectorType();
537  auto oldV2Type = shuffleOp.getV2VectorType();
538 
539  // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
540  if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
541  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
542  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
543  rewriter.getI32ArrayAttr(mask));
544  return success();
545  }
546 
547  // When at least one of the operands becomes a scalar after type conversion
548  // for SPIR-V, extract all the required elements and construct the result
549  // vector.
550  auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
551  Value scalarOrVec, int32_t idx) -> Value {
552  if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
553  return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
554  idx);
555 
556  assert(idx == 0 && "Invalid scalar element index");
557  return scalarOrVec;
558  };
559 
560  int32_t numV1Elems = oldV1Type.getNumElements();
561  SmallVector<Value> newOperands(mask.size());
562  for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
563  Value vec = adaptor.getV1();
564  int32_t elementIdx = shuffleIdx;
565  if (elementIdx >= numV1Elems) {
566  vec = adaptor.getV2();
567  elementIdx -= numV1Elems;
568  }
569 
570  newOperand = getElementAtIdx(vec, elementIdx);
571  }
572 
573  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
574  shuffleOp, newResultType, newOperands);
575 
576  return success();
577  }
578 };
579 
580 struct VectorLoadOpConverter final
581  : public OpConversionPattern<vector::LoadOp> {
583 
585  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
586  ConversionPatternRewriter &rewriter) const override {
587  auto memrefType = loadOp.getMemRefType();
588  auto attr =
589  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
590  if (!attr)
591  return rewriter.notifyMatchFailure(
592  loadOp, "expected spirv.storage_class memory space");
593 
594  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
595  auto loc = loadOp.getLoc();
596  Value accessChain =
597  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
598  adaptor.getIndices(), loc, rewriter);
599  if (!accessChain)
600  return rewriter.notifyMatchFailure(
601  loadOp, "failed to get memref element pointer");
602 
603  spirv::StorageClass storageClass = attr.getValue();
604  auto vectorType = loadOp.getVectorType();
605  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
606  Value castedAccessChain =
607  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
608  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
609  castedAccessChain);
610 
611  return success();
612  }
613 };
614 
615 struct VectorStoreOpConverter final
616  : public OpConversionPattern<vector::StoreOp> {
618 
620  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
621  ConversionPatternRewriter &rewriter) const override {
622  auto memrefType = storeOp.getMemRefType();
623  auto attr =
624  dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
625  if (!attr)
626  return rewriter.notifyMatchFailure(
627  storeOp, "expected spirv.storage_class memory space");
628 
629  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
630  auto loc = storeOp.getLoc();
631  Value accessChain =
632  spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
633  adaptor.getIndices(), loc, rewriter);
634  if (!accessChain)
635  return rewriter.notifyMatchFailure(
636  storeOp, "failed to get memref element pointer");
637 
638  spirv::StorageClass storageClass = attr.getValue();
639  auto vectorType = storeOp.getVectorType();
640  auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
641  Value castedAccessChain =
642  rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
643  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
644  adaptor.getValueToStore());
645 
646  return success();
647  }
648 };
649 
650 struct VectorReductionToIntDotProd final
651  : OpRewritePattern<vector::ReductionOp> {
653 
654  LogicalResult matchAndRewrite(vector::ReductionOp op,
655  PatternRewriter &rewriter) const override {
656  if (op.getKind() != vector::CombiningKind::ADD)
657  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
658 
659  auto resultType = dyn_cast<IntegerType>(op.getType());
660  if (!resultType)
661  return rewriter.notifyMatchFailure(op, "result is not an integer");
662 
663  int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
664  if (!llvm::is_contained({32, 64}, resultBitwidth))
665  return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
666 
667  VectorType inVecTy = op.getSourceVectorType();
668  if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
669  inVecTy.getShape().size() != 1 || inVecTy.isScalable())
670  return rewriter.notifyMatchFailure(op, "unsupported vector shape");
671 
672  auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
673  if (!mul)
674  return rewriter.notifyMatchFailure(
675  op, "reduction operand is not 'arith.muli'");
676 
677  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
678  spirv::SDotAccSatOp, false>(op, mul, rewriter)))
679  return success();
680 
681  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
682  spirv::UDotAccSatOp, false>(op, mul, rewriter)))
683  return success();
684 
685  if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
686  spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
687  return success();
688 
689  if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
690  spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
691  return success();
692 
693  return failure();
694  }
695 
696 private:
697  template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
698  typename DotAccOp, bool SwapOperands>
699  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
700  PatternRewriter &rewriter) {
701  auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
702  if (!lhs)
703  return failure();
704  Value lhsIn = lhs.getIn();
705  auto lhsInType = cast<VectorType>(lhsIn.getType());
706  if (!lhsInType.getElementType().isInteger(8))
707  return failure();
708 
709  auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
710  if (!rhs)
711  return failure();
712  Value rhsIn = rhs.getIn();
713  auto rhsInType = cast<VectorType>(rhsIn.getType());
714  if (!rhsInType.getElementType().isInteger(8))
715  return failure();
716 
717  if (op.getSourceVectorType().getNumElements() == 3) {
718  IntegerType i8Type = rewriter.getI8Type();
719  auto v4i8Type = VectorType::get({4}, i8Type);
720  Location loc = op.getLoc();
721  Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
722  lhsIn = rewriter.create<spirv::CompositeConstructOp>(
723  loc, v4i8Type, ValueRange{lhsIn, zero});
724  rhsIn = rewriter.create<spirv::CompositeConstructOp>(
725  loc, v4i8Type, ValueRange{rhsIn, zero});
726  }
727 
728  // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
729  // we have to swap operands instead in that case.
730  if (SwapOperands)
731  std::swap(lhsIn, rhsIn);
732 
733  if (Value acc = op.getAcc()) {
734  rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
735  nullptr);
736  } else {
737  rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
738  nullptr);
739  }
740 
741  return success();
742  }
743 };
744 
745 struct VectorReductionToFPDotProd final
746  : OpConversionPattern<vector::ReductionOp> {
748 
750  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
751  ConversionPatternRewriter &rewriter) const override {
752  if (op.getKind() != vector::CombiningKind::ADD)
753  return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
754 
755  auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
756  if (!resultType)
757  return rewriter.notifyMatchFailure(op, "result is not a float");
758 
759  Value vec = adaptor.getVector();
760  Value acc = adaptor.getAcc();
761 
762  auto vectorType = dyn_cast<VectorType>(vec.getType());
763  if (!vectorType) {
764  assert(isa<FloatType>(vec.getType()) &&
765  "Expected the vector to be scalarized");
766  if (acc) {
767  rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
768  return success();
769  }
770 
771  rewriter.replaceOp(op, vec);
772  return success();
773  }
774 
775  Location loc = op.getLoc();
776  Value lhs;
777  Value rhs;
778  if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
779  lhs = mul.getLhs();
780  rhs = mul.getRhs();
781  } else {
782  // If the operand is not a mul, use a vector of ones for the dot operand
783  // to just sum up all values.
784  lhs = vec;
785  Attribute oneAttr =
786  rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
787  oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
788  rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
789  }
790  assert(lhs);
791  assert(rhs);
792 
793  Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
794  if (acc)
795  res = rewriter.create<spirv::FAddOp>(loc, acc, res);
796 
797  rewriter.replaceOp(op, res);
798  return success();
799  }
800 };
801 
802 } // namespace
803 #define CL_INT_MAX_MIN_OPS \
804  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
805 
806 #define GL_INT_MAX_MIN_OPS \
807  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
808 
809 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
810 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
811 
813  RewritePatternSet &patterns) {
814  patterns.add<
815  VectorBitcastConvert, VectorBroadcastConvert,
816  VectorExtractElementOpConvert, VectorExtractOpConvert,
817  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
818  VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
819  VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
820  VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
821  VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
822  VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
823  VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
824  VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
825  typeConverter, patterns.getContext(), PatternBenefit(1));
826 
827  // Make sure that the more specialized dot product pattern has higher benefit
828  // than the generic one that extracts all elements.
829  patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
830  PatternBenefit(2));
831 }
832 
834  RewritePatternSet &patterns) {
835  patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
836 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
#define MINUI(lhs, rhs)
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:283
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI8Type()
Definition: Builders.cpp:79
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361