MLIR  19.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include <cstdint>
27 
28 using namespace mlir;
29 
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
34 
35 /// Returns a compressed mask. The mask value is set only if any mask is present
36 /// in the scale range. E.g., if `scale` equals to 2, the following mask:
37 ///
38 /// %mask = [1, 1, 1, 0, 0, 0]
39 ///
40 /// will return the following new compressed mask:
41 ///
42 /// %mask = [1, 1, 0]
44  Location loc, Value mask,
45  int origElements, int scale) {
46  auto numElements = (origElements + scale - 1) / scale;
47 
48  Operation *maskOp = mask.getDefiningOp();
50  // Finding the mask creation operation.
51  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53  maskOp = extractOp.getVector().getDefiningOp();
54  extractOps.push_back(extractOp);
55  }
56  }
57  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59  if (!createMaskOp && !constantMaskOp)
60  return failure();
61 
62  // Computing the "compressed" mask. All the emulation logic (i.e. computing
63  // new mask index) only happens on the last dimension of the vectors.
64  Operation *newMask = nullptr;
66  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
67  shape.back() = numElements;
68  auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
69  if (createMaskOp) {
70  OperandRange maskOperands = createMaskOp.getOperands();
71  size_t numMaskOperands = maskOperands.size();
72  AffineExpr s0;
73  bindSymbols(rewriter.getContext(), s0);
74  s0 = s0 + scale - 1;
75  s0 = s0.floorDiv(scale);
76  OpFoldResult origIndex =
77  getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
78  OpFoldResult maskIndex =
79  affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
80  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
81  newMaskOperands.push_back(
82  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
83  newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
84  newMaskOperands);
85  } else if (constantMaskOp) {
86  ArrayRef<Attribute> maskDimSizes =
87  constantMaskOp.getMaskDimSizes().getValue();
88  size_t numMaskOperands = maskDimSizes.size();
89  auto origIndex =
90  cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
91  IntegerAttr maskIndexAttr =
92  rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
93  SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
94  newMaskDimSizes.push_back(maskIndexAttr);
95  newMask = rewriter.create<vector::ConstantMaskOp>(
96  loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
97  }
98 
99  while (!extractOps.empty()) {
100  newMask = rewriter.create<vector::ExtractOp>(
101  loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
102  extractOps.pop_back();
103  }
104 
105  return newMask;
106 }
107 
108 namespace {
109 
110 //===----------------------------------------------------------------------===//
111 // ConvertVectorStore
112 //===----------------------------------------------------------------------===//
113 
114 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
116 
118  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
119  ConversionPatternRewriter &rewriter) const override {
120 
121  auto loc = op.getLoc();
122  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
123  Type oldElementType = op.getValueToStore().getType().getElementType();
124  Type newElementType = convertedType.getElementType();
125  int srcBits = oldElementType.getIntOrFloatBitWidth();
126  int dstBits = newElementType.getIntOrFloatBitWidth();
127 
128  if (dstBits % srcBits != 0) {
129  return rewriter.notifyMatchFailure(
130  op, "only dstBits % srcBits == 0 supported");
131  }
132  int scale = dstBits / srcBits;
133 
134  // Adjust the number of elements to store when emulating narrow types.
135  // Here only the 1-D vector store is considered, and the N-D memref types
136  // should be linearized.
137  // For example, to emulate i4 to i8, the following op:
138  //
139  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
140  //
141  // can be replaced with
142  //
143  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
144  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
145  // vector<4xi8>
146 
147  auto origElements = op.getValueToStore().getType().getNumElements();
148  if (origElements % scale != 0)
149  return failure();
150 
151  auto stridedMetadata =
152  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
153 
154  OpFoldResult linearizedIndices;
155  std::tie(std::ignore, linearizedIndices) =
157  rewriter, loc, srcBits, dstBits,
158  stridedMetadata.getConstifiedMixedOffset(),
159  stridedMetadata.getConstifiedMixedSizes(),
160  stridedMetadata.getConstifiedMixedStrides(),
161  getAsOpFoldResult(adaptor.getIndices()));
162 
163  auto numElements = origElements / scale;
164  auto bitCast = rewriter.create<vector::BitCastOp>(
165  loc, VectorType::get(numElements, newElementType),
166  op.getValueToStore());
167 
168  rewriter.replaceOpWithNewOp<vector::StoreOp>(
169  op, bitCast.getResult(), adaptor.getBase(),
170  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
171  return success();
172  }
173 };
174 
175 //===----------------------------------------------------------------------===//
176 // ConvertVectorMaskedStore
177 //===----------------------------------------------------------------------===//
178 
179 struct ConvertVectorMaskedStore final
180  : OpConversionPattern<vector::MaskedStoreOp> {
182 
184  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override {
186 
187  auto loc = op.getLoc();
188  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
189  Type oldElementType = op.getValueToStore().getType().getElementType();
190  Type newElementType = convertedType.getElementType();
191  int srcBits = oldElementType.getIntOrFloatBitWidth();
192  int dstBits = newElementType.getIntOrFloatBitWidth();
193 
194  if (dstBits % srcBits != 0) {
195  return rewriter.notifyMatchFailure(
196  op, "only dstBits % srcBits == 0 supported");
197  }
198 
199  int scale = dstBits / srcBits;
200  int origElements = op.getValueToStore().getType().getNumElements();
201  if (origElements % scale != 0)
202  return failure();
203 
204  auto stridedMetadata =
205  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
206  OpFoldResult linearizedIndicesOfr;
207  std::tie(std::ignore, linearizedIndicesOfr) =
209  rewriter, loc, srcBits, dstBits,
210  stridedMetadata.getConstifiedMixedOffset(),
211  stridedMetadata.getConstifiedMixedSizes(),
212  stridedMetadata.getConstifiedMixedStrides(),
213  getAsOpFoldResult(adaptor.getIndices()));
214  Value linearizedIndices =
215  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
216 
217  // Load the whole data and use arith.select to handle the corner cases.
218  // E.g., given these input values:
219  //
220  // %mask = [1, 1, 1, 0, 0, 0]
221  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
222  // %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
223  //
224  // we'll have
225  //
226  // expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
227  //
228  // %new_mask = [1, 1, 0]
229  // %maskedload = [0x12, 0x34, 0x0]
230  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
231  // %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
232  // %packed_data = [0x78, 0x94, 0x00]
233  //
234  // Using the new mask to store %packed_data results in expected output.
235  FailureOr<Operation *> newMask =
236  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
237  if (failed(newMask))
238  return failure();
239 
240  auto numElements = (origElements + scale - 1) / scale;
241  auto newType = VectorType::get(numElements, newElementType);
242  auto passThru = rewriter.create<arith::ConstantOp>(
243  loc, newType, rewriter.getZeroAttr(newType));
244 
245  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
246  loc, newType, adaptor.getBase(), linearizedIndices,
247  newMask.value()->getResult(0), passThru);
248 
249  Value valueToStore = rewriter.create<vector::BitCastOp>(
250  loc, op.getValueToStore().getType(), newLoad);
251  valueToStore = rewriter.create<arith::SelectOp>(
252  loc, op.getMask(), op.getValueToStore(), valueToStore);
253  valueToStore =
254  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
255 
256  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
257  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
258  valueToStore);
259  return success();
260  }
261 };
262 
263 //===----------------------------------------------------------------------===//
264 // ConvertVectorLoad
265 //===----------------------------------------------------------------------===//
266 
267 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
269 
271  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
272  ConversionPatternRewriter &rewriter) const override {
273 
274  auto loc = op.getLoc();
275  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
276  Type oldElementType = op.getType().getElementType();
277  Type newElementType = convertedType.getElementType();
278  int srcBits = oldElementType.getIntOrFloatBitWidth();
279  int dstBits = newElementType.getIntOrFloatBitWidth();
280 
281  if (dstBits % srcBits != 0) {
282  return rewriter.notifyMatchFailure(
283  op, "only dstBits % srcBits == 0 supported");
284  }
285  int scale = dstBits / srcBits;
286 
287  // Adjust the number of elements to load when emulating narrow types,
288  // and then cast back to the original type with vector.bitcast op.
289  // Here only the 1-D vector load is considered, and the N-D memref types
290  // should be linearized.
291  // For example, to emulate i4 to i8, the following op:
292  //
293  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
294  //
295  // can be replaced with
296  //
297  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
298  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
299  //
300  // TODO: Currently, only the even number of elements loading is supported.
301  // To deal with the odd number of elements, one has to extract the
302  // subvector at the proper offset after bit-casting.
303 
304  auto origElements = op.getVectorType().getNumElements();
305  if (origElements % scale != 0)
306  return failure();
307 
308  auto stridedMetadata =
309  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
310 
311  OpFoldResult linearizedIndices;
312  std::tie(std::ignore, linearizedIndices) =
314  rewriter, loc, srcBits, dstBits,
315  stridedMetadata.getConstifiedMixedOffset(),
316  stridedMetadata.getConstifiedMixedSizes(),
317  stridedMetadata.getConstifiedMixedStrides(),
318  getAsOpFoldResult(adaptor.getIndices()));
319 
320  auto numElements = (origElements + scale - 1) / scale;
321  auto newLoad = rewriter.create<vector::LoadOp>(
322  loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
323  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
324 
325  auto bitCast =
326  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
327 
328  rewriter.replaceOp(op, bitCast->getResult(0));
329  return success();
330  }
331 };
332 
333 //===----------------------------------------------------------------------===//
334 // ConvertVectorMaskedLoad
335 //===----------------------------------------------------------------------===//
336 
337 struct ConvertVectorMaskedLoad final
338  : OpConversionPattern<vector::MaskedLoadOp> {
340 
342  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344 
345  auto loc = op.getLoc();
346  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
347  Type oldElementType = op.getType().getElementType();
348  Type newElementType = convertedType.getElementType();
349  int srcBits = oldElementType.getIntOrFloatBitWidth();
350  int dstBits = newElementType.getIntOrFloatBitWidth();
351 
352  if (dstBits % srcBits != 0) {
353  return rewriter.notifyMatchFailure(
354  op, "only dstBits % srcBits == 0 supported");
355  }
356  int scale = dstBits / srcBits;
357 
358  // Adjust the number of elements to load when emulating narrow types,
359  // and then cast back to the original type with vector.bitcast op.
360  // For example, to emulate i4 to i8, the following op:
361  //
362  // %mask = vector.constant_mask [3] : vector<6xi1>
363  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
364  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
365  //
366  // can be replaced with
367  //
368  // %new_mask = vector.constant_mask [2] : vector<3xi1>
369  // %new_pass_thru = vector.bitcast %pass_thru :
370  // vector<6xi4> to vector<3xi8>
371  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
372  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
373  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
374  //
375  // Since we are effectively loading 16 bits (2xi8) from the memref with the
376  // new mask, while originally we only wanted to effectively load 12 bits
377  // (3xi4) from the memref, we need to set the second half of the last i8
378  // that was effectively loaded (i.e. the second i8) to %pass_thru.
379  //
380  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
381  //
382  // Given these input values:
383  // %mask = [1, 1, 1, 0, 0, 0]
384  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
385  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
386  //
387  // we'll have:
388  //
389  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
390  //
391  // %new_mask = [1, 1, 0]
392  // %new_pass_thru = [0x78, 0x9A, 0xBC]
393  // %1 = [0x12, 0x34, 0xBC]
394  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
395  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
396  //
397  // TODO: Currently, only the even number of elements loading is supported.
398  // To deal with the odd number of elements, one has to extract the
399  // subvector at the proper offset after bit-casting.
400  auto origType = op.getVectorType();
401  auto origElements = origType.getNumElements();
402  if (origElements % scale != 0)
403  return failure();
404 
405  auto stridedMetadata =
406  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
407  OpFoldResult linearizedIndices;
408  std::tie(std::ignore, linearizedIndices) =
410  rewriter, loc, srcBits, dstBits,
411  stridedMetadata.getConstifiedMixedOffset(),
412  stridedMetadata.getConstifiedMixedSizes(),
413  stridedMetadata.getConstifiedMixedStrides(),
414  getAsOpFoldResult(adaptor.getIndices()));
415 
416  FailureOr<Operation *> newMask =
417  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
418  if (failed(newMask))
419  return failure();
420 
421  auto numElements = (origElements + scale - 1) / scale;
422  auto newType = VectorType::get(numElements, newElementType);
423  auto newPassThru =
424  rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
425 
426  // Generating the new masked load.
427  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
428  loc, newType, adaptor.getBase(),
429  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
430  newMask.value()->getResult(0), newPassThru);
431 
432  // Setting the part that originally was not effectively loaded from memory
433  // to pass through.
434  auto bitCast =
435  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
436  auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
437  op.getPassThru());
438  rewriter.replaceOp(op, select->getResult(0));
439 
440  return success();
441  }
442 };
443 
444 //===----------------------------------------------------------------------===//
445 // ConvertVectorTransferRead
446 //===----------------------------------------------------------------------===//
447 
448 struct ConvertVectorTransferRead final
449  : OpConversionPattern<vector::TransferReadOp> {
451 
453  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
454  ConversionPatternRewriter &rewriter) const override {
455 
456  auto loc = op.getLoc();
457  auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
458  Type oldElementType = op.getType().getElementType();
459  Type newElementType = convertedType.getElementType();
460  int srcBits = oldElementType.getIntOrFloatBitWidth();
461  int dstBits = newElementType.getIntOrFloatBitWidth();
462 
463  if (dstBits % srcBits != 0) {
464  return rewriter.notifyMatchFailure(
465  op, "only dstBits % srcBits == 0 supported");
466  }
467  int scale = dstBits / srcBits;
468 
469  auto origElements = op.getVectorType().getNumElements();
470  if (origElements % scale != 0)
471  return failure();
472 
473  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
474  adaptor.getPadding());
475 
476  auto stridedMetadata =
477  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
478 
479  OpFoldResult linearizedIndices;
480  std::tie(std::ignore, linearizedIndices) =
482  rewriter, loc, srcBits, dstBits,
483  stridedMetadata.getConstifiedMixedOffset(),
484  stridedMetadata.getConstifiedMixedSizes(),
485  stridedMetadata.getConstifiedMixedStrides(),
486  getAsOpFoldResult(adaptor.getIndices()));
487 
488  auto numElements = (origElements + scale - 1) / scale;
489  auto newReadType = VectorType::get(numElements, newElementType);
490 
491  auto newRead = rewriter.create<vector::TransferReadOp>(
492  loc, newReadType, adaptor.getSource(),
493  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
494  newPadding);
495 
496  auto bitCast =
497  rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
498 
499  rewriter.replaceOp(op, bitCast->getResult(0));
500  return success();
501  }
502 };
503 } // end anonymous namespace
504 
505 //===----------------------------------------------------------------------===//
506 // RewriteBitCastOfTruncI
507 //===----------------------------------------------------------------------===//
508 
509 namespace {
510 
511 /// Helper struct to keep track of the provenance of a contiguous set of bits
512 /// in a source vector.
513 struct SourceElementRange {
514  /// The index of the source vector element that contributes bits to *this.
515  int64_t sourceElementIdx;
516  /// The range of bits in the source vector element that contribute to *this.
517  int64_t sourceBitBegin;
518  int64_t sourceBitEnd;
519 };
520 
521 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
522  /// Given the index of a SourceElementRange in the SourceElementRangeList,
523  /// compute the amount of bits that need to be shifted to the left to get the
524  /// bits in their final location. This shift amount is simply the sum of the
525  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
526  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
527  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
528  int64_t res = 0;
529  for (int64_t i = 0; i < shuffleIdx; ++i)
530  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
531  return res;
532  }
533 };
534 
535 /// Helper struct to enumerate the source elements and bit ranges that are
536 /// involved in a bitcast operation.
537 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
538 /// any 1-D vector shape and any source/target bitwidths.
539 /// This creates and holds a mapping of the form:
540 /// [dstVectorElementJ] ==
541 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
542 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
543 /// [0] = {0, [0-8)}
544 /// [1] = {0, [8-16)}
545 /// [2] = {0, [16-24)}
546 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
547 /// [0] = {0, [0, 10)}, {1, [0, 5)}
548 /// [1] = {1, [5, 10)}, {2, [0, 10)}
549 struct BitCastBitsEnumerator {
550  BitCastBitsEnumerator(VectorType sourceVectorType,
551  VectorType targetVectorType);
552 
553  int64_t getMaxNumberOfEntries() {
554  int64_t numVectors = 0;
555  for (const auto &l : sourceElementRanges)
556  numVectors = std::max(numVectors, (int64_t)l.size());
557  return numVectors;
558  }
559 
560  VectorType sourceVectorType;
561  VectorType targetVectorType;
562  SmallVector<SourceElementRangeList> sourceElementRanges;
563 };
564 
565 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
566 /// advantage of high-level information to avoid leaving LLVM to scramble with
567 /// peephole optimizations.
568 /// BitCastBitsEnumerator encodes for each element of the target vector the
569 /// provenance of the bits in the source vector. We can "transpose" this
570 /// information to build a sequence of shuffles and bitwise ops that will
571 /// produce the desired result.
572 //
573 /// Consider the following motivating example:
574 /// ```
575 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
576 /// ```
577 //
578 /// BitCastBitsEnumerator contains the following information:
579 /// ```
580 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
581 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
582 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
583 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
584 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
585 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
586 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
587 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
588 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
589 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
590 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
591 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
592 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
593 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
594 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
595 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
596 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
597 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
598 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
599 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
600 /// ```
601 ///
602 /// In the above, each row represents one target vector element and each
603 /// column represents one bit contribution from a source vector element.
604 /// The algorithm creates vector.shuffle operations (in this case there are 3
605 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
606 /// algorithm populates the bits as follows:
607 /// ```
608 /// src bits 0 ...
609 /// 1st shuffle |xxxxx |xx |...
610 /// 2nd shuffle | xxx| xxxxx |...
611 /// 3rd shuffle | | x|...
612 /// ```
613 //
614 /// The algorithm proceeds as follows:
615 /// 1. for each vector.shuffle, collect the source vectors that participate in
616 /// this shuffle. One source vector per target element of the resulting
617 /// vector.shuffle. If there is no source element contributing bits for the
618 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
619 /// 2 columns).
620 /// 2. represent the bitrange in the source vector as a mask. If there is no
621 /// source element contributing bits for the current vector.shuffle, take 0.
622 /// 3. shift right by the proper amount to align the source bitrange at
623 /// position 0. This is exactly the low end of the bitrange. For instance,
624 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
625 /// shift right by 3 to get the bits contributed by the source element #1
626 /// into position 0.
627 /// 4. shift left by the proper amount to to align to the desired position in
628 /// the result element vector. For instance, the contribution of the second
629 /// source element for the first row needs to be shifted by `5` to form the
630 /// first i8 result element.
631 ///
632 /// Eventually, we end up building the sequence
633 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
634 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
635 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
636 struct BitCastRewriter {
637  /// Helper metadata struct to hold the static quantities for the rewrite.
638  struct Metadata {
639  SmallVector<int64_t> shuffles;
640  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
641  };
642 
643  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
644 
645  /// Verify that general preconditions for the rewrite are met.
646  LogicalResult commonPrecondition(PatternRewriter &rewriter,
647  VectorType preconditionType, Operation *op);
648 
649  /// Precompute the metadata for the rewrite.
651  precomputeMetadata(IntegerType shuffledElementType);
652 
653  /// Rewrite one step of the sequence:
654  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
656  Value initialValue, Value runningResult,
657  const BitCastRewriter::Metadata &metadata);
658 
659 private:
660  /// Underlying enumerator that encodes the provenance of the bits in the each
661  /// element of the result vector.
662  BitCastBitsEnumerator enumerator;
663 };
664 
665 } // namespace
666 
667 [[maybe_unused]] static raw_ostream &
668 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
669  for (const auto &l : vec) {
670  for (auto it : llvm::enumerate(l)) {
671  os << "{ " << it.value().sourceElementIdx << ": b@["
672  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
673  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
674  }
675  os << "\n";
676  }
677  return os;
678 }
679 
680 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
681  VectorType targetVectorType)
682  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
683 
684  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
685  "requires -D non-scalable vector type");
686  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
687  "requires -D non-scalable vector type");
688  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
689  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
690  LDBG("sourceVectorType: " << sourceVectorType);
691 
692  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
693  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
694  LDBG("targetVectorType: " << targetVectorType);
695 
696  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
697  (void)mostMinorSourceDim;
698  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
699  "source and target bitwidths must match");
700 
701  // Prepopulate one source element range per target element.
702  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
703  for (int64_t resultBit = 0; resultBit < bitwidth;) {
704  int64_t resultElement = resultBit / targetBitWidth;
705  int64_t resultBitInElement = resultBit % targetBitWidth;
706  int64_t sourceElementIdx = resultBit / sourceBitWidth;
707  int64_t sourceBitInElement = resultBit % sourceBitWidth;
708  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
709  targetBitWidth - resultBitInElement);
710  sourceElementRanges[resultElement].push_back(
711  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
712  resultBit += step;
713  }
714 }
715 
716 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
717  VectorType targetVectorType)
718  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
719  LDBG("\n" << enumerator.sourceElementRanges);
720 }
721 
722 /// Verify that the precondition type meets the common preconditions for any
723 /// conversion.
725  VectorType preconditionType,
726  Operation *op) {
727  if (!preconditionType || preconditionType.isScalable())
728  return rewriter.notifyMatchFailure(op, "scalable vector");
729 
730  // TODO: consider relaxing this restriction in the future if we find ways
731  // to really work with subbyte elements across the MLIR/LLVM boundary.
732  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
733  if (bitwidth % 8 != 0)
734  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
735 
736  return success();
737 }
738 
739 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
740  VectorType preconditionType,
741  Operation *op) {
742  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
743  return rewriter.notifyMatchFailure(op, "types are not vector");
744 
745  if (!preconditionType || preconditionType.getRank() != 1)
746  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
747 
748  return commonConversionPrecondition(rewriter, preconditionType, op);
749 }
750 
751 /// Verify that source and destination element types meet the precondition for
752 /// the supported aligned conversion cases. Alignment means that the either the
753 /// source element type is multiple of the destination element type or the other
754 /// way around.
755 ///
756 /// NOTE: This method assumes that common conversion preconditions are met.
758  VectorType srcType,
759  VectorType dstType,
760  Operation *op) {
761  if (!srcType || !dstType)
762  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
763  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
764  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
765 
766  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
767  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
768  (dstElemBitwidth % srcElemBitwidth) != 0)
769  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
770 
771  if ((srcType.getShape().back() % 2) != 0)
772  return rewriter.notifyMatchFailure(
773  op, "Not an even number of i4 elements in trailing dim");
774 
775  return success();
776 }
777 
779 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
781  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
782  shuffleIdx < e; ++shuffleIdx) {
783  SmallVector<int64_t> shuffles;
784  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
785 
786  // Create the attribute quantities for the shuffle / mask / shift ops.
787  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
788  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
789  ? srcEltRangeList[shuffleIdx].sourceElementIdx
790  : 0;
791  shuffles.push_back(sourceElement);
792 
793  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
794  ? srcEltRangeList[shuffleIdx].sourceBitBegin
795  : 0;
796  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
797  ? srcEltRangeList[shuffleIdx].sourceBitEnd
798  : 0;
799  IntegerAttr mask = IntegerAttr::get(
800  shuffledElementType,
801  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
802  bitLo, bitHi));
803  masks.push_back(mask);
804 
805  int64_t shiftRight = bitLo;
806  shiftRightAmounts.push_back(
807  IntegerAttr::get(shuffledElementType, shiftRight));
808 
809  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
810  shiftLeftAmounts.push_back(
811  IntegerAttr::get(shuffledElementType, shiftLeft));
812  }
813 
814  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
815  }
816  return result;
817 }
818 
819 Value BitCastRewriter::genericRewriteStep(
820  PatternRewriter &rewriter, Location loc, Value initialValue,
821  Value runningResult, const BitCastRewriter::Metadata &metadata) {
822  // Create vector.shuffle from the metadata.
823  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
824  loc, initialValue, initialValue, metadata.shuffles);
825 
826  // Intersect with the mask.
827  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
828  auto constOp = rewriter.create<arith::ConstantOp>(
829  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
830  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
831 
832  // Align right on 0.
833  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
834  loc,
835  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
836  Value shiftedRight =
837  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
838 
839  // Shift bits left into their final position.
840  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
841  loc,
842  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
843  Value shiftedLeft =
844  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
845 
846  runningResult =
847  runningResult
848  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
849  : shiftedLeft;
850 
851  return runningResult;
852 }
853 
854 /// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
855 /// bitwise ops that take advantage of high-level information to avoid leaving
856 /// LLVM to scramble with peephole optimizations.
858  Value srcValue) {
859  VectorType srcVecType = cast<VectorType>(srcValue.getType());
860  assert(srcVecType.getElementType().isSignlessInteger(4) &&
861  "Expected i4 type");
862 
863  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
864  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
865  constexpr int64_t i4Toi8BitwidthFactor = 2;
866  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
867  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
868  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
869 
870  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
871  // byte are place in one vector and the high i4 elements in another vector.
872  constexpr int8_t bitsToShift = 4;
873  auto shiftValues = rewriter.create<arith::ConstantOp>(
874  loc, DenseElementsAttr::get(i8VecType, bitsToShift));
875  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
876  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
877  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
878 
879  // 3. Interleave low and high i8 elements.
880  return rewriter.create<vector::InterleaveOp>(loc, low, high);
881 }
882 
883 /// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884 /// that take advantage of high-level information to avoid leaving LLVM to
885 /// scramble with peephole optimizations.
887  Value srcValue) {
888  VectorType srcVecType = cast<VectorType>(srcValue.getType());
889  assert(srcVecType.getElementType().isSignlessInteger(8) &&
890  "Expected i8 type");
891 
892  // 1. De-interleave low and high i8 elements.
893  int64_t vecDimSize = srcVecType.getShape().back();
894  SmallVector<int64_t> deinterleaveLowMaskValues;
895  SmallVector<int64_t> deinterleaveHighMaskValues;
896  assert((vecDimSize % 2) == 0 && "Odd number of i4 elements");
897  deinterleaveLowMaskValues.reserve(vecDimSize / 2);
898  deinterleaveHighMaskValues.reserve(vecDimSize / 2);
899  for (int i = 0, end = vecDimSize; i < end; i += 2) {
900  deinterleaveLowMaskValues.push_back(i);
901  deinterleaveHighMaskValues.push_back(i + 1);
902  }
903 
904  auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
905  loc, srcValue, srcValue,
906  rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
907  auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
908  loc, srcValue, srcValue,
909  rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
910 
911  // 2. Zero out the upper side of each low i8 element.
912  constexpr int8_t i8LowBitMask = 0x0F;
913  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
914  loc,
915  DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
916  Value zeroOutLow =
917  rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
918 
919  // 3. Move high i4 values to upper side of the byte.
920  constexpr int8_t bitsToShift = 4;
921  VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
922  auto shiftValues = rewriter.create<arith::ConstantOp>(
923  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
924  Value shlHigh =
925  rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
926 
927  // 4. Merge high and low i4 values.
928  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
929 
930  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
931  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
932  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
933 }
934 
935 namespace {
936 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
937 /// advantage of high-level information to avoid leaving LLVM to scramble with
938 /// peephole optimizations.
939 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
941 
942  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
943  PatternRewriter &rewriter) const override {
944  // The source must be a trunc op.
945  auto truncOp =
946  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
947  if (!truncOp)
948  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
949 
950  // Set up the BitCastRewriter and verify the precondition.
951  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
952  VectorType targetVectorType = bitCastOp.getResultVectorType();
953  BitCastRewriter bcr(sourceVectorType, targetVectorType);
954  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
955  return failure();
956 
957  // Perform the rewrite.
958  Value truncValue = truncOp.getIn();
959  auto shuffledElementType =
960  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
961  Value runningResult;
962  for (const BitCastRewriter ::Metadata &metadata :
963  bcr.precomputeMetadata(shuffledElementType)) {
964  runningResult = bcr.genericRewriteStep(
965  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
966  }
967 
968  // Finalize the rewrite.
969  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
970  shuffledElementType.getIntOrFloatBitWidth();
971  if (narrowing) {
972  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
973  rewriter.replaceOp(bitCastOp, runningResult);
974  } else {
975  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
976  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
977  }
978  } else {
979  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
980  rewriter.replaceOp(bitCastOp, runningResult);
981  } else {
982  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
983  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
984  }
985  }
986 
987  return success();
988  }
989 };
990 } // namespace
991 
992 //===----------------------------------------------------------------------===//
993 // RewriteExtOfBitCast
994 //===----------------------------------------------------------------------===//
995 
996 namespace {
997 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
998 /// take advantage of high-level information to avoid leaving LLVM to scramble
999 /// with peephole optimizations.
1000 template <typename ExtOpType>
1001 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1003 
1004  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1005  : OpRewritePattern<ExtOpType>(context, benefit) {}
1006 
1007  LogicalResult matchAndRewrite(ExtOpType extOp,
1008  PatternRewriter &rewriter) const override {
1009  // The source must be a bitcast op.
1010  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1011  if (!bitCastOp)
1012  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1013 
1014  // Set up the BitCastRewriter and verify the precondition.
1015  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1016  VectorType targetVectorType = bitCastOp.getResultVectorType();
1017  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1018  if (failed(bcr.commonPrecondition(
1019  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1020  return failure();
1021 
1022  // Perform the rewrite.
1023  Value runningResult;
1024  Value sourceValue = bitCastOp.getSource();
1025  auto shuffledElementType =
1026  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1027  for (const BitCastRewriter::Metadata &metadata :
1028  bcr.precomputeMetadata(shuffledElementType)) {
1029  runningResult = bcr.genericRewriteStep(
1030  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1031  }
1032 
1033  // Finalize the rewrite.
1034  bool narrowing =
1035  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1036  shuffledElementType.getIntOrFloatBitWidth();
1037  if (narrowing) {
1038  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1039  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1040  } else {
1041  rewriter.replaceOpWithNewOp<ExtOpType>(
1042  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1043  }
1044 
1045  return success();
1046  }
1047 };
1048 
1049 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1050 /// bitwise ops that take advantage of high-level information to avoid leaving
1051 /// LLVM to scramble with peephole optimizations.
1052 ///
1053 /// For example:
1054 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
1055 /// is rewriten as
1056 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1057 /// %1 = arith.shli %0, 4 : vector<4xi8>
1058 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1059 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1060 /// %4 = vector.interleave %2, %3 : vector<4xi8>
1061 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1062 ///
1063 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1064 /// is rewriten as
1065 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1066 /// %1 = arith.shli %0, 4 : vector<4xi8>
1067 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1068 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1069 /// %4 = vector.interleave %2, %3 : vector<4xi8>
1070 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1071 ///
1072 template <typename ConversionOpType>
1073 struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1075 
1076  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1077  PatternRewriter &rewriter) const override {
1078  // Verify the preconditions.
1079  Value srcValue = conversionOp.getIn();
1080  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1081  auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1082  if (failed(
1083  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1084  return failure();
1085 
1086  // Check general alignment preconditions.
1087  if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1088  conversionOp)))
1089  return failure();
1090 
1091  // Perform the rewrite.
1092  Value subByteExt =
1093  rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1094 
1095  // Finalize the rewrite.
1096  rewriter.replaceOpWithNewOp<ConversionOpType>(
1097  conversionOp, conversionOp.getType(), subByteExt);
1098  return success();
1099  }
1100 };
1101 
1102 /// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
1103 /// bitwise ops that take advantage of high-level information to avoid leaving
1104 /// LLVM to scramble with peephole optimizations.
1105 ///
1106 /// For example:
1107 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
1108 /// is rewriten as
1109 ///
1110 /// %cst = arith.constant dense<15> : vector<4xi8>
1111 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
1112 /// %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
1113 /// %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
1114 /// %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
1115 /// %3 = arith.andi %1, %cst : vector<4xi8>
1116 /// %4 = arith.shli %2, %cst_0 : vector<4xi8>
1117 /// %5 = arith.ori %3, %4 : vector<4xi8>
1118 /// %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
1119 ///
1120 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1122 
1123  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1124  PatternRewriter &rewriter) const override {
1125  // Verify the preconditions.
1126  Value srcValue = truncOp.getIn();
1127  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1128  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1129  if (!srcVecType || !dstVecType)
1130  return failure();
1131 
1132  // Only single dim vectors are supported until we have
1133  // `vector.deinterleave`.
1134  if (srcVecType.getRank() != 1)
1135  return failure();
1136 
1137  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1138  return failure();
1139 
1140  // Check general alignment preconditions. We invert the src/dst type order
1141  // to reuse the existing precondition logic.
1142  if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1143  truncOp)))
1144  return failure();
1145 
1146  // Create a new iX -> i8 truncation op.
1147  Location loc = truncOp.getLoc();
1148  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1149  Value i8TruncVal =
1150  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1151 
1152  // Rewrite the i8 -> i4 truncation part.
1153  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1154 
1155  // Finalize the rewrite.
1156  rewriter.replaceOp(truncOp, subByteTrunc);
1157  return success();
1158  }
1159 };
1160 
1161 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
1162 /// perform the transpose on wider (byte) element types.
1163 /// For example:
1164 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1165 ///
1166 /// is rewritten as:
1167 ///
1168 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1169 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1170 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1171 ///
1172 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1174 
1175  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1176  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1177 
1178  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1179  PatternRewriter &rewriter) const override {
1180  // Precondition: sub-byte integer transpose.
1181  constexpr unsigned minNativeBitwidth = 8;
1182  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1183  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1184  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1185  return rewriter.notifyMatchFailure(transposeOp,
1186  "not a sub-byte transpose");
1187  }
1188 
1189  // Perform the rewrite.
1190  Location loc = transposeOp.getLoc();
1191  // Signed/unsigned interpretation shouldn't matter here as we are just
1192  // transposing the elements and truncating them back to the original size.
1193  // TODO: Use unsigned extension (more efficient) when emulation or backend
1194  // support is available.
1195  auto srcNativeVecType = srcSubByteVecType.cloneWith(
1196  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1197  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1198  transposeOp.getVector());
1199  Value newTranspose = rewriter.create<vector::TransposeOp>(
1200  loc, extOp, transposeOp.getPermutation());
1201  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1202  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1203  newTranspose);
1204  return success();
1205  }
1206 };
1207 
1208 } // namespace
1209 
1210 //===----------------------------------------------------------------------===//
1211 // Public Interface Definition
1212 //===----------------------------------------------------------------------===//
1213 
1215  arith::NarrowTypeEmulationConverter &typeConverter,
1216  RewritePatternSet &patterns) {
1217 
1218  // Populate `vector.*` conversion patterns.
1219  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1220  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1221  typeConverter, patterns.getContext());
1222 }
1223 
1225  RewritePatternSet &patterns, PatternBenefit benefit) {
1226  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1227  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
1228  benefit);
1229 
1230  // Patterns for aligned cases. We set higher priority as they are expected to
1231  // generate better performance for aligned cases.
1232  patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1233  RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1234  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
1235  benefit.getBenefit() + 1);
1236 }
1237 
1239  RewritePatternSet &patterns, PatternBenefit benefit) {
1240  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
1241 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops that take advantage of hi...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
#define LDBG(X)
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale)
Returns a compressed mask.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:883
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IntegerType getI4Type()
Definition: Builders.cpp:77
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IntegerType getI8Type()
Definition: Builders.cpp:79
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:50
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329