MLIR  20.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include <cstdint>
27 
28 using namespace mlir;
29 
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
34 
35 /// Returns a compressed mask. The mask value is set only if any mask is present
36 /// in the scale range. E.g., if `scale` equals to 2, the following mask:
37 ///
38 /// %mask = [1, 1, 1, 0, 0, 0]
39 ///
40 /// will return the following new compressed mask:
41 ///
42 /// %mask = [1, 1, 0]
43 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
44  Location loc, Value mask,
45  int origElements, int scale) {
46  auto numElements = (origElements + scale - 1) / scale;
47 
48  Operation *maskOp = mask.getDefiningOp();
50  // Finding the mask creation operation.
51  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
52  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
53  maskOp = extractOp.getVector().getDefiningOp();
54  extractOps.push_back(extractOp);
55  }
56  }
57  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
58  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
59  if (!createMaskOp && !constantMaskOp)
60  return failure();
61 
62  // Computing the "compressed" mask. All the emulation logic (i.e. computing
63  // new mask index) only happens on the last dimension of the vectors.
64  Operation *newMask = nullptr;
66  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
67  shape.back() = numElements;
68  auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
69  if (createMaskOp) {
70  OperandRange maskOperands = createMaskOp.getOperands();
71  size_t numMaskOperands = maskOperands.size();
72  AffineExpr s0;
73  bindSymbols(rewriter.getContext(), s0);
74  s0 = s0 + scale - 1;
75  s0 = s0.floorDiv(scale);
76  OpFoldResult origIndex =
77  getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
78  OpFoldResult maskIndex =
79  affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
80  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
81  newMaskOperands.push_back(
82  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
83  newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
84  newMaskOperands);
85  } else if (constantMaskOp) {
86  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
87  size_t numMaskOperands = maskDimSizes.size();
88  int64_t origIndex = maskDimSizes[numMaskOperands - 1];
89  int64_t maskIndex = (origIndex + scale - 1) / scale;
90  SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
91  newMaskDimSizes.push_back(maskIndex);
92  newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
93  newMaskDimSizes);
94  }
95 
96  while (!extractOps.empty()) {
97  newMask = rewriter.create<vector::ExtractOp>(
98  loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
99  extractOps.pop_back();
100  }
101 
102  return newMask;
103 }
104 
105 namespace {
106 
107 //===----------------------------------------------------------------------===//
108 // ConvertVectorStore
109 //===----------------------------------------------------------------------===//
110 
111 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
113 
114  LogicalResult
115  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
116  ConversionPatternRewriter &rewriter) const override {
117 
118  auto loc = op.getLoc();
119  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
120  Type oldElementType = op.getValueToStore().getType().getElementType();
121  Type newElementType = convertedType.getElementType();
122  int srcBits = oldElementType.getIntOrFloatBitWidth();
123  int dstBits = newElementType.getIntOrFloatBitWidth();
124 
125  if (dstBits % srcBits != 0) {
126  return rewriter.notifyMatchFailure(
127  op, "only dstBits % srcBits == 0 supported");
128  }
129  int scale = dstBits / srcBits;
130 
131  // Adjust the number of elements to store when emulating narrow types.
132  // Here only the 1-D vector store is considered, and the N-D memref types
133  // should be linearized.
134  // For example, to emulate i4 to i8, the following op:
135  //
136  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
137  //
138  // can be replaced with
139  //
140  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
141  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
142  // vector<4xi8>
143 
144  auto origElements = op.getValueToStore().getType().getNumElements();
145  if (origElements % scale != 0)
146  return failure();
147 
148  auto stridedMetadata =
149  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
150 
151  OpFoldResult linearizedIndices;
152  std::tie(std::ignore, linearizedIndices) =
154  rewriter, loc, srcBits, dstBits,
155  stridedMetadata.getConstifiedMixedOffset(),
156  stridedMetadata.getConstifiedMixedSizes(),
157  stridedMetadata.getConstifiedMixedStrides(),
158  getAsOpFoldResult(adaptor.getIndices()));
159 
160  auto numElements = origElements / scale;
161  auto bitCast = rewriter.create<vector::BitCastOp>(
162  loc, VectorType::get(numElements, newElementType),
163  op.getValueToStore());
164 
165  rewriter.replaceOpWithNewOp<vector::StoreOp>(
166  op, bitCast.getResult(), adaptor.getBase(),
167  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
168  return success();
169  }
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // ConvertVectorMaskedStore
174 //===----------------------------------------------------------------------===//
175 
176 struct ConvertVectorMaskedStore final
177  : OpConversionPattern<vector::MaskedStoreOp> {
179 
180  LogicalResult
181  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
182  ConversionPatternRewriter &rewriter) const override {
183 
184  auto loc = op.getLoc();
185  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
186  Type oldElementType = op.getValueToStore().getType().getElementType();
187  Type newElementType = convertedType.getElementType();
188  int srcBits = oldElementType.getIntOrFloatBitWidth();
189  int dstBits = newElementType.getIntOrFloatBitWidth();
190 
191  if (dstBits % srcBits != 0) {
192  return rewriter.notifyMatchFailure(
193  op, "only dstBits % srcBits == 0 supported");
194  }
195 
196  int scale = dstBits / srcBits;
197  int origElements = op.getValueToStore().getType().getNumElements();
198  if (origElements % scale != 0)
199  return failure();
200 
201  auto stridedMetadata =
202  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
203  OpFoldResult linearizedIndicesOfr;
204  std::tie(std::ignore, linearizedIndicesOfr) =
206  rewriter, loc, srcBits, dstBits,
207  stridedMetadata.getConstifiedMixedOffset(),
208  stridedMetadata.getConstifiedMixedSizes(),
209  stridedMetadata.getConstifiedMixedStrides(),
210  getAsOpFoldResult(adaptor.getIndices()));
211  Value linearizedIndices =
212  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
213 
214  // Load the whole data and use arith.select to handle the corner cases.
215  // E.g., given these input values:
216  //
217  // %mask = [1, 1, 1, 0, 0, 0]
218  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
219  // %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
220  //
221  // we'll have
222  //
223  // expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
224  //
225  // %new_mask = [1, 1, 0]
226  // %maskedload = [0x12, 0x34, 0x0]
227  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
228  // %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
229  // %packed_data = [0x78, 0x94, 0x00]
230  //
231  // Using the new mask to store %packed_data results in expected output.
232  FailureOr<Operation *> newMask =
233  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
234  if (failed(newMask))
235  return failure();
236 
237  auto numElements = (origElements + scale - 1) / scale;
238  auto newType = VectorType::get(numElements, newElementType);
239  auto passThru = rewriter.create<arith::ConstantOp>(
240  loc, newType, rewriter.getZeroAttr(newType));
241 
242  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
243  loc, newType, adaptor.getBase(), linearizedIndices,
244  newMask.value()->getResult(0), passThru);
245 
246  Value valueToStore = rewriter.create<vector::BitCastOp>(
247  loc, op.getValueToStore().getType(), newLoad);
248  valueToStore = rewriter.create<arith::SelectOp>(
249  loc, op.getMask(), op.getValueToStore(), valueToStore);
250  valueToStore =
251  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
252 
253  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
254  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
255  valueToStore);
256  return success();
257  }
258 };
259 
260 //===----------------------------------------------------------------------===//
261 // ConvertVectorLoad
262 //===----------------------------------------------------------------------===//
263 
264 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
266 
267  LogicalResult
268  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
269  ConversionPatternRewriter &rewriter) const override {
270 
271  auto loc = op.getLoc();
272  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
273  Type oldElementType = op.getType().getElementType();
274  Type newElementType = convertedType.getElementType();
275  int srcBits = oldElementType.getIntOrFloatBitWidth();
276  int dstBits = newElementType.getIntOrFloatBitWidth();
277 
278  if (dstBits % srcBits != 0) {
279  return rewriter.notifyMatchFailure(
280  op, "only dstBits % srcBits == 0 supported");
281  }
282  int scale = dstBits / srcBits;
283 
284  // Adjust the number of elements to load when emulating narrow types,
285  // and then cast back to the original type with vector.bitcast op.
286  // Here only the 1-D vector load is considered, and the N-D memref types
287  // should be linearized.
288  // For example, to emulate i4 to i8, the following op:
289  //
290  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
291  //
292  // can be replaced with
293  //
294  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
295  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
296  //
297  // TODO: Currently, only the even number of elements loading is supported.
298  // To deal with the odd number of elements, one has to extract the
299  // subvector at the proper offset after bit-casting.
300 
301  auto origElements = op.getVectorType().getNumElements();
302  if (origElements % scale != 0)
303  return failure();
304 
305  auto stridedMetadata =
306  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
307 
308  OpFoldResult linearizedIndices;
309  std::tie(std::ignore, linearizedIndices) =
311  rewriter, loc, srcBits, dstBits,
312  stridedMetadata.getConstifiedMixedOffset(),
313  stridedMetadata.getConstifiedMixedSizes(),
314  stridedMetadata.getConstifiedMixedStrides(),
315  getAsOpFoldResult(adaptor.getIndices()));
316 
317  auto numElements = (origElements + scale - 1) / scale;
318  auto newLoad = rewriter.create<vector::LoadOp>(
319  loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
320  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
321 
322  auto bitCast =
323  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
324 
325  rewriter.replaceOp(op, bitCast->getResult(0));
326  return success();
327  }
328 };
329 
330 //===----------------------------------------------------------------------===//
331 // ConvertVectorMaskedLoad
332 //===----------------------------------------------------------------------===//
333 
334 struct ConvertVectorMaskedLoad final
335  : OpConversionPattern<vector::MaskedLoadOp> {
337 
338  LogicalResult
339  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const override {
341 
342  auto loc = op.getLoc();
343  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
344  Type oldElementType = op.getType().getElementType();
345  Type newElementType = convertedType.getElementType();
346  int srcBits = oldElementType.getIntOrFloatBitWidth();
347  int dstBits = newElementType.getIntOrFloatBitWidth();
348 
349  if (dstBits % srcBits != 0) {
350  return rewriter.notifyMatchFailure(
351  op, "only dstBits % srcBits == 0 supported");
352  }
353  int scale = dstBits / srcBits;
354 
355  // Adjust the number of elements to load when emulating narrow types,
356  // and then cast back to the original type with vector.bitcast op.
357  // For example, to emulate i4 to i8, the following op:
358  //
359  // %mask = vector.constant_mask [3] : vector<6xi1>
360  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
361  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
362  //
363  // can be replaced with
364  //
365  // %new_mask = vector.constant_mask [2] : vector<3xi1>
366  // %new_pass_thru = vector.bitcast %pass_thru :
367  // vector<6xi4> to vector<3xi8>
368  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
369  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
370  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
371  //
372  // Since we are effectively loading 16 bits (2xi8) from the memref with the
373  // new mask, while originally we only wanted to effectively load 12 bits
374  // (3xi4) from the memref, we need to set the second half of the last i8
375  // that was effectively loaded (i.e. the second i8) to %pass_thru.
376  //
377  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
378  //
379  // Given these input values:
380  // %mask = [1, 1, 1, 0, 0, 0]
381  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
382  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
383  //
384  // we'll have:
385  //
386  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
387  //
388  // %new_mask = [1, 1, 0]
389  // %new_pass_thru = [0x78, 0x9A, 0xBC]
390  // %1 = [0x12, 0x34, 0xBC]
391  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
392  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
393  //
394  // TODO: Currently, only the even number of elements loading is supported.
395  // To deal with the odd number of elements, one has to extract the
396  // subvector at the proper offset after bit-casting.
397  auto origType = op.getVectorType();
398  auto origElements = origType.getNumElements();
399  if (origElements % scale != 0)
400  return failure();
401 
402  auto stridedMetadata =
403  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
404  OpFoldResult linearizedIndices;
405  std::tie(std::ignore, linearizedIndices) =
407  rewriter, loc, srcBits, dstBits,
408  stridedMetadata.getConstifiedMixedOffset(),
409  stridedMetadata.getConstifiedMixedSizes(),
410  stridedMetadata.getConstifiedMixedStrides(),
411  getAsOpFoldResult(adaptor.getIndices()));
412 
413  FailureOr<Operation *> newMask =
414  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
415  if (failed(newMask))
416  return failure();
417 
418  auto numElements = (origElements + scale - 1) / scale;
419  auto newType = VectorType::get(numElements, newElementType);
420  auto newPassThru =
421  rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
422 
423  // Generating the new masked load.
424  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
425  loc, newType, adaptor.getBase(),
426  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
427  newMask.value()->getResult(0), newPassThru);
428 
429  // Setting the part that originally was not effectively loaded from memory
430  // to pass through.
431  auto bitCast =
432  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
433  auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
434  op.getPassThru());
435  rewriter.replaceOp(op, select->getResult(0));
436 
437  return success();
438  }
439 };
440 
441 //===----------------------------------------------------------------------===//
442 // ConvertVectorTransferRead
443 //===----------------------------------------------------------------------===//
444 
445 struct ConvertVectorTransferRead final
446  : OpConversionPattern<vector::TransferReadOp> {
448 
449  LogicalResult
450  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
451  ConversionPatternRewriter &rewriter) const override {
452 
453  auto loc = op.getLoc();
454  auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
455  Type oldElementType = op.getType().getElementType();
456  Type newElementType = convertedType.getElementType();
457  int srcBits = oldElementType.getIntOrFloatBitWidth();
458  int dstBits = newElementType.getIntOrFloatBitWidth();
459 
460  if (dstBits % srcBits != 0) {
461  return rewriter.notifyMatchFailure(
462  op, "only dstBits % srcBits == 0 supported");
463  }
464  int scale = dstBits / srcBits;
465 
466  auto origElements = op.getVectorType().getNumElements();
467  if (origElements % scale != 0)
468  return failure();
469 
470  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
471  adaptor.getPadding());
472 
473  auto stridedMetadata =
474  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
475 
476  OpFoldResult linearizedIndices;
477  std::tie(std::ignore, linearizedIndices) =
479  rewriter, loc, srcBits, dstBits,
480  stridedMetadata.getConstifiedMixedOffset(),
481  stridedMetadata.getConstifiedMixedSizes(),
482  stridedMetadata.getConstifiedMixedStrides(),
483  getAsOpFoldResult(adaptor.getIndices()));
484 
485  auto numElements = (origElements + scale - 1) / scale;
486  auto newReadType = VectorType::get(numElements, newElementType);
487 
488  auto newRead = rewriter.create<vector::TransferReadOp>(
489  loc, newReadType, adaptor.getSource(),
490  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
491  newPadding);
492 
493  auto bitCast =
494  rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
495 
496  rewriter.replaceOp(op, bitCast->getResult(0));
497  return success();
498  }
499 };
500 } // end anonymous namespace
501 
502 //===----------------------------------------------------------------------===//
503 // RewriteBitCastOfTruncI
504 //===----------------------------------------------------------------------===//
505 
506 namespace {
507 
508 /// Helper struct to keep track of the provenance of a contiguous set of bits
509 /// in a source vector.
510 struct SourceElementRange {
511  /// The index of the source vector element that contributes bits to *this.
512  int64_t sourceElementIdx;
513  /// The range of bits in the source vector element that contribute to *this.
514  int64_t sourceBitBegin;
515  int64_t sourceBitEnd;
516 };
517 
518 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
519  /// Given the index of a SourceElementRange in the SourceElementRangeList,
520  /// compute the amount of bits that need to be shifted to the left to get the
521  /// bits in their final location. This shift amount is simply the sum of the
522  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
523  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
524  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
525  int64_t res = 0;
526  for (int64_t i = 0; i < shuffleIdx; ++i)
527  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
528  return res;
529  }
530 };
531 
532 /// Helper struct to enumerate the source elements and bit ranges that are
533 /// involved in a bitcast operation.
534 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
535 /// any 1-D vector shape and any source/target bitwidths.
536 /// This creates and holds a mapping of the form:
537 /// [dstVectorElementJ] ==
538 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
539 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
540 /// [0] = {0, [0-8)}
541 /// [1] = {0, [8-16)}
542 /// [2] = {0, [16-24)}
543 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
544 /// [0] = {0, [0, 10)}, {1, [0, 5)}
545 /// [1] = {1, [5, 10)}, {2, [0, 10)}
546 struct BitCastBitsEnumerator {
547  BitCastBitsEnumerator(VectorType sourceVectorType,
548  VectorType targetVectorType);
549 
550  int64_t getMaxNumberOfEntries() {
551  int64_t numVectors = 0;
552  for (const auto &l : sourceElementRanges)
553  numVectors = std::max(numVectors, (int64_t)l.size());
554  return numVectors;
555  }
556 
557  VectorType sourceVectorType;
558  VectorType targetVectorType;
559  SmallVector<SourceElementRangeList> sourceElementRanges;
560 };
561 
562 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
563 /// advantage of high-level information to avoid leaving LLVM to scramble with
564 /// peephole optimizations.
565 /// BitCastBitsEnumerator encodes for each element of the target vector the
566 /// provenance of the bits in the source vector. We can "transpose" this
567 /// information to build a sequence of shuffles and bitwise ops that will
568 /// produce the desired result.
569 //
570 /// Consider the following motivating example:
571 /// ```
572 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
573 /// ```
574 //
575 /// BitCastBitsEnumerator contains the following information:
576 /// ```
577 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
578 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
579 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
580 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
581 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
582 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
583 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
584 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
585 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
586 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
587 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
588 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
589 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
590 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
591 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
592 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
593 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
594 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
595 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
596 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
597 /// ```
598 ///
599 /// In the above, each row represents one target vector element and each
600 /// column represents one bit contribution from a source vector element.
601 /// The algorithm creates vector.shuffle operations (in this case there are 3
602 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
603 /// algorithm populates the bits as follows:
604 /// ```
605 /// src bits 0 ...
606 /// 1st shuffle |xxxxx |xx |...
607 /// 2nd shuffle | xxx| xxxxx |...
608 /// 3rd shuffle | | x|...
609 /// ```
610 //
611 /// The algorithm proceeds as follows:
612 /// 1. for each vector.shuffle, collect the source vectors that participate in
613 /// this shuffle. One source vector per target element of the resulting
614 /// vector.shuffle. If there is no source element contributing bits for the
615 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
616 /// 2 columns).
617 /// 2. represent the bitrange in the source vector as a mask. If there is no
618 /// source element contributing bits for the current vector.shuffle, take 0.
619 /// 3. shift right by the proper amount to align the source bitrange at
620 /// position 0. This is exactly the low end of the bitrange. For instance,
621 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
622 /// shift right by 3 to get the bits contributed by the source element #1
623 /// into position 0.
624 /// 4. shift left by the proper amount to to align to the desired position in
625 /// the result element vector. For instance, the contribution of the second
626 /// source element for the first row needs to be shifted by `5` to form the
627 /// first i8 result element.
628 ///
629 /// Eventually, we end up building the sequence
630 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
631 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
632 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
633 struct BitCastRewriter {
634  /// Helper metadata struct to hold the static quantities for the rewrite.
635  struct Metadata {
636  SmallVector<int64_t> shuffles;
637  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
638  };
639 
640  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
641 
642  /// Verify that general preconditions for the rewrite are met.
643  LogicalResult commonPrecondition(PatternRewriter &rewriter,
644  VectorType preconditionType, Operation *op);
645 
646  /// Precompute the metadata for the rewrite.
648  precomputeMetadata(IntegerType shuffledElementType);
649 
650  /// Rewrite one step of the sequence:
651  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
652  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
653  Value initialValue, Value runningResult,
654  const BitCastRewriter::Metadata &metadata);
655 
656 private:
657  /// Underlying enumerator that encodes the provenance of the bits in the each
658  /// element of the result vector.
659  BitCastBitsEnumerator enumerator;
660 };
661 
662 } // namespace
663 
664 [[maybe_unused]] static raw_ostream &
665 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
666  for (const auto &l : vec) {
667  for (auto it : llvm::enumerate(l)) {
668  os << "{ " << it.value().sourceElementIdx << ": b@["
669  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
670  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
671  }
672  os << "\n";
673  }
674  return os;
675 }
676 
677 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
678  VectorType targetVectorType)
679  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
680 
681  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
682  "requires -D non-scalable vector type");
683  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
684  "requires -D non-scalable vector type");
685  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
686  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
687  LDBG("sourceVectorType: " << sourceVectorType);
688 
689  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
690  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
691  LDBG("targetVectorType: " << targetVectorType);
692 
693  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
694  (void)mostMinorSourceDim;
695  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
696  "source and target bitwidths must match");
697 
698  // Prepopulate one source element range per target element.
699  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
700  for (int64_t resultBit = 0; resultBit < bitwidth;) {
701  int64_t resultElement = resultBit / targetBitWidth;
702  int64_t resultBitInElement = resultBit % targetBitWidth;
703  int64_t sourceElementIdx = resultBit / sourceBitWidth;
704  int64_t sourceBitInElement = resultBit % sourceBitWidth;
705  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
706  targetBitWidth - resultBitInElement);
707  sourceElementRanges[resultElement].push_back(
708  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
709  resultBit += step;
710  }
711 }
712 
713 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
714  VectorType targetVectorType)
715  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
716  LDBG("\n" << enumerator.sourceElementRanges);
717 }
718 
719 /// Verify that the precondition type meets the common preconditions for any
720 /// conversion.
721 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
722  VectorType preconditionType,
723  Operation *op) {
724  if (!preconditionType || preconditionType.isScalable())
725  return rewriter.notifyMatchFailure(op, "scalable vector");
726 
727  // TODO: consider relaxing this restriction in the future if we find ways
728  // to really work with subbyte elements across the MLIR/LLVM boundary.
729  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
730  if (bitwidth % 8 != 0)
731  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
732 
733  return success();
734 }
735 
736 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
737  VectorType preconditionType,
738  Operation *op) {
739  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
740  return rewriter.notifyMatchFailure(op, "types are not vector");
741 
742  if (!preconditionType || preconditionType.getRank() != 1)
743  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
744 
745  return commonConversionPrecondition(rewriter, preconditionType, op);
746 }
747 
748 /// Verify that source and destination element types meet the precondition for
749 /// the supported aligned conversion cases. Alignment means that the either the
750 /// source element type is multiple of the destination element type or the other
751 /// way around.
752 ///
753 /// NOTE: This method assumes that common conversion preconditions are met.
754 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
755  VectorType srcType,
756  VectorType dstType,
757  Operation *op) {
758  if (!srcType || !dstType)
759  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
760  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
761  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
762 
763  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
764  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
765  (dstElemBitwidth % srcElemBitwidth) != 0)
766  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
767 
768  if ((srcType.getShape().back() % 2) != 0)
769  return rewriter.notifyMatchFailure(
770  op, "Not an even number of i4 elements in trailing dim");
771 
772  return success();
773 }
774 
776 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
778  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
779  shuffleIdx < e; ++shuffleIdx) {
780  SmallVector<int64_t> shuffles;
781  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
782 
783  // Create the attribute quantities for the shuffle / mask / shift ops.
784  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
785  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
786  ? srcEltRangeList[shuffleIdx].sourceElementIdx
787  : 0;
788  shuffles.push_back(sourceElement);
789 
790  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
791  ? srcEltRangeList[shuffleIdx].sourceBitBegin
792  : 0;
793  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
794  ? srcEltRangeList[shuffleIdx].sourceBitEnd
795  : 0;
796  IntegerAttr mask = IntegerAttr::get(
797  shuffledElementType,
798  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
799  bitLo, bitHi));
800  masks.push_back(mask);
801 
802  int64_t shiftRight = bitLo;
803  shiftRightAmounts.push_back(
804  IntegerAttr::get(shuffledElementType, shiftRight));
805 
806  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
807  shiftLeftAmounts.push_back(
808  IntegerAttr::get(shuffledElementType, shiftLeft));
809  }
810 
811  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
812  }
813  return result;
814 }
815 
816 Value BitCastRewriter::genericRewriteStep(
817  PatternRewriter &rewriter, Location loc, Value initialValue,
818  Value runningResult, const BitCastRewriter::Metadata &metadata) {
819  // Create vector.shuffle from the metadata.
820  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
821  loc, initialValue, initialValue, metadata.shuffles);
822 
823  // Intersect with the mask.
824  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
825  auto constOp = rewriter.create<arith::ConstantOp>(
826  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
827  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
828 
829  // Align right on 0.
830  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
831  loc,
832  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
833  Value shiftedRight =
834  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
835 
836  // Shift bits left into their final position.
837  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
838  loc,
839  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
840  Value shiftedLeft =
841  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
842 
843  runningResult =
844  runningResult
845  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
846  : shiftedLeft;
847 
848  return runningResult;
849 }
850 
851 /// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
852 /// bitwise ops that take advantage of high-level information to avoid leaving
853 /// LLVM to scramble with peephole optimizations.
855  Value srcValue) {
856  VectorType srcVecType = cast<VectorType>(srcValue.getType());
857  assert(srcVecType.getElementType().isSignlessInteger(4) &&
858  "Expected i4 type");
859 
860  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
861  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
862  constexpr int64_t i4Toi8BitwidthFactor = 2;
863  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
864  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
865  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
866 
867  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
868  // byte are place in one vector and the high i4 elements in another vector.
869  constexpr int8_t bitsToShift = 4;
870  auto shiftValues = rewriter.create<arith::ConstantOp>(
871  loc, DenseElementsAttr::get(i8VecType, bitsToShift));
872  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
873  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
874  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
875 
876  // 3. Interleave low and high i8 elements.
877  return rewriter.create<vector::InterleaveOp>(loc, low, high);
878 }
879 
880 /// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
881 /// bitwise ops that take advantage of high-level information to avoid leaving
882 /// LLVM to scramble with peephole optimizations.
884  Value srcValue) {
885  VectorType srcVecType = cast<VectorType>(srcValue.getType());
886  assert(srcVecType.getElementType().isSignlessInteger(4) &&
887  "Expected i4 type");
888 
889  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
890  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
891  constexpr int64_t i4Toi8BitwidthFactor = 2;
892  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
893  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
894  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
895 
896  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
897  // byte are placed in one vector and the high i4 elements in another vector.
898  constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
899  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
900  loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
901  Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
902  lowBitsMaskValues);
903  constexpr int8_t highBitsToShift = 4;
904  auto highShiftValues = rewriter.create<arith::ConstantOp>(
905  loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
906  Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
907 
908  // 3. Interleave low and high i8 elements.
909  return rewriter.create<vector::InterleaveOp>(loc, low, high);
910 }
911 
912 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
913 /// ops that take advantage of high-level information to avoid leaving LLVM to
914 /// scramble with peephole optimizations.
916  Value srcValue) {
917  VectorType srcVecType = cast<VectorType>(srcValue.getType());
918  assert(srcVecType.getElementType().isSignlessInteger(8) &&
919  "Expected i8 type");
920 
921  // 1. De-interleave low and high i8 elements.
922  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
923 
924  // 2. Zero out the upper side of each low i8 element.
925  constexpr int8_t i8LowBitMask = 0x0F;
926  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
927  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
928  loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
929  Value zeroOutLow = rewriter.create<arith::AndIOp>(
930  loc, deinterleaveOp.getRes1(), zeroOutMask);
931 
932  // 3. Move high i4 values to upper side of the byte.
933  constexpr int8_t bitsToShift = 4;
934  auto shiftValues = rewriter.create<arith::ConstantOp>(
935  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
936  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
937  shiftValues);
938 
939  // 4. Merge high and low i4 values.
940  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
941 
942  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
943  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
944  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
945 }
946 
947 namespace {
948 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
949 /// advantage of high-level information to avoid leaving LLVM to scramble with
950 /// peephole optimizations.
951 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
953 
954  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
955  PatternRewriter &rewriter) const override {
956  // The source must be a trunc op.
957  auto truncOp =
958  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
959  if (!truncOp)
960  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
961 
962  // Set up the BitCastRewriter and verify the precondition.
963  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
964  VectorType targetVectorType = bitCastOp.getResultVectorType();
965  BitCastRewriter bcr(sourceVectorType, targetVectorType);
966  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
967  return failure();
968 
969  // Perform the rewrite.
970  Value truncValue = truncOp.getIn();
971  auto shuffledElementType =
972  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
973  Value runningResult;
974  for (const BitCastRewriter ::Metadata &metadata :
975  bcr.precomputeMetadata(shuffledElementType)) {
976  runningResult = bcr.genericRewriteStep(
977  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
978  }
979 
980  // Finalize the rewrite.
981  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
982  shuffledElementType.getIntOrFloatBitWidth();
983  if (narrowing) {
984  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
985  rewriter.replaceOp(bitCastOp, runningResult);
986  } else {
987  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
988  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
989  }
990  } else {
991  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
992  rewriter.replaceOp(bitCastOp, runningResult);
993  } else {
994  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
995  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
996  }
997  }
998 
999  return success();
1000  }
1001 };
1002 } // namespace
1003 
1004 //===----------------------------------------------------------------------===//
1005 // RewriteExtOfBitCast
1006 //===----------------------------------------------------------------------===//
1007 
1008 namespace {
1009 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1010 /// take advantage of high-level information to avoid leaving LLVM to scramble
1011 /// with peephole optimizations.
1012 template <typename ExtOpType>
1013 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1015 
1016  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1017  : OpRewritePattern<ExtOpType>(context, benefit) {}
1018 
1019  LogicalResult matchAndRewrite(ExtOpType extOp,
1020  PatternRewriter &rewriter) const override {
1021  // The source must be a bitcast op.
1022  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1023  if (!bitCastOp)
1024  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1025 
1026  // Set up the BitCastRewriter and verify the precondition.
1027  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1028  VectorType targetVectorType = bitCastOp.getResultVectorType();
1029  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1030  if (failed(bcr.commonPrecondition(
1031  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1032  return failure();
1033 
1034  // Perform the rewrite.
1035  Value runningResult;
1036  Value sourceValue = bitCastOp.getSource();
1037  auto shuffledElementType =
1038  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1039  for (const BitCastRewriter::Metadata &metadata :
1040  bcr.precomputeMetadata(shuffledElementType)) {
1041  runningResult = bcr.genericRewriteStep(
1042  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1043  }
1044 
1045  // Finalize the rewrite.
1046  bool narrowing =
1047  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1048  shuffledElementType.getIntOrFloatBitWidth();
1049  if (narrowing) {
1050  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1051  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1052  } else {
1053  rewriter.replaceOpWithNewOp<ExtOpType>(
1054  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1055  }
1056 
1057  return success();
1058  }
1059 };
1060 
1061 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1062 /// bitwise ops that take advantage of high-level information to avoid leaving
1063 /// LLVM to scramble with peephole optimizations. Templated to choose between
1064 /// signed and unsigned conversions.
1065 ///
1066 /// For example (signed):
1067 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
1068 /// is rewriten as
1069 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1070 /// %1 = arith.shli %0, 4 : vector<4xi8>
1071 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1072 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1073 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1074 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1075 ///
1076 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1077 /// is rewriten as
1078 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1079 /// %1 = arith.shli %0, 4 : vector<4xi8>
1080 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1081 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1082 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1083 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1084 ///
1085 /// Example (unsigned):
1086 /// arith.extui %in : vector<8xi4> to vector<8xi32>
1087 /// is rewritten as
1088 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1089 /// %1 = arith.andi %0, 15 : vector<4xi8>
1090 /// %2 = arith.shrui %0, 4 : vector<4xi8>
1091 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1092 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1093 ///
1094 template <typename ConversionOpType, bool isSigned>
1095 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1097 
1098  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1099  PatternRewriter &rewriter) const override {
1100  // Verify the preconditions.
1101  Value srcValue = conversionOp.getIn();
1102  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1103  auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1104 
1105  if (failed(
1106  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1107  return failure();
1108 
1109  // Check general alignment preconditions.
1110  if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1111  conversionOp)))
1112  return failure();
1113 
1114  // Perform the rewrite.
1115  Value subByteExt;
1116  if (isSigned) {
1117  subByteExt =
1118  rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1119  } else {
1120  subByteExt =
1121  rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1122  }
1123 
1124  // Finalize the rewrite.
1125  rewriter.replaceOpWithNewOp<ConversionOpType>(
1126  conversionOp, conversionOp.getType(), subByteExt);
1127  return success();
1128  }
1129 };
1130 
1131 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1132 /// bitwise ops that take advantage of high-level information to avoid leaving
1133 /// LLVM to scramble with peephole optimizations.
1134 ///
1135 /// For example:
1136 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
1137 /// is rewriten as
1138 ///
1139 /// %cst = arith.constant dense<15> : vector<4xi8>
1140 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
1141 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1142 /// %2 = arith.andi %0, %cst : vector<4xi8>
1143 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1144 /// %4 = arith.ori %2, %3 : vector<4xi8>
1145 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1146 ///
1147 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1149 
1150  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1151  PatternRewriter &rewriter) const override {
1152  // Verify the preconditions.
1153  Value srcValue = truncOp.getIn();
1154  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1155  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1156  if (!srcVecType || !dstVecType)
1157  return failure();
1158 
1159  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1160  return failure();
1161 
1162  // Check general alignment preconditions. We invert the src/dst type order
1163  // to reuse the existing precondition logic.
1164  if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1165  truncOp)))
1166  return failure();
1167 
1168  // Create a new iX -> i8 truncation op.
1169  Location loc = truncOp.getLoc();
1170  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1171  Value i8TruncVal =
1172  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1173 
1174  // Rewrite the i8 -> i4 truncation part.
1175  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1176 
1177  // Finalize the rewrite.
1178  rewriter.replaceOp(truncOp, subByteTrunc);
1179  return success();
1180  }
1181 };
1182 
1183 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
1184 /// perform the transpose on wider (byte) element types.
1185 /// For example:
1186 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1187 ///
1188 /// is rewritten as:
1189 ///
1190 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1191 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1192 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1193 ///
1194 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1196 
1197  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1198  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1199 
1200  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1201  PatternRewriter &rewriter) const override {
1202  // Precondition: sub-byte integer transpose.
1203  constexpr unsigned minNativeBitwidth = 8;
1204  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1205  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1206  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1207  return rewriter.notifyMatchFailure(transposeOp,
1208  "not a sub-byte transpose");
1209  }
1210 
1211  // Perform the rewrite.
1212  Location loc = transposeOp.getLoc();
1213  // Signed/unsigned interpretation shouldn't matter here as we are just
1214  // transposing the elements and truncating them back to the original size.
1215  // TODO: Use unsigned extension (more efficient) when emulation or backend
1216  // support is available.
1217  auto srcNativeVecType = srcSubByteVecType.cloneWith(
1218  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1219  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1220  transposeOp.getVector());
1221  Value newTranspose = rewriter.create<vector::TransposeOp>(
1222  loc, extOp, transposeOp.getPermutation());
1223  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1224  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1225  newTranspose);
1226  return success();
1227  }
1228 };
1229 
1230 } // namespace
1231 
1232 //===----------------------------------------------------------------------===//
1233 // Public Interface Definition
1234 //===----------------------------------------------------------------------===//
1235 
1237  const arith::NarrowTypeEmulationConverter &typeConverter,
1238  RewritePatternSet &patterns) {
1239 
1240  // Populate `vector.*` conversion patterns.
1241  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1242  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1243  typeConverter, patterns.getContext());
1244 }
1245 
1247  RewritePatternSet &patterns, PatternBenefit benefit) {
1248  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1249  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
1250  benefit);
1251 
1252  // Patterns for aligned cases. We set higher priority as they are expected to
1253  // generate better performance for aligned cases.
1254  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1255  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
1256  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
1257  benefit.getBenefit() + 1);
1258  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
1259  patterns.getContext(), benefit.getBenefit() + 1);
1260 }
1261 
1263  RewritePatternSet &patterns, PatternBenefit benefit) {
1264  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
1265 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
#define LDBG(X)
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale)
Returns a compressed mask.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:907
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI4Type()
Definition: Builders.cpp:101
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:355
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
IntegerType getI8Type()
Definition: Builders.cpp:103
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329