MLIR  20.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include <cstdint>
27 
28 using namespace mlir;
29 
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
34 
35 /// Returns a compressed mask. The mask value is set only if any mask is present
36 /// in the scale range. E.g., if `scale` equals to 2, the following mask:
37 ///
38 /// %mask = [1, 1, 1, 0, 0, 0]
39 ///
40 /// will return the following new compressed mask:
41 ///
42 /// %mask = [1, 1, 0]
43 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
44  Location loc, Value mask,
45  int origElements, int scale) {
46  auto numElements = (origElements + scale - 1) / scale;
47 
48  Operation *maskOp = mask.getDefiningOp();
50  // Finding the mask creation operation.
51  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53  maskOp = extractOp.getVector().getDefiningOp();
54  extractOps.push_back(extractOp);
55  }
56  }
57  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59  if (!createMaskOp && !constantMaskOp)
60  return failure();
61 
62  // Computing the "compressed" mask. All the emulation logic (i.e. computing
63  // new mask index) only happens on the last dimension of the vectors.
64  Operation *newMask = nullptr;
66  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
67  shape.back() = numElements;
68  auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
69  if (createMaskOp) {
70  OperandRange maskOperands = createMaskOp.getOperands();
71  size_t numMaskOperands = maskOperands.size();
72  AffineExpr s0;
73  bindSymbols(rewriter.getContext(), s0);
74  s0 = s0 + scale - 1;
75  s0 = s0.floorDiv(scale);
76  OpFoldResult origIndex =
77  getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
78  OpFoldResult maskIndex =
79  affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
80  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
81  newMaskOperands.push_back(
82  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
83  newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
84  newMaskOperands);
85  } else if (constantMaskOp) {
86  ArrayRef<Attribute> maskDimSizes =
87  constantMaskOp.getMaskDimSizes().getValue();
88  size_t numMaskOperands = maskDimSizes.size();
89  auto origIndex =
90  cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
91  IntegerAttr maskIndexAttr =
92  rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
93  SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
94  newMaskDimSizes.push_back(maskIndexAttr);
95  newMask = rewriter.create<vector::ConstantMaskOp>(
96  loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
97  }
98 
99  while (!extractOps.empty()) {
100  newMask = rewriter.create<vector::ExtractOp>(
101  loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
102  extractOps.pop_back();
103  }
104 
105  return newMask;
106 }
107 
108 namespace {
109 
110 //===----------------------------------------------------------------------===//
111 // ConvertVectorStore
112 //===----------------------------------------------------------------------===//
113 
114 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
116 
117  LogicalResult
118  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
119  ConversionPatternRewriter &rewriter) const override {
120 
121  auto loc = op.getLoc();
122  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
123  Type oldElementType = op.getValueToStore().getType().getElementType();
124  Type newElementType = convertedType.getElementType();
125  int srcBits = oldElementType.getIntOrFloatBitWidth();
126  int dstBits = newElementType.getIntOrFloatBitWidth();
127 
128  if (dstBits % srcBits != 0) {
129  return rewriter.notifyMatchFailure(
130  op, "only dstBits % srcBits == 0 supported");
131  }
132  int scale = dstBits / srcBits;
133 
134  // Adjust the number of elements to store when emulating narrow types.
135  // Here only the 1-D vector store is considered, and the N-D memref types
136  // should be linearized.
137  // For example, to emulate i4 to i8, the following op:
138  //
139  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
140  //
141  // can be replaced with
142  //
143  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
144  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
145  // vector<4xi8>
146 
147  auto origElements = op.getValueToStore().getType().getNumElements();
148  if (origElements % scale != 0)
149  return failure();
150 
151  auto stridedMetadata =
152  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
153 
154  OpFoldResult linearizedIndices;
155  std::tie(std::ignore, linearizedIndices) =
157  rewriter, loc, srcBits, dstBits,
158  stridedMetadata.getConstifiedMixedOffset(),
159  stridedMetadata.getConstifiedMixedSizes(),
160  stridedMetadata.getConstifiedMixedStrides(),
161  getAsOpFoldResult(adaptor.getIndices()));
162 
163  auto numElements = origElements / scale;
164  auto bitCast = rewriter.create<vector::BitCastOp>(
165  loc, VectorType::get(numElements, newElementType),
166  op.getValueToStore());
167 
168  rewriter.replaceOpWithNewOp<vector::StoreOp>(
169  op, bitCast.getResult(), adaptor.getBase(),
170  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
171  return success();
172  }
173 };
174 
175 //===----------------------------------------------------------------------===//
176 // ConvertVectorMaskedStore
177 //===----------------------------------------------------------------------===//
178 
179 struct ConvertVectorMaskedStore final
180  : OpConversionPattern<vector::MaskedStoreOp> {
182 
183  LogicalResult
184  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override {
186 
187  auto loc = op.getLoc();
188  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
189  Type oldElementType = op.getValueToStore().getType().getElementType();
190  Type newElementType = convertedType.getElementType();
191  int srcBits = oldElementType.getIntOrFloatBitWidth();
192  int dstBits = newElementType.getIntOrFloatBitWidth();
193 
194  if (dstBits % srcBits != 0) {
195  return rewriter.notifyMatchFailure(
196  op, "only dstBits % srcBits == 0 supported");
197  }
198 
199  int scale = dstBits / srcBits;
200  int origElements = op.getValueToStore().getType().getNumElements();
201  if (origElements % scale != 0)
202  return failure();
203 
204  auto stridedMetadata =
205  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
206  OpFoldResult linearizedIndicesOfr;
207  std::tie(std::ignore, linearizedIndicesOfr) =
209  rewriter, loc, srcBits, dstBits,
210  stridedMetadata.getConstifiedMixedOffset(),
211  stridedMetadata.getConstifiedMixedSizes(),
212  stridedMetadata.getConstifiedMixedStrides(),
213  getAsOpFoldResult(adaptor.getIndices()));
214  Value linearizedIndices =
215  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
216 
217  // Load the whole data and use arith.select to handle the corner cases.
218  // E.g., given these input values:
219  //
220  // %mask = [1, 1, 1, 0, 0, 0]
221  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
222  // %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
223  //
224  // we'll have
225  //
226  // expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
227  //
228  // %new_mask = [1, 1, 0]
229  // %maskedload = [0x12, 0x34, 0x0]
230  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
231  // %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
232  // %packed_data = [0x78, 0x94, 0x00]
233  //
234  // Using the new mask to store %packed_data results in expected output.
235  FailureOr<Operation *> newMask =
236  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
237  if (failed(newMask))
238  return failure();
239 
240  auto numElements = (origElements + scale - 1) / scale;
241  auto newType = VectorType::get(numElements, newElementType);
242  auto passThru = rewriter.create<arith::ConstantOp>(
243  loc, newType, rewriter.getZeroAttr(newType));
244 
245  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
246  loc, newType, adaptor.getBase(), linearizedIndices,
247  newMask.value()->getResult(0), passThru);
248 
249  Value valueToStore = rewriter.create<vector::BitCastOp>(
250  loc, op.getValueToStore().getType(), newLoad);
251  valueToStore = rewriter.create<arith::SelectOp>(
252  loc, op.getMask(), op.getValueToStore(), valueToStore);
253  valueToStore =
254  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
255 
256  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
257  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
258  valueToStore);
259  return success();
260  }
261 };
262 
263 //===----------------------------------------------------------------------===//
264 // ConvertVectorLoad
265 //===----------------------------------------------------------------------===//
266 
267 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
269 
270  LogicalResult
271  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
272  ConversionPatternRewriter &rewriter) const override {
273 
274  auto loc = op.getLoc();
275  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
276  Type oldElementType = op.getType().getElementType();
277  Type newElementType = convertedType.getElementType();
278  int srcBits = oldElementType.getIntOrFloatBitWidth();
279  int dstBits = newElementType.getIntOrFloatBitWidth();
280 
281  if (dstBits % srcBits != 0) {
282  return rewriter.notifyMatchFailure(
283  op, "only dstBits % srcBits == 0 supported");
284  }
285  int scale = dstBits / srcBits;
286 
287  // Adjust the number of elements to load when emulating narrow types,
288  // and then cast back to the original type with vector.bitcast op.
289  // Here only the 1-D vector load is considered, and the N-D memref types
290  // should be linearized.
291  // For example, to emulate i4 to i8, the following op:
292  //
293  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
294  //
295  // can be replaced with
296  //
297  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
298  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
299  //
300  // TODO: Currently, only the even number of elements loading is supported.
301  // To deal with the odd number of elements, one has to extract the
302  // subvector at the proper offset after bit-casting.
303 
304  auto origElements = op.getVectorType().getNumElements();
305  if (origElements % scale != 0)
306  return failure();
307 
308  auto stridedMetadata =
309  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
310 
311  OpFoldResult linearizedIndices;
312  std::tie(std::ignore, linearizedIndices) =
314  rewriter, loc, srcBits, dstBits,
315  stridedMetadata.getConstifiedMixedOffset(),
316  stridedMetadata.getConstifiedMixedSizes(),
317  stridedMetadata.getConstifiedMixedStrides(),
318  getAsOpFoldResult(adaptor.getIndices()));
319 
320  auto numElements = (origElements + scale - 1) / scale;
321  auto newLoad = rewriter.create<vector::LoadOp>(
322  loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
323  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
324 
325  auto bitCast =
326  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
327 
328  rewriter.replaceOp(op, bitCast->getResult(0));
329  return success();
330  }
331 };
332 
333 //===----------------------------------------------------------------------===//
334 // ConvertVectorMaskedLoad
335 //===----------------------------------------------------------------------===//
336 
337 struct ConvertVectorMaskedLoad final
338  : OpConversionPattern<vector::MaskedLoadOp> {
340 
341  LogicalResult
342  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344 
345  auto loc = op.getLoc();
346  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
347  Type oldElementType = op.getType().getElementType();
348  Type newElementType = convertedType.getElementType();
349  int srcBits = oldElementType.getIntOrFloatBitWidth();
350  int dstBits = newElementType.getIntOrFloatBitWidth();
351 
352  if (dstBits % srcBits != 0) {
353  return rewriter.notifyMatchFailure(
354  op, "only dstBits % srcBits == 0 supported");
355  }
356  int scale = dstBits / srcBits;
357 
358  // Adjust the number of elements to load when emulating narrow types,
359  // and then cast back to the original type with vector.bitcast op.
360  // For example, to emulate i4 to i8, the following op:
361  //
362  // %mask = vector.constant_mask [3] : vector<6xi1>
363  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
364  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
365  //
366  // can be replaced with
367  //
368  // %new_mask = vector.constant_mask [2] : vector<3xi1>
369  // %new_pass_thru = vector.bitcast %pass_thru :
370  // vector<6xi4> to vector<3xi8>
371  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
372  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
373  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
374  //
375  // Since we are effectively loading 16 bits (2xi8) from the memref with the
376  // new mask, while originally we only wanted to effectively load 12 bits
377  // (3xi4) from the memref, we need to set the second half of the last i8
378  // that was effectively loaded (i.e. the second i8) to %pass_thru.
379  //
380  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
381  //
382  // Given these input values:
383  // %mask = [1, 1, 1, 0, 0, 0]
384  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
385  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
386  //
387  // we'll have:
388  //
389  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
390  //
391  // %new_mask = [1, 1, 0]
392  // %new_pass_thru = [0x78, 0x9A, 0xBC]
393  // %1 = [0x12, 0x34, 0xBC]
394  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
395  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
396  //
397  // TODO: Currently, only the even number of elements loading is supported.
398  // To deal with the odd number of elements, one has to extract the
399  // subvector at the proper offset after bit-casting.
400  auto origType = op.getVectorType();
401  auto origElements = origType.getNumElements();
402  if (origElements % scale != 0)
403  return failure();
404 
405  auto stridedMetadata =
406  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
407  OpFoldResult linearizedIndices;
408  std::tie(std::ignore, linearizedIndices) =
410  rewriter, loc, srcBits, dstBits,
411  stridedMetadata.getConstifiedMixedOffset(),
412  stridedMetadata.getConstifiedMixedSizes(),
413  stridedMetadata.getConstifiedMixedStrides(),
414  getAsOpFoldResult(adaptor.getIndices()));
415 
416  FailureOr<Operation *> newMask =
417  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
418  if (failed(newMask))
419  return failure();
420 
421  auto numElements = (origElements + scale - 1) / scale;
422  auto newType = VectorType::get(numElements, newElementType);
423  auto newPassThru =
424  rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
425 
426  // Generating the new masked load.
427  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
428  loc, newType, adaptor.getBase(),
429  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
430  newMask.value()->getResult(0), newPassThru);
431 
432  // Setting the part that originally was not effectively loaded from memory
433  // to pass through.
434  auto bitCast =
435  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
436  auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
437  op.getPassThru());
438  rewriter.replaceOp(op, select->getResult(0));
439 
440  return success();
441  }
442 };
443 
444 //===----------------------------------------------------------------------===//
445 // ConvertVectorTransferRead
446 //===----------------------------------------------------------------------===//
447 
448 struct ConvertVectorTransferRead final
449  : OpConversionPattern<vector::TransferReadOp> {
451 
452  LogicalResult
453  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
454  ConversionPatternRewriter &rewriter) const override {
455 
456  auto loc = op.getLoc();
457  auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
458  Type oldElementType = op.getType().getElementType();
459  Type newElementType = convertedType.getElementType();
460  int srcBits = oldElementType.getIntOrFloatBitWidth();
461  int dstBits = newElementType.getIntOrFloatBitWidth();
462 
463  if (dstBits % srcBits != 0) {
464  return rewriter.notifyMatchFailure(
465  op, "only dstBits % srcBits == 0 supported");
466  }
467  int scale = dstBits / srcBits;
468 
469  auto origElements = op.getVectorType().getNumElements();
470  if (origElements % scale != 0)
471  return failure();
472 
473  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
474  adaptor.getPadding());
475 
476  auto stridedMetadata =
477  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
478 
479  OpFoldResult linearizedIndices;
480  std::tie(std::ignore, linearizedIndices) =
482  rewriter, loc, srcBits, dstBits,
483  stridedMetadata.getConstifiedMixedOffset(),
484  stridedMetadata.getConstifiedMixedSizes(),
485  stridedMetadata.getConstifiedMixedStrides(),
486  getAsOpFoldResult(adaptor.getIndices()));
487 
488  auto numElements = (origElements + scale - 1) / scale;
489  auto newReadType = VectorType::get(numElements, newElementType);
490 
491  auto newRead = rewriter.create<vector::TransferReadOp>(
492  loc, newReadType, adaptor.getSource(),
493  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
494  newPadding);
495 
496  auto bitCast =
497  rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
498 
499  rewriter.replaceOp(op, bitCast->getResult(0));
500  return success();
501  }
502 };
503 } // end anonymous namespace
504 
505 //===----------------------------------------------------------------------===//
506 // RewriteBitCastOfTruncI
507 //===----------------------------------------------------------------------===//
508 
509 namespace {
510 
511 /// Helper struct to keep track of the provenance of a contiguous set of bits
512 /// in a source vector.
513 struct SourceElementRange {
514  /// The index of the source vector element that contributes bits to *this.
515  int64_t sourceElementIdx;
516  /// The range of bits in the source vector element that contribute to *this.
517  int64_t sourceBitBegin;
518  int64_t sourceBitEnd;
519 };
520 
521 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
522  /// Given the index of a SourceElementRange in the SourceElementRangeList,
523  /// compute the amount of bits that need to be shifted to the left to get the
524  /// bits in their final location. This shift amount is simply the sum of the
525  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
526  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
527  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
528  int64_t res = 0;
529  for (int64_t i = 0; i < shuffleIdx; ++i)
530  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
531  return res;
532  }
533 };
534 
535 /// Helper struct to enumerate the source elements and bit ranges that are
536 /// involved in a bitcast operation.
537 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
538 /// any 1-D vector shape and any source/target bitwidths.
539 /// This creates and holds a mapping of the form:
540 /// [dstVectorElementJ] ==
541 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
542 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
543 /// [0] = {0, [0-8)}
544 /// [1] = {0, [8-16)}
545 /// [2] = {0, [16-24)}
546 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
547 /// [0] = {0, [0, 10)}, {1, [0, 5)}
548 /// [1] = {1, [5, 10)}, {2, [0, 10)}
549 struct BitCastBitsEnumerator {
550  BitCastBitsEnumerator(VectorType sourceVectorType,
551  VectorType targetVectorType);
552 
553  int64_t getMaxNumberOfEntries() {
554  int64_t numVectors = 0;
555  for (const auto &l : sourceElementRanges)
556  numVectors = std::max(numVectors, (int64_t)l.size());
557  return numVectors;
558  }
559 
560  VectorType sourceVectorType;
561  VectorType targetVectorType;
562  SmallVector<SourceElementRangeList> sourceElementRanges;
563 };
564 
565 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
566 /// advantage of high-level information to avoid leaving LLVM to scramble with
567 /// peephole optimizations.
568 /// BitCastBitsEnumerator encodes for each element of the target vector the
569 /// provenance of the bits in the source vector. We can "transpose" this
570 /// information to build a sequence of shuffles and bitwise ops that will
571 /// produce the desired result.
572 //
573 /// Consider the following motivating example:
574 /// ```
575 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
576 /// ```
577 //
578 /// BitCastBitsEnumerator contains the following information:
579 /// ```
580 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
581 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
582 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
583 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
584 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
585 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
586 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
587 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
588 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
589 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
590 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
591 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
592 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
593 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
594 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
595 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
596 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
597 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
598 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
599 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
600 /// ```
601 ///
602 /// In the above, each row represents one target vector element and each
603 /// column represents one bit contribution from a source vector element.
604 /// The algorithm creates vector.shuffle operations (in this case there are 3
605 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
606 /// algorithm populates the bits as follows:
607 /// ```
608 /// src bits 0 ...
609 /// 1st shuffle |xxxxx |xx |...
610 /// 2nd shuffle | xxx| xxxxx |...
611 /// 3rd shuffle | | x|...
612 /// ```
613 //
614 /// The algorithm proceeds as follows:
615 /// 1. for each vector.shuffle, collect the source vectors that participate in
616 /// this shuffle. One source vector per target element of the resulting
617 /// vector.shuffle. If there is no source element contributing bits for the
618 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
619 /// 2 columns).
620 /// 2. represent the bitrange in the source vector as a mask. If there is no
621 /// source element contributing bits for the current vector.shuffle, take 0.
622 /// 3. shift right by the proper amount to align the source bitrange at
623 /// position 0. This is exactly the low end of the bitrange. For instance,
624 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
625 /// shift right by 3 to get the bits contributed by the source element #1
626 /// into position 0.
627 /// 4. shift left by the proper amount to to align to the desired position in
628 /// the result element vector. For instance, the contribution of the second
629 /// source element for the first row needs to be shifted by `5` to form the
630 /// first i8 result element.
631 ///
632 /// Eventually, we end up building the sequence
633 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
634 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
635 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
636 struct BitCastRewriter {
637  /// Helper metadata struct to hold the static quantities for the rewrite.
638  struct Metadata {
639  SmallVector<int64_t> shuffles;
640  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
641  };
642 
643  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
644 
645  /// Verify that general preconditions for the rewrite are met.
646  LogicalResult commonPrecondition(PatternRewriter &rewriter,
647  VectorType preconditionType, Operation *op);
648 
649  /// Precompute the metadata for the rewrite.
651  precomputeMetadata(IntegerType shuffledElementType);
652 
653  /// Rewrite one step of the sequence:
654  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
656  Value initialValue, Value runningResult,
657  const BitCastRewriter::Metadata &metadata);
658 
659 private:
660  /// Underlying enumerator that encodes the provenance of the bits in the each
661  /// element of the result vector.
662  BitCastBitsEnumerator enumerator;
663 };
664 
665 } // namespace
666 
667 [[maybe_unused]] static raw_ostream &
668 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
669  for (const auto &l : vec) {
670  for (auto it : llvm::enumerate(l)) {
671  os << "{ " << it.value().sourceElementIdx << ": b@["
672  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
673  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
674  }
675  os << "\n";
676  }
677  return os;
678 }
679 
680 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
681  VectorType targetVectorType)
682  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
683 
684  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
685  "requires -D non-scalable vector type");
686  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
687  "requires -D non-scalable vector type");
688  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
689  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
690  LDBG("sourceVectorType: " << sourceVectorType);
691 
692  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
693  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
694  LDBG("targetVectorType: " << targetVectorType);
695 
696  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
697  (void)mostMinorSourceDim;
698  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
699  "source and target bitwidths must match");
700 
701  // Prepopulate one source element range per target element.
702  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
703  for (int64_t resultBit = 0; resultBit < bitwidth;) {
704  int64_t resultElement = resultBit / targetBitWidth;
705  int64_t resultBitInElement = resultBit % targetBitWidth;
706  int64_t sourceElementIdx = resultBit / sourceBitWidth;
707  int64_t sourceBitInElement = resultBit % sourceBitWidth;
708  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
709  targetBitWidth - resultBitInElement);
710  sourceElementRanges[resultElement].push_back(
711  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
712  resultBit += step;
713  }
714 }
715 
716 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
717  VectorType targetVectorType)
718  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
719  LDBG("\n" << enumerator.sourceElementRanges);
720 }
721 
722 /// Verify that the precondition type meets the common preconditions for any
723 /// conversion.
724 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
725  VectorType preconditionType,
726  Operation *op) {
727  if (!preconditionType || preconditionType.isScalable())
728  return rewriter.notifyMatchFailure(op, "scalable vector");
729 
730  // TODO: consider relaxing this restriction in the future if we find ways
731  // to really work with subbyte elements across the MLIR/LLVM boundary.
732  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
733  if (bitwidth % 8 != 0)
734  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
735 
736  return success();
737 }
738 
739 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
740  VectorType preconditionType,
741  Operation *op) {
742  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
743  return rewriter.notifyMatchFailure(op, "types are not vector");
744 
745  if (!preconditionType || preconditionType.getRank() != 1)
746  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
747 
748  return commonConversionPrecondition(rewriter, preconditionType, op);
749 }
750 
751 /// Verify that source and destination element types meet the precondition for
752 /// the supported aligned conversion cases. Alignment means that the either the
753 /// source element type is multiple of the destination element type or the other
754 /// way around.
755 ///
756 /// NOTE: This method assumes that common conversion preconditions are met.
757 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
758  VectorType srcType,
759  VectorType dstType,
760  Operation *op) {
761  if (!srcType || !dstType)
762  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
763  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
764  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
765 
766  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
767  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
768  (dstElemBitwidth % srcElemBitwidth) != 0)
769  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
770 
771  if ((srcType.getShape().back() % 2) != 0)
772  return rewriter.notifyMatchFailure(
773  op, "Not an even number of i4 elements in trailing dim");
774 
775  return success();
776 }
777 
779 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
781  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
782  shuffleIdx < e; ++shuffleIdx) {
783  SmallVector<int64_t> shuffles;
784  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
785 
786  // Create the attribute quantities for the shuffle / mask / shift ops.
787  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
788  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
789  ? srcEltRangeList[shuffleIdx].sourceElementIdx
790  : 0;
791  shuffles.push_back(sourceElement);
792 
793  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
794  ? srcEltRangeList[shuffleIdx].sourceBitBegin
795  : 0;
796  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
797  ? srcEltRangeList[shuffleIdx].sourceBitEnd
798  : 0;
799  IntegerAttr mask = IntegerAttr::get(
800  shuffledElementType,
801  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
802  bitLo, bitHi));
803  masks.push_back(mask);
804 
805  int64_t shiftRight = bitLo;
806  shiftRightAmounts.push_back(
807  IntegerAttr::get(shuffledElementType, shiftRight));
808 
809  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
810  shiftLeftAmounts.push_back(
811  IntegerAttr::get(shuffledElementType, shiftLeft));
812  }
813 
814  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
815  }
816  return result;
817 }
818 
819 Value BitCastRewriter::genericRewriteStep(
820  PatternRewriter &rewriter, Location loc, Value initialValue,
821  Value runningResult, const BitCastRewriter::Metadata &metadata) {
822  // Create vector.shuffle from the metadata.
823  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
824  loc, initialValue, initialValue, metadata.shuffles);
825 
826  // Intersect with the mask.
827  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
828  auto constOp = rewriter.create<arith::ConstantOp>(
829  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
830  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
831 
832  // Align right on 0.
833  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
834  loc,
835  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
836  Value shiftedRight =
837  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
838 
839  // Shift bits left into their final position.
840  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
841  loc,
842  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
843  Value shiftedLeft =
844  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
845 
846  runningResult =
847  runningResult
848  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
849  : shiftedLeft;
850 
851  return runningResult;
852 }
853 
854 /// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
855 /// bitwise ops that take advantage of high-level information to avoid leaving
856 /// LLVM to scramble with peephole optimizations.
858  Value srcValue) {
859  VectorType srcVecType = cast<VectorType>(srcValue.getType());
860  assert(srcVecType.getElementType().isSignlessInteger(4) &&
861  "Expected i4 type");
862 
863  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
864  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
865  constexpr int64_t i4Toi8BitwidthFactor = 2;
866  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
867  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
868  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
869 
870  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
871  // byte are place in one vector and the high i4 elements in another vector.
872  constexpr int8_t bitsToShift = 4;
873  auto shiftValues = rewriter.create<arith::ConstantOp>(
874  loc, DenseElementsAttr::get(i8VecType, bitsToShift));
875  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
876  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
877  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
878 
879  // 3. Interleave low and high i8 elements.
880  return rewriter.create<vector::InterleaveOp>(loc, low, high);
881 }
882 
883 /// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884 /// bitwise ops that take advantage of high-level information to avoid leaving
885 /// LLVM to scramble with peephole optimizations.
887  Value srcValue) {
888  VectorType srcVecType = cast<VectorType>(srcValue.getType());
889  assert(srcVecType.getElementType().isSignlessInteger(4) &&
890  "Expected i4 type");
891 
892  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
894  constexpr int64_t i4Toi8BitwidthFactor = 2;
895  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
896  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
897  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
898 
899  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900  // byte are placed in one vector and the high i4 elements in another vector.
901  constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
902  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
903  loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
904  Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
905  lowBitsMaskValues);
906  constexpr int8_t highBitsToShift = 4;
907  auto highShiftValues = rewriter.create<arith::ConstantOp>(
908  loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
909  Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
910 
911  // 3. Interleave low and high i8 elements.
912  return rewriter.create<vector::InterleaveOp>(loc, low, high);
913 }
914 
915 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
916 /// ops that take advantage of high-level information to avoid leaving LLVM to
917 /// scramble with peephole optimizations.
919  Value srcValue) {
920  VectorType srcVecType = cast<VectorType>(srcValue.getType());
921  assert(srcVecType.getElementType().isSignlessInteger(8) &&
922  "Expected i8 type");
923 
924  // 1. De-interleave low and high i8 elements.
925  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
926 
927  // 2. Zero out the upper side of each low i8 element.
928  constexpr int8_t i8LowBitMask = 0x0F;
929  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
930  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
931  loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
932  Value zeroOutLow = rewriter.create<arith::AndIOp>(
933  loc, deinterleaveOp.getRes1(), zeroOutMask);
934 
935  // 3. Move high i4 values to upper side of the byte.
936  constexpr int8_t bitsToShift = 4;
937  auto shiftValues = rewriter.create<arith::ConstantOp>(
938  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
939  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
940  shiftValues);
941 
942  // 4. Merge high and low i4 values.
943  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
944 
945  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
946  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
947  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
948 }
949 
950 namespace {
951 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
952 /// advantage of high-level information to avoid leaving LLVM to scramble with
953 /// peephole optimizations.
954 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
956 
957  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
958  PatternRewriter &rewriter) const override {
959  // The source must be a trunc op.
960  auto truncOp =
961  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
962  if (!truncOp)
963  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
964 
965  // Set up the BitCastRewriter and verify the precondition.
966  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
967  VectorType targetVectorType = bitCastOp.getResultVectorType();
968  BitCastRewriter bcr(sourceVectorType, targetVectorType);
969  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
970  return failure();
971 
972  // Perform the rewrite.
973  Value truncValue = truncOp.getIn();
974  auto shuffledElementType =
975  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
976  Value runningResult;
977  for (const BitCastRewriter ::Metadata &metadata :
978  bcr.precomputeMetadata(shuffledElementType)) {
979  runningResult = bcr.genericRewriteStep(
980  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
981  }
982 
983  // Finalize the rewrite.
984  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
985  shuffledElementType.getIntOrFloatBitWidth();
986  if (narrowing) {
987  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
988  rewriter.replaceOp(bitCastOp, runningResult);
989  } else {
990  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
991  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
992  }
993  } else {
994  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
995  rewriter.replaceOp(bitCastOp, runningResult);
996  } else {
997  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
998  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
999  }
1000  }
1001 
1002  return success();
1003  }
1004 };
1005 } // namespace
1006 
1007 //===----------------------------------------------------------------------===//
1008 // RewriteExtOfBitCast
1009 //===----------------------------------------------------------------------===//
1010 
1011 namespace {
1012 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1013 /// take advantage of high-level information to avoid leaving LLVM to scramble
1014 /// with peephole optimizations.
1015 template <typename ExtOpType>
1016 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1018 
1019  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1020  : OpRewritePattern<ExtOpType>(context, benefit) {}
1021 
1022  LogicalResult matchAndRewrite(ExtOpType extOp,
1023  PatternRewriter &rewriter) const override {
1024  // The source must be a bitcast op.
1025  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1026  if (!bitCastOp)
1027  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1028 
1029  // Set up the BitCastRewriter and verify the precondition.
1030  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1031  VectorType targetVectorType = bitCastOp.getResultVectorType();
1032  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1033  if (failed(bcr.commonPrecondition(
1034  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1035  return failure();
1036 
1037  // Perform the rewrite.
1038  Value runningResult;
1039  Value sourceValue = bitCastOp.getSource();
1040  auto shuffledElementType =
1041  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1042  for (const BitCastRewriter::Metadata &metadata :
1043  bcr.precomputeMetadata(shuffledElementType)) {
1044  runningResult = bcr.genericRewriteStep(
1045  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1046  }
1047 
1048  // Finalize the rewrite.
1049  bool narrowing =
1050  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1051  shuffledElementType.getIntOrFloatBitWidth();
1052  if (narrowing) {
1053  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1054  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1055  } else {
1056  rewriter.replaceOpWithNewOp<ExtOpType>(
1057  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1058  }
1059 
1060  return success();
1061  }
1062 };
1063 
1064 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1065 /// bitwise ops that take advantage of high-level information to avoid leaving
1066 /// LLVM to scramble with peephole optimizations. Templated to choose between
1067 /// signed and unsigned conversions.
1068 ///
1069 /// For example (signed):
1070 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
1071 /// is rewriten as
1072 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1073 /// %1 = arith.shli %0, 4 : vector<4xi8>
1074 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1075 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1076 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1077 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1078 ///
1079 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1080 /// is rewriten as
1081 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1082 /// %1 = arith.shli %0, 4 : vector<4xi8>
1083 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1084 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1085 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1086 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1087 ///
1088 /// Example (unsigned):
1089 /// arith.extui %in : vector<8xi4> to vector<8xi32>
1090 /// is rewritten as
1091 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1092 /// %1 = arith.andi %0, 15 : vector<4xi8>
1093 /// %2 = arith.shrui %0, 4 : vector<4xi8>
1094 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1095 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1096 ///
1097 template <typename ConversionOpType, bool isSigned>
1098 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1100 
1101  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1102  PatternRewriter &rewriter) const override {
1103  // Verify the preconditions.
1104  Value srcValue = conversionOp.getIn();
1105  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1106  auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1107 
1108  if (failed(
1109  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1110  return failure();
1111 
1112  // Check general alignment preconditions.
1113  if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1114  conversionOp)))
1115  return failure();
1116 
1117  // Perform the rewrite.
1118  Value subByteExt;
1119  if (isSigned) {
1120  subByteExt =
1121  rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1122  } else {
1123  subByteExt =
1124  rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1125  }
1126 
1127  // Finalize the rewrite.
1128  rewriter.replaceOpWithNewOp<ConversionOpType>(
1129  conversionOp, conversionOp.getType(), subByteExt);
1130  return success();
1131  }
1132 };
1133 
1134 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1135 /// bitwise ops that take advantage of high-level information to avoid leaving
1136 /// LLVM to scramble with peephole optimizations.
1137 ///
1138 /// For example:
1139 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
1140 /// is rewriten as
1141 ///
1142 /// %cst = arith.constant dense<15> : vector<4xi8>
1143 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
1144 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1145 /// %2 = arith.andi %0, %cst : vector<4xi8>
1146 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1147 /// %4 = arith.ori %2, %3 : vector<4xi8>
1148 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1149 ///
1150 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1152 
1153  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1154  PatternRewriter &rewriter) const override {
1155  // Verify the preconditions.
1156  Value srcValue = truncOp.getIn();
1157  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1158  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1159  if (!srcVecType || !dstVecType)
1160  return failure();
1161 
1162  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1163  return failure();
1164 
1165  // Check general alignment preconditions. We invert the src/dst type order
1166  // to reuse the existing precondition logic.
1167  if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1168  truncOp)))
1169  return failure();
1170 
1171  // Create a new iX -> i8 truncation op.
1172  Location loc = truncOp.getLoc();
1173  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1174  Value i8TruncVal =
1175  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1176 
1177  // Rewrite the i8 -> i4 truncation part.
1178  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1179 
1180  // Finalize the rewrite.
1181  rewriter.replaceOp(truncOp, subByteTrunc);
1182  return success();
1183  }
1184 };
1185 
1186 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
1187 /// perform the transpose on wider (byte) element types.
1188 /// For example:
1189 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1190 ///
1191 /// is rewritten as:
1192 ///
1193 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1194 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1195 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1196 ///
1197 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1199 
1200  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1201  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1202 
1203  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1204  PatternRewriter &rewriter) const override {
1205  // Precondition: sub-byte integer transpose.
1206  constexpr unsigned minNativeBitwidth = 8;
1207  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1208  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1209  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1210  return rewriter.notifyMatchFailure(transposeOp,
1211  "not a sub-byte transpose");
1212  }
1213 
1214  // Perform the rewrite.
1215  Location loc = transposeOp.getLoc();
1216  // Signed/unsigned interpretation shouldn't matter here as we are just
1217  // transposing the elements and truncating them back to the original size.
1218  // TODO: Use unsigned extension (more efficient) when emulation or backend
1219  // support is available.
1220  auto srcNativeVecType = srcSubByteVecType.cloneWith(
1221  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1222  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1223  transposeOp.getVector());
1224  Value newTranspose = rewriter.create<vector::TransposeOp>(
1225  loc, extOp, transposeOp.getPermutation());
1226  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1227  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1228  newTranspose);
1229  return success();
1230  }
1231 };
1232 
1233 } // namespace
1234 
1235 //===----------------------------------------------------------------------===//
1236 // Public Interface Definition
1237 //===----------------------------------------------------------------------===//
1238 
1240  arith::NarrowTypeEmulationConverter &typeConverter,
1241  RewritePatternSet &patterns) {
1242 
1243  // Populate `vector.*` conversion patterns.
1244  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1245  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1246  typeConverter, patterns.getContext());
1247 }
1248 
1250  RewritePatternSet &patterns, PatternBenefit benefit) {
1251  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1252  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
1253  benefit);
1254 
1255  // Patterns for aligned cases. We set higher priority as they are expected to
1256  // generate better performance for aligned cases.
1257  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1258  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
1259  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
1260  benefit.getBenefit() + 1);
1261  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
1262  patterns.getContext(), benefit.getBenefit() + 1);
1263 }
1264 
1266  RewritePatternSet &patterns, PatternBenefit benefit) {
1267  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
1268 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
#define LDBG(X)
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale)
Returns a compressed mask.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:904
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
IntegerType getI4Type()
Definition: Builders.cpp:81
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:77
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
IntegerType getI8Type()
Definition: Builders.cpp:83
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329