MLIR  18.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include <cstdint>
27 
28 using namespace mlir;
29 
30 #define DEBUG_TYPE "vector-narrow-type-emulation"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
34 
35 namespace {
36 
37 //===----------------------------------------------------------------------===//
38 // ConvertVectorStore
39 //===----------------------------------------------------------------------===//
40 
41 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
43 
45  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
46  ConversionPatternRewriter &rewriter) const override {
47 
48  auto loc = op.getLoc();
49  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
50  Type oldElementType = op.getValueToStore().getType().getElementType();
51  Type newElementType = convertedType.getElementType();
52  int srcBits = oldElementType.getIntOrFloatBitWidth();
53  int dstBits = newElementType.getIntOrFloatBitWidth();
54 
55  if (dstBits % srcBits != 0) {
56  return rewriter.notifyMatchFailure(
57  op, "only dstBits % srcBits == 0 supported");
58  }
59  int scale = dstBits / srcBits;
60 
61  // Adjust the number of elements to store when emulating narrow types.
62  // Here only the 1-D vector store is considered, and the N-D memref types
63  // should be linearized.
64  // For example, to emulate i4 to i8, the following op:
65  //
66  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
67  //
68  // can be replaced with
69  //
70  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
71  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
72  // vector<4xi8>
73 
74  auto origElements = op.getValueToStore().getType().getNumElements();
75  if (origElements % scale != 0)
76  return failure();
77 
78  auto stridedMetadata =
79  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
80 
81  OpFoldResult linearizedIndices;
82  std::tie(std::ignore, linearizedIndices) =
84  rewriter, loc, srcBits, dstBits,
85  stridedMetadata.getConstifiedMixedOffset(),
86  stridedMetadata.getConstifiedMixedSizes(),
87  stridedMetadata.getConstifiedMixedStrides(),
88  getAsOpFoldResult(adaptor.getIndices()));
89 
90  auto numElements = origElements / scale;
91  auto bitCast = rewriter.create<vector::BitCastOp>(
92  loc, VectorType::get(numElements, newElementType),
93  op.getValueToStore());
94 
95  rewriter.replaceOpWithNewOp<vector::StoreOp>(
96  op, bitCast.getResult(), adaptor.getBase(),
97  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
98  return success();
99  }
100 };
101 
102 //===----------------------------------------------------------------------===//
103 // ConvertVectorLoad
104 //===----------------------------------------------------------------------===//
105 
106 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
108 
110  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
111  ConversionPatternRewriter &rewriter) const override {
112 
113  auto loc = op.getLoc();
114  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
115  Type oldElementType = op.getType().getElementType();
116  Type newElementType = convertedType.getElementType();
117  int srcBits = oldElementType.getIntOrFloatBitWidth();
118  int dstBits = newElementType.getIntOrFloatBitWidth();
119 
120  if (dstBits % srcBits != 0) {
121  return rewriter.notifyMatchFailure(
122  op, "only dstBits % srcBits == 0 supported");
123  }
124  int scale = dstBits / srcBits;
125 
126  // Adjust the number of elements to load when emulating narrow types,
127  // and then cast back to the original type with vector.bitcast op.
128  // Here only the 1-D vector load is considered, and the N-D memref types
129  // should be linearized.
130  // For example, to emulate i4 to i8, the following op:
131  //
132  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
133  //
134  // can be replaced with
135  //
136  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
137  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
138  //
139  // TODO: Currently, only the even number of elements loading is supported.
140  // To deal with the odd number of elements, one has to extract the
141  // subvector at the proper offset after bit-casting.
142 
143  auto origElements = op.getVectorType().getNumElements();
144  if (origElements % scale != 0)
145  return failure();
146 
147  auto stridedMetadata =
148  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
149 
150  OpFoldResult linearizedIndices;
151  std::tie(std::ignore, linearizedIndices) =
153  rewriter, loc, srcBits, dstBits,
154  stridedMetadata.getConstifiedMixedOffset(),
155  stridedMetadata.getConstifiedMixedSizes(),
156  stridedMetadata.getConstifiedMixedStrides(),
157  getAsOpFoldResult(adaptor.getIndices()));
158 
159  auto numElements = (origElements + scale - 1) / scale;
160  auto newLoad = rewriter.create<vector::LoadOp>(
161  loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
162  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
163 
164  auto bitCast =
165  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
166 
167  rewriter.replaceOp(op, bitCast->getResult(0));
168  return success();
169  }
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // ConvertVectorMaskedLoad
174 //===----------------------------------------------------------------------===//
175 
176 struct ConvertVectorMaskedLoad final
177  : OpConversionPattern<vector::MaskedLoadOp> {
179 
181  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
182  ConversionPatternRewriter &rewriter) const override {
183 
184  auto loc = op.getLoc();
185  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
186  Type oldElementType = op.getType().getElementType();
187  Type newElementType = convertedType.getElementType();
188  int srcBits = oldElementType.getIntOrFloatBitWidth();
189  int dstBits = newElementType.getIntOrFloatBitWidth();
190 
191  if (dstBits % srcBits != 0) {
192  return rewriter.notifyMatchFailure(
193  op, "only dstBits % srcBits == 0 supported");
194  }
195  int scale = dstBits / srcBits;
196 
197  // Adjust the number of elements to load when emulating narrow types,
198  // and then cast back to the original type with vector.bitcast op.
199  // For example, to emulate i4 to i8, the following op:
200  //
201  // %mask = vector.constant_mask [3] : vector<6xi1>
202  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
203  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
204  //
205  // can be replaced with
206  //
207  // %new_mask = vector.constant_mask [2] : vector<3xi1>
208  // %new_pass_thru = vector.bitcast %pass_thru :
209  // vector<6xi4> to vector<3xi8>
210  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
211  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
212  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
213  //
214  // Since we are effectively loading 16 bits (2xi8) from the memref with the
215  // new mask, while originally we only wanted to effectively load 12 bits
216  // (3xi4) from the memref, we need to set the second half of the last i8
217  // that was effectively loaded (i.e. the second i8) to %pass_thru.
218  //
219  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
220  //
221  // Given these input values:
222  // %mask = [1, 1, 1, 0, 0, 0]
223  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
224  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
225  //
226  // we'll have:
227  //
228  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
229  //
230  // %new_mask = [1, 1, 0]
231  // %new_pass_thru = [0x78, 0x9A, 0xBC]
232  // %1 = [0x12, 0x34, 0xBC]
233  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
234  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
235  //
236  // TODO: Currently, only the even number of elements loading is supported.
237  // To deal with the odd number of elements, one has to extract the
238  // subvector at the proper offset after bit-casting.
239 
240  auto origType = op.getVectorType();
241  auto origElements = origType.getNumElements();
242  if (origElements % scale != 0)
243  return failure();
244 
245  auto stridedMetadata =
246  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
247 
248  OpFoldResult linearizedIndices;
249  std::tie(std::ignore, linearizedIndices) =
251  rewriter, loc, srcBits, dstBits,
252  stridedMetadata.getConstifiedMixedOffset(),
253  stridedMetadata.getConstifiedMixedSizes(),
254  stridedMetadata.getConstifiedMixedStrides(),
255  getAsOpFoldResult(adaptor.getIndices()));
256 
257  auto numElements = (origElements + scale - 1) / scale;
258  auto newType = VectorType::get(numElements, newElementType);
259 
260  auto maskOp = op.getMask().getDefiningOp();
262  // Finding the mask creation operation.
263  while (maskOp &&
264  !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
265  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
266  maskOp = extractOp.getVector().getDefiningOp();
267  extractOps.push_back(extractOp);
268  }
269  }
270  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
271  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
272  if (!createMaskOp && !constantMaskOp)
273  return failure();
274 
275  // Computing the "compressed" mask. All the emulation logic (i.e. computing
276  // new mask index) only happens on the last dimension of the vectors.
277  Operation *newMask = nullptr;
278  auto shape = llvm::to_vector(
279  maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
280  shape.push_back(numElements);
281  auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
282  if (createMaskOp) {
283  auto maskOperands = createMaskOp.getOperands();
284  auto numMaskOperands = maskOperands.size();
285  AffineExpr s0;
286  bindSymbols(rewriter.getContext(), s0);
287  s0 = s0 + scale - 1;
288  s0 = s0.floorDiv(scale);
289  OpFoldResult origIndex =
290  getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
291  OpFoldResult maskIndex =
292  affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
293  auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
294  newMaskOperands.push_back(
295  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
296  newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
297  newMaskOperands);
298  } else if (constantMaskOp) {
299  auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
300  auto numMaskOperands = maskDimSizes.size();
301  auto origIndex =
302  cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
303  auto maskIndex =
304  rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
305  auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
306  newMaskDimSizes.push_back(maskIndex);
307  newMask = rewriter.create<vector::ConstantMaskOp>(
308  loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
309  }
310 
311  while (!extractOps.empty()) {
312  newMask = rewriter.create<vector::ExtractOp>(
313  loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
314  extractOps.pop_back();
315  }
316 
317  auto newPassThru =
318  rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
319 
320  // Generating the new masked load.
321  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
322  loc, newType, adaptor.getBase(),
323  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
324  newMask->getResult(0), newPassThru);
325 
326  // Setting the part that originally was not effectively loaded from memory
327  // to pass through.
328  auto bitCast =
329  rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
330  auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
331  op.getPassThru());
332  rewriter.replaceOp(op, select->getResult(0));
333 
334  return success();
335  }
336 };
337 
338 //===----------------------------------------------------------------------===//
339 // ConvertVectorTransferRead
340 //===----------------------------------------------------------------------===//
341 
342 struct ConvertVectorTransferRead final
343  : OpConversionPattern<vector::TransferReadOp> {
345 
347  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
348  ConversionPatternRewriter &rewriter) const override {
349 
350  auto loc = op.getLoc();
351  auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
352  Type oldElementType = op.getType().getElementType();
353  Type newElementType = convertedType.getElementType();
354  int srcBits = oldElementType.getIntOrFloatBitWidth();
355  int dstBits = newElementType.getIntOrFloatBitWidth();
356 
357  if (dstBits % srcBits != 0) {
358  return rewriter.notifyMatchFailure(
359  op, "only dstBits % srcBits == 0 supported");
360  }
361  int scale = dstBits / srcBits;
362 
363  auto origElements = op.getVectorType().getNumElements();
364  if (origElements % scale != 0)
365  return failure();
366 
367  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
368  adaptor.getPadding());
369 
370  auto stridedMetadata =
371  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
372 
373  OpFoldResult linearizedIndices;
374  std::tie(std::ignore, linearizedIndices) =
376  rewriter, loc, srcBits, dstBits,
377  stridedMetadata.getConstifiedMixedOffset(),
378  stridedMetadata.getConstifiedMixedSizes(),
379  stridedMetadata.getConstifiedMixedStrides(),
380  getAsOpFoldResult(adaptor.getIndices()));
381 
382  auto numElements = (origElements + scale - 1) / scale;
383  auto newReadType = VectorType::get(numElements, newElementType);
384 
385  auto newRead = rewriter.create<vector::TransferReadOp>(
386  loc, newReadType, adaptor.getSource(),
387  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
388  newPadding);
389 
390  auto bitCast =
391  rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
392 
393  rewriter.replaceOp(op, bitCast->getResult(0));
394  return success();
395  }
396 };
397 } // end anonymous namespace
398 
399 //===----------------------------------------------------------------------===//
400 // RewriteBitCastOfTruncI
401 //===----------------------------------------------------------------------===//
402 
403 namespace {
404 
405 /// Helper struct to keep track of the provenance of a contiguous set of bits
406 /// in a source vector.
407 struct SourceElementRange {
408  /// The index of the source vector element that contributes bits to *this.
409  int64_t sourceElementIdx;
410  /// The range of bits in the source vector element that contribute to *this.
411  int64_t sourceBitBegin;
412  int64_t sourceBitEnd;
413 };
414 
415 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
416  /// Given the index of a SourceElementRange in the SourceElementRangeList,
417  /// compute the amount of bits that need to be shifted to the left to get the
418  /// bits in their final location. This shift amount is simply the sum of the
419  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
420  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
421  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
422  int64_t res = 0;
423  for (int64_t i = 0; i < shuffleIdx; ++i)
424  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
425  return res;
426  }
427 };
428 
429 /// Helper struct to enumerate the source elements and bit ranges that are
430 /// involved in a bitcast operation.
431 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
432 /// any 1-D vector shape and any source/target bitwidths.
433 /// This creates and holds a mapping of the form:
434 /// [dstVectorElementJ] ==
435 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
436 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
437 /// [0] = {0, [0-8)}
438 /// [1] = {0, [8-16)}
439 /// [2] = {0, [16-24)}
440 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
441 /// [0] = {0, [0, 10)}, {1, [0, 5)}
442 /// [1] = {1, [5, 10)}, {2, [0, 10)}
443 struct BitCastBitsEnumerator {
444  BitCastBitsEnumerator(VectorType sourceVectorType,
445  VectorType targetVectorType);
446 
447  int64_t getMaxNumberOfEntries() {
448  int64_t numVectors = 0;
449  for (const auto &l : sourceElementRanges)
450  numVectors = std::max(numVectors, (int64_t)l.size());
451  return numVectors;
452  }
453 
454  VectorType sourceVectorType;
455  VectorType targetVectorType;
456  SmallVector<SourceElementRangeList> sourceElementRanges;
457 };
458 
459 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
460 /// advantage of high-level information to avoid leaving LLVM to scramble with
461 /// peephole optimizations.
462 /// BitCastBitsEnumerator encodes for each element of the target vector the
463 /// provenance of the bits in the source vector. We can "transpose" this
464 /// information to build a sequence of shuffles and bitwise ops that will
465 /// produce the desired result.
466 //
467 /// Consider the following motivating example:
468 /// ```
469 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
470 /// ```
471 //
472 /// BitCastBitsEnumerator contains the following information:
473 /// ```
474 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
475 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
476 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
477 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
478 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
479 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
480 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
481 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
482 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
483 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
484 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
485 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
486 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
487 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
488 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
489 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
490 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
491 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
492 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
493 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
494 /// ```
495 ///
496 /// In the above, each row represents one target vector element and each
497 /// column represents one bit contribution from a source vector element.
498 /// The algorithm creates vector.shuffle operations (in this case there are 3
499 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
500 /// algorithm populates the bits as follows:
501 /// ```
502 /// src bits 0 ...
503 /// 1st shuffle |xxxxx |xx |...
504 /// 2nd shuffle | xxx| xxxxx |...
505 /// 3rd shuffle | | x|...
506 /// ```
507 //
508 /// The algorithm proceeds as follows:
509 /// 1. for each vector.shuffle, collect the source vectors that participate in
510 /// this shuffle. One source vector per target element of the resulting
511 /// vector.shuffle. If there is no source element contributing bits for the
512 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
513 /// 2 columns).
514 /// 2. represent the bitrange in the source vector as a mask. If there is no
515 /// source element contributing bits for the current vector.shuffle, take 0.
516 /// 3. shift right by the proper amount to align the source bitrange at
517 /// position 0. This is exactly the low end of the bitrange. For instance,
518 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
519 /// shift right by 3 to get the bits contributed by the source element #1
520 /// into position 0.
521 /// 4. shift left by the proper amount to to align to the desired position in
522 /// the result element vector. For instance, the contribution of the second
523 /// source element for the first row needs to be shifted by `5` to form the
524 /// first i8 result element.
525 ///
526 /// Eventually, we end up building the sequence
527 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
528 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
529 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
530 struct BitCastRewriter {
531  /// Helper metadata struct to hold the static quantities for the rewrite.
532  struct Metadata {
533  SmallVector<int64_t> shuffles;
534  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
535  };
536 
537  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
538 
539  /// Verify that the preconditions for the rewrite are met.
540  LogicalResult precondition(PatternRewriter &rewriter,
541  VectorType preconditionVectorType, Operation *op);
542 
543  /// Precompute the metadata for the rewrite.
545  precomputeMetadata(IntegerType shuffledElementType);
546 
547  /// Rewrite one step of the sequence:
548  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
549  Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
550  Value runningResult,
551  const BitCastRewriter::Metadata &metadata);
552 
553 private:
554  /// Underlying enumerator that encodes the provenance of the bits in the each
555  /// element of the result vector.
556  BitCastBitsEnumerator enumerator;
557 };
558 
559 } // namespace
560 
561 [[maybe_unused]] static raw_ostream &operator<<(raw_ostream &os,
563  for (const auto &l : vec) {
564  for (auto it : llvm::enumerate(l)) {
565  os << "{ " << it.value().sourceElementIdx << ": b@["
566  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
567  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
568  }
569  os << "\n";
570  }
571  return os;
572 }
573 
574 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
575  VectorType targetVectorType)
576  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
577 
578  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
579  "requires -D non-scalable vector type");
580  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
581  "requires -D non-scalable vector type");
582  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
583  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
584  LDBG("sourceVectorType: " << sourceVectorType);
585 
586  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
587  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
588  LDBG("targetVectorType: " << targetVectorType);
589 
590  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
591  (void)mostMinorSourceDim;
592  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
593  "source and target bitwidths must match");
594 
595  // Prepopulate one source element range per target element.
596  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
597  for (int64_t resultBit = 0; resultBit < bitwidth;) {
598  int64_t resultElement = resultBit / targetBitWidth;
599  int64_t resultBitInElement = resultBit % targetBitWidth;
600  int64_t sourceElementIdx = resultBit / sourceBitWidth;
601  int64_t sourceBitInElement = resultBit % sourceBitWidth;
602  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
603  targetBitWidth - resultBitInElement);
604  sourceElementRanges[resultElement].push_back(
605  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
606  resultBit += step;
607  }
608 }
609 
610 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
611  VectorType targetVectorType)
612  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
613  LDBG("\n" << enumerator.sourceElementRanges);
614 }
615 
616 LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
617  VectorType precondition,
618  Operation *op) {
619  if (precondition.getRank() != 1 || precondition.isScalable())
620  return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
621 
622  // TODO: consider relaxing this restriction in the future if we find ways
623  // to really work with subbyte elements across the MLIR/LLVM boundary.
624  int64_t resultBitwidth = precondition.getElementTypeBitWidth();
625  if (resultBitwidth % 8 != 0)
626  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
627 
628  return success();
629 }
630 
632 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
634  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
635  shuffleIdx < e; ++shuffleIdx) {
636  SmallVector<int64_t> shuffles;
637  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
638 
639  // Create the attribute quantities for the shuffle / mask / shift ops.
640  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
641  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
642  ? srcEltRangeList[shuffleIdx].sourceElementIdx
643  : 0;
644  shuffles.push_back(sourceElement);
645 
646  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
647  ? srcEltRangeList[shuffleIdx].sourceBitBegin
648  : 0;
649  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
650  ? srcEltRangeList[shuffleIdx].sourceBitEnd
651  : 0;
652  IntegerAttr mask = IntegerAttr::get(
653  shuffledElementType,
654  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
655  bitLo, bitHi));
656  masks.push_back(mask);
657 
658  int64_t shiftRight = bitLo;
659  shiftRightAmounts.push_back(
660  IntegerAttr::get(shuffledElementType, shiftRight));
661 
662  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
663  shiftLeftAmounts.push_back(
664  IntegerAttr::get(shuffledElementType, shiftLeft));
665  }
666 
667  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
668  }
669  return result;
670 }
671 
672 Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
673  Value initialValue, Value runningResult,
674  const BitCastRewriter::Metadata &metadata) {
675  // Create vector.shuffle from the metadata.
676  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
677  loc, initialValue, initialValue, metadata.shuffles);
678 
679  // Intersect with the mask.
680  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
681  auto constOp = rewriter.create<arith::ConstantOp>(
682  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
683  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
684 
685  // Align right on 0.
686  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
687  loc,
688  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
689  Value shiftedRight =
690  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
691 
692  // Shift bits left into their final position.
693  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
694  loc,
695  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
696  Value shiftedLeft =
697  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
698 
699  runningResult =
700  runningResult
701  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
702  : shiftedLeft;
703 
704  return runningResult;
705 }
706 
707 namespace {
708 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
709 /// advantage of high-level information to avoid leaving LLVM to scramble with
710 /// peephole optimizations.
711 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
713 
714  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
715  PatternRewriter &rewriter) const override {
716  // The source must be a trunc op.
717  auto truncOp =
718  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
719  if (!truncOp)
720  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
721 
722  // Set up the BitCastRewriter and verify the precondition.
723  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
724  VectorType targetVectorType = bitCastOp.getResultVectorType();
725  BitCastRewriter bcr(sourceVectorType, targetVectorType);
726  if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
727  return failure();
728 
729  // Perform the rewrite.
730  Value truncValue = truncOp.getIn();
731  auto shuffledElementType =
732  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
733  Value runningResult;
734  for (const BitCastRewriter ::Metadata &metadata :
735  bcr.precomputeMetadata(shuffledElementType)) {
736  runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
737  runningResult, metadata);
738  }
739 
740  // Finalize the rewrite.
741  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
742  shuffledElementType.getIntOrFloatBitWidth();
743  if (narrowing) {
744  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
745  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
746  } else {
747  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
748  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
749  }
750 
751  return success();
752  }
753 };
754 } // namespace
755 
756 //===----------------------------------------------------------------------===//
757 // RewriteExtOfBitCast
758 //===----------------------------------------------------------------------===//
759 
760 namespace {
761 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
762 /// take advantage of high-level information to avoid leaving LLVM to scramble
763 /// with peephole optimizations.
764 template <typename ExtOpType>
765 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
767 
768  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
769  : OpRewritePattern<ExtOpType>(context, benefit) {}
770 
771  LogicalResult matchAndRewrite(ExtOpType extOp,
772  PatternRewriter &rewriter) const override {
773  // The source must be a bitcast op.
774  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
775  if (!bitCastOp)
776  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
777 
778  // Set up the BitCastRewriter and verify the precondition.
779  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
780  VectorType targetVectorType = bitCastOp.getResultVectorType();
781  BitCastRewriter bcr(sourceVectorType, targetVectorType);
782  if (failed(bcr.precondition(
783  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
784  return failure();
785 
786  // Perform the rewrite.
787  Value runningResult;
788  Value sourceValue = bitCastOp.getSource();
789  auto shuffledElementType =
790  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
791  for (const BitCastRewriter::Metadata &metadata :
792  bcr.precomputeMetadata(shuffledElementType)) {
793  runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
794  sourceValue, runningResult, metadata);
795  }
796 
797  // Finalize the rewrite.
798  bool narrowing =
799  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
800  shuffledElementType.getIntOrFloatBitWidth();
801  if (narrowing) {
802  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
803  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
804  } else {
805  rewriter.replaceOpWithNewOp<ExtOpType>(
806  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
807  }
808 
809  return success();
810  }
811 };
812 } // namespace
813 
814 //===----------------------------------------------------------------------===//
815 // Public Interface Definition
816 //===----------------------------------------------------------------------===//
817 
820  RewritePatternSet &patterns) {
821 
822  // Populate `vector.*` conversion patterns.
823  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
824  ConvertVectorTransferRead>(typeConverter, patterns.getContext());
825 }
826 
828  RewritePatternSet &patterns, PatternBenefit benefit) {
829  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
830  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
831  benefit);
832 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:867
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:50
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:348
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:328