MLIR  21.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4
11 // ("emulated type"), using wider types, e.g. i8 ("container type").
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
41 using namespace mlir;
42 
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
47 
50 
51 //===----------------------------------------------------------------------===//
52 // Utils
53 //===----------------------------------------------------------------------===//
54 
55 /// Returns a compressed mask for the emulated vector. For example, when
56 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
57 /// elements span two dest elements), this method compresses `vector<8xi1>`
58 /// into `vector<2xi1>`.
59 ///
60 /// The compressed/output mask value is set iff any mask in the corresponding
61 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
62 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
63 /// following mask:
64 ///
65 /// %mask = [1, 1, 0, 0, 0, 0]
66 ///
67 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
68 /// will be added in the back to make the number of elements a multiple of
69 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
70 ///
71 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
72 ///
73 /// then it will return the following new compressed mask:
74 ///
75 /// %mask = [1, 1, 0, 0]
76 ///
77 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
78 /// `numSrcElemsPerDest`.
79 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
80  Location loc, Value mask,
81  int numSrcElems,
82  int numSrcElemsPerDest,
83  int numFrontPadElems = 0) {
84 
85  assert(numFrontPadElems < numSrcElemsPerDest &&
86  "numFrontPadElems must be less than numSrcElemsPerDest");
87 
88  auto numDestElems =
89  (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
90  numSrcElemsPerDest;
91 
92  Operation *maskOp = mask.getDefiningOp();
94  // TODO: add support to `vector.splat`.
95  // Finding the mask creation operation.
96  while (maskOp &&
97  !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98  maskOp)) {
99  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
100  maskOp = extractOp.getVector().getDefiningOp();
101  extractOps.push_back(extractOp);
102  }
103  }
104 
105  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
106  maskOp))
107  return failure();
108 
109  // Computing the "compressed" mask. All the emulation logic (i.e. computing
110  // new mask index) only happens on the last dimension of the vectors.
111  SmallVector<int64_t> maskShape(
112  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
113  maskShape.back() = numDestElems;
114  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
115  std::optional<Operation *> newMask =
117  .Case<vector::CreateMaskOp>(
118  [&](auto createMaskOp) -> std::optional<Operation *> {
119  OperandRange maskOperands = createMaskOp.getOperands();
120  // The `vector.create_mask` op creates a mask arrangement
121  // without any zeros at the front. Also, because
122  // `numFrontPadElems` is strictly smaller than
123  // `numSrcElemsPerDest`, the compressed mask generated by
124  // padding the original mask by `numFrontPadElems` will not
125  // have any zeros at the front as well.
126  AffineExpr s0;
127  bindSymbols(rewriter.getContext(), s0);
128  s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
129  OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
131  rewriter, loc, s0, origIndex);
132  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
133  newMaskOperands.push_back(
134  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
135  return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
136  newMaskOperands);
137  })
138  .Case<vector::ConstantMaskOp>(
139  [&](auto constantMaskOp) -> std::optional<Operation *> {
140  // Take the shape of mask, compress its trailing dimension:
141  SmallVector<int64_t> maskDimSizes(
142  constantMaskOp.getMaskDimSizes());
143  int64_t &maskIndex = maskDimSizes.back();
144  maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
145  numSrcElemsPerDest);
146  return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
147  maskDimSizes);
148  })
149  .Case<arith::ConstantOp>([&](auto constantOp)
150  -> std::optional<Operation *> {
151  // TODO: Support multiple dimensions.
152  if (maskShape.size() != 1)
153  return std::nullopt;
154  // Rearrange the original mask values to cover the whole potential
155  // loading region. For example, in the case of using byte-size for
156  // emulation, given the following mask:
157  //
158  // %mask = [0, 1, 0, 1, 0, 0]
159  //
160  // With front offset of 1, the mask will be padded 0s in the front
161  // and back so that:
162  // 1. It is aligned with the effective loading bits
163  // 2. Its length is multiple of `numSrcElemPerDest` (and the total
164  // coverage size is mulitiple of bytes). The new mask will be like
165  // this before compressing:
166  //
167  // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
168  auto originalMask =
169  cast<DenseIntElementsAttr>(constantOp.getValue());
170  SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
171  paddedMaskValues.append(originalMask.template value_begin<bool>(),
172  originalMask.template value_end<bool>());
173  paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
174 
175  // Compressing by combining every `numSrcElemsPerDest` elements:
176  SmallVector<bool> compressedMaskValues;
177  for (size_t i = 0; i < paddedMaskValues.size();
178  i += numSrcElemsPerDest) {
179  bool combinedValue = false;
180  for (int j = 0; j < numSrcElemsPerDest; ++j) {
181  combinedValue |= paddedMaskValues[i + j];
182  }
183  compressedMaskValues.push_back(combinedValue);
184  }
185  return rewriter.create<arith::ConstantOp>(
186  loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
187  });
188 
189  if (!newMask)
190  return failure();
191 
192  while (!extractOps.empty()) {
193  newMask = rewriter.create<vector::ExtractOp>(
194  loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
195  extractOps.pop_back();
196  }
197 
198  return *newMask;
199 }
200 
201 /// Extracts 1-D subvector from a 1-D vector.
202 ///
203 /// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
204 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
205 ///
206 /// vector<numElemsToExtract x !elemType>
207 ///
208 /// (`!elType` is the element type of the source vector). As `offset` is a known
209 /// _static_ value, this helper hook emits `vector.extract_strided_slice`.
210 ///
211 /// EXAMPLE:
212 /// %res = vector.extract_strided_slice %src
213 /// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
215  Value src, int64_t offset,
216  int64_t numElemsToExtract) {
217  auto vectorType = cast<VectorType>(src.getType());
218  assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
219  assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220  "subvector out of bounds");
221 
222  // When extracting all available elements, just use the source vector as the
223  // result.
224  if (vectorType.getNumElements() == numElemsToExtract)
225  return src;
226 
227  auto offsets = rewriter.getI64ArrayAttr({offset});
228  auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
229  auto strides = rewriter.getI64ArrayAttr({1});
230 
231  auto resultVectorType =
232  VectorType::get({numElemsToExtract}, vectorType.getElementType());
233  return rewriter
234  .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
235  offsets, sizes, strides)
236  ->getResult(0);
237 }
238 
239 /// Inserts 1-D subvector into a 1-D vector.
240 ///
241 /// Inserts the input rank-1 source vector into the destination vector starting
242 /// at `offset`. As `offset` is a known _static_ value, this helper hook emits
243 /// `vector.insert_strided_slice`.
244 ///
245 /// EXAMPLE:
246 /// %res = vector.insert_strided_slice %src, %dest
247 /// {offsets = [%offset], strides [1]}
249  Value src, Value dest, int64_t offset) {
250  [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
251  [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
252  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253  "expected source and dest to be rank-1 vector types");
254 
255  // If overwritting the destination vector, just return the source.
256  if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
257  return src;
258 
259  auto offsets = rewriter.getI64ArrayAttr({offset});
260  auto strides = rewriter.getI64ArrayAttr({1});
261  return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
262  dest, offsets, strides);
263 }
264 
265 /// Extracts 1-D subvector from a 1-D vector.
266 ///
267 /// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
269 ///
270 /// vector<numElemsToExtact x !elType>
271 ///
272 /// (`!elType` is the element type of the source vector). As `offset` is assumed
273 /// to be a _dynamic_ SSA value, this helper method generates a sequence of
274 /// `vector.extract` + `vector.insert` pairs.
275 ///
276 /// EXAMPLE:
277 /// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278 /// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279 /// %c1 = arith.constant 1 : index
280 /// %idx2 = arith.addi %offset, %c1 : index
281 /// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282 /// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283 /// (...)
285  Value src, Value dest,
286  OpFoldResult offset,
287  int64_t numElemsToExtract) {
288  auto srcVecTy = cast<VectorType>(src.getType());
289  assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
290  // NOTE: We are unable to take the offset into account in the following
291  // assert, hence its still possible that the subvector is out-of-bounds even
292  // if the condition is true.
293  assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294  "subvector out of bounds");
295 
296  // When extracting all available elements, just use the source vector as the
297  // result.
298  if (srcVecTy.getNumElements() == numElemsToExtract)
299  return src;
300 
301  for (int i = 0; i < numElemsToExtract; ++i) {
302  Value extractLoc =
303  (i == 0) ? dyn_cast<Value>(offset)
304  : rewriter.create<arith::AddIOp>(
305  loc, rewriter.getIndexType(), dyn_cast<Value>(offset),
306  rewriter.create<arith::ConstantIndexOp>(loc, i));
307  auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, extractLoc);
308  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
309  }
310  return dest;
311 }
312 
313 /// Inserts 1-D subvector into a 1-D vector.
314 ///
315 /// Inserts the input rank-1 source vector into the destination vector starting
316 /// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
317 /// uses a sequence of `vector.extract` + `vector.insert` pairs.
318 ///
319 /// EXAMPLE:
320 /// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
321 /// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
322 /// %c1 = arith.constant 1 : index
323 /// %idx2 = arith.addi %offset, %c1 : index
324 /// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
325 /// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
326 /// (...)
328  Value src, Value dest,
329  OpFoldResult offset,
330  int64_t numElemsToInsert) {
331  auto srcVecTy = cast<VectorType>(src.getType());
332  auto destVecTy = cast<VectorType>(dest.getType());
333  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334  "expected source and dest to be rank-1 vector types");
335  (void)srcVecTy;
336  (void)destVecTy;
337  assert(numElemsToInsert > 0 &&
338  "the number of elements to insert must be greater than 0");
339  // NOTE: We are unable to take the offset into account in the following
340  // assert, hence its still possible that the subvector is out-of-bounds even
341  // if the condition is true.
342  assert(numElemsToInsert <= destVecTy.getNumElements() &&
343  "subvector out of bounds");
344 
345  Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
346  for (int64_t i = 0; i < numElemsToInsert; ++i) {
347  auto insertLoc = i == 0
348  ? destOffsetVal
349  : rewriter.create<arith::AddIOp>(
350  loc, rewriter.getIndexType(), destOffsetVal,
351  rewriter.create<arith::ConstantIndexOp>(loc, i));
352  auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, i);
353  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
354  }
355  return dest;
356 }
357 
358 /// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
359 ///
360 /// Specifically, use `containerElemTy` for loading a vector of
361 /// `emulatedElemTy`. The load location is given by `base` and
362 /// `linearizedIndices`, and the load size is given by
363 /// `numEmulatedElementsToLoad`.
365  Value base,
366  OpFoldResult linearizedIndices,
367  int64_t numContainerElemsToLoad,
368  Type emulatedElemTy,
369  Type containerElemTy) {
370  auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
371  emulatedElemTy.getIntOrFloatBitWidth();
372  auto newLoad = rewriter.create<vector::LoadOp>(
373  loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
374  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
375  return rewriter.create<vector::BitCastOp>(
376  loc,
377  VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
378  emulatedElemTy),
379  newLoad);
380 }
381 
382 /// Downcast two values to `downcastType`, then select values
383 /// based on `mask`, and casts the result to `upcastType`.
385  VectorType downcastType,
386  VectorType upcastType, Value mask,
387  Value trueValue, Value falseValue) {
388  assert(
389  downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390  upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391  "expected input and output number of bits to match");
392  if (trueValue.getType() != downcastType) {
393  trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
394  }
395  if (falseValue.getType() != downcastType) {
396  falseValue =
397  builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
398  }
399  Value selectedType =
400  builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
401  // Upcast the selected value to the new type.
402  return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
403 }
404 
405 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
406 /// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
407 /// subbyte-sized elements, with size of 8 bits, and the mask is used to select
408 /// which elements to store.
409 ///
410 /// Inputs:
411 /// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
412 /// storeIdx = 2
413 /// valueToStore = |3|3|3|3| : vector<4xi2>
414 /// mask = |0|0|1|1| : vector<4xi1>
415 ///
416 /// Result:
417 /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
418 static void atomicRMW(OpBuilder &builder, Location loc,
419  MemRefValue linearizedMemref, Value storeIdx,
420  VectorValue valueToStore, Value mask) {
421  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
422 
423  // Create an atomic load-modify-write region using
424  // `memref.generic_atomic_rmw`.
425  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
426  loc, linearizedMemref, ValueRange{storeIdx});
427  Value origValue = atomicOp.getCurrentValue();
428 
429  OpBuilder::InsertionGuard guard(builder);
430  builder.setInsertionPointToStart(atomicOp.getBody());
431 
432  // Load the original value from memory, and cast it to the original element
433  // type.
434  auto oneElemVecType = VectorType::get({1}, origValue.getType());
435  Value origVecValue = builder.create<vector::FromElementsOp>(
436  loc, oneElemVecType, ValueRange{origValue});
437 
438  // Construct the final masked value and yield it.
439  Value maskedValue =
440  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
441  oneElemVecType, mask, valueToStore, origVecValue);
442  auto scalarMaskedValue =
443  builder.create<vector::ExtractOp>(loc, maskedValue, 0);
444  builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
445 }
446 
447 /// Generate a non-atomic read-modify-write sequence for storing to the emulated
448 /// type. It has similar logic to `atomicRMWStore`, but without atomicity.
449 static void nonAtomicRMW(OpBuilder &builder, Location loc,
450  MemRefValue linearizedMemref, Value linearizedIndex,
451  VectorValue valueToStore, Value mask) {
452  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
453 
454  auto oneElemVecType =
455  VectorType::get({1}, linearizedMemref.getType().getElementType());
456  Value origVecValue = builder.create<vector::LoadOp>(
457  loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
458  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
459  origVecValue);
460 
461  Value maskedValue =
462  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
463  oneElemVecType, mask, valueToStore, origVecValue);
464  builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
465  linearizedIndex);
466 }
467 
468 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
469 /// and insert it into an empty vector at `insertOffset`.
470 /// Inputs:
471 /// vec_in = |0|1|2|3| : vector<4xi2>
472 /// extractOffset = 1
473 /// sliceNumElements = 2
474 /// insertOffset = 2
475 /// Output:
476 /// vec_out = |0|0|1|2| : vector<4xi2>
478  Location loc, VectorValue vector,
479  int64_t extractOffset,
480  int64_t sliceNumElements,
481  int64_t insertOffset) {
482  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
483  auto vectorElementType = vector.getType().getElementType();
484  // TODO: update and use `alignedConversionPrecondition` in the place of
485  // these asserts.
486  assert(
487  sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
488  "sliceNumElements * vector element size must be less than or equal to 8");
489  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
490  "vector element must be a valid sub-byte type");
491  auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
492  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
493  loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
494  rewriter.getZeroAttr(
495  VectorType::get({emulatedPerContainerElem}, vectorElementType)));
496  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
497  extractOffset, sliceNumElements);
498  return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
499  insertOffset);
500 }
501 
502 namespace {
503 
504 //===----------------------------------------------------------------------===//
505 // ConvertVectorStore
506 //===----------------------------------------------------------------------===//
507 
508 // Emulate `vector.store` using a multi-byte container type.
509 //
510 // The container type is obtained through Op adaptor and would normally be
511 // generated via `NarrowTypeEmulationConverter`.
512 //
513 // EXAMPLE 1
514 // (aligned store of i4, emulated using i8 as the container type)
515 //
516 // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
517 //
518 // is rewritten as:
519 //
520 // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
521 // vector.store %src_bitcast, %dest_bitcast[%idx]
522 // : memref<16xi8>, vector<4xi8>
523 //
524 // EXAMPLE 2
525 // (unaligned store of i2, emulated using i8 as the container type)
526 //
527 // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
528 //
529 // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
530 // is modelled using 3 bytes:
531 //
532 // Byte 0 Byte 1 Byte 2
533 // +----------+----------+----------+
534 // | oooooooo | ooooNNNN | NNoooooo |
535 // +----------+----------+----------+
536 //
537 // N - (N)ew entries (i.e. to be overwritten by vector.store)
538 // o - (o)ld entries (to be preserved)
539 //
540 // For the generated output in the non-atomic case, see:
541 // * @vector_store_i2_const_index_two_partial_stores`
542 // in:
543 // * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
544 //
545 // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
546 // `false` to generate non-atomic RMW sequences.
547 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
549 
550  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
551  : OpConversionPattern<vector::StoreOp>(context),
552  disableAtomicRMW(disableAtomicRMW) {}
553 
554  LogicalResult
555  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
556  ConversionPatternRewriter &rewriter) const override {
557 
558  // See #115653
559  if (op.getValueToStore().getType().getRank() != 1)
560  return rewriter.notifyMatchFailure(op,
561  "only 1-D vectors are supported ATM");
562 
563  auto loc = op.getLoc();
564 
565  auto valueToStore = cast<VectorValue>(op.getValueToStore());
566  auto containerElemTy =
567  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
568  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
569  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
570  int containerBits = containerElemTy.getIntOrFloatBitWidth();
571 
572  // Check per-element alignment.
573  if (containerBits % emulatedBits != 0) {
574  return rewriter.notifyMatchFailure(
575  op, "impossible to pack emulated elements into container elements "
576  "(bit-wise misalignment)");
577  }
578  int emulatedPerContainerElem = containerBits / emulatedBits;
579 
580  // Adjust the number of elements to store when emulating narrow types.
581  // Here only the 1-D vector store is considered, and the N-D memref types
582  // should be linearized.
583  // For example, to emulate i4 to i8, the following op:
584  //
585  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
586  //
587  // can be replaced with
588  //
589  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
590  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
591  // vector<4xi8>
592 
593  auto origElements = valueToStore.getType().getNumElements();
594  // Note, per-element-alignment was already verified above.
595  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
596  // Do the trailing dim for source and destination match? If yes, then the
597  // corresponding index must be 0.
598  // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
599  // However, that makes some tests fail, so we need to audit first.
600  auto trailingDim = op.getBase().getType().getShape().back();
601  bool trailingDimsMatch =
602  ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
603 
604  auto stridedMetadata =
605  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
606 
607  // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
608  // non-zero. As such, this is not needed.
609  OpFoldResult linearizedIndices;
610  memref::LinearizedMemRefInfo linearizedInfo;
611  std::tie(linearizedInfo, linearizedIndices) =
613  rewriter, loc, emulatedBits, containerBits,
614  stridedMetadata.getConstifiedMixedOffset(),
615  stridedMetadata.getConstifiedMixedSizes(),
616  stridedMetadata.getConstifiedMixedStrides(),
617  getAsOpFoldResult(adaptor.getIndices()));
618 
619  std::optional<int64_t> foldedNumFrontPadElems =
620  (isDivisibleInSize && trailingDimsMatch)
621  ? 0
622  : getConstantIntValue(linearizedInfo.intraDataOffset);
623 
624  if (!foldedNumFrontPadElems) {
625  return rewriter.notifyMatchFailure(
626  op, "subbyte store emulation: dynamic front padding size is "
627  "not yet implemented");
628  }
629 
630  auto memrefBase = cast<MemRefValue>(adaptor.getBase());
631 
632  // RMWs are not needed when:
633  // * no _partial_ stores are required.
634  // A partial store is defined as a store in which only a part of the
635  // container element is overwritten, e.g.
636  //
637  // Dest before (8 bits)
638  // +----------+
639  // | 11000000 |
640  // +----------+
641  //
642  // Dest after storing 0xF at offset 4 (in bits)
643  // +----------+
644  // | 11001111 |
645  // +----------+
646  //
647  // At a higher level, this translats to:
648  // 1. The source vector size (in bits) is a multiple of byte size.
649  // 2. The address of the store is aligned to the container type width
650  // boundary.
651  //
652  // EXAMPLE 1:
653  // Requires partial store:
654  // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
655  //
656  // EXAMPLE 2:
657  // Does not require a partial store:
658  // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
659  //
660  // TODO: Take linearizedInfo.linearizedOffset into account. This is
661  // currently not needed/used/exercised as all our tests set offset to 0.
662  bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
663 
664  if (!emulationRequiresPartialStores) {
665  // Basic case: storing full bytes.
666  auto numElements = origElements / emulatedPerContainerElem;
667  auto bitCast = rewriter.create<vector::BitCastOp>(
668  loc, VectorType::get(numElements, containerElemTy),
669  op.getValueToStore());
670  rewriter.replaceOpWithNewOp<vector::StoreOp>(
671  op, bitCast.getResult(), memrefBase,
672  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
673  return success();
674  }
675 
676  // Next, handle the case when sub-byte read-modify-write
677  // sequences are needed to emulate a vector store.
678  // Here is an example:
679  //
680  // Vector to store: vector<7xi2>
681  // Value to store: 11 11 11 11 11 11 11 (all ones)
682  //
683  // Destination: memref<12xi2>
684  // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
685  //
686  // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
687  //
688  // Destination memref before:
689  //
690  // Byte 0 Byte 1 Byte 2
691  // +----------+----------+----------+
692  // | 00000000 | 00000000 | 00000000 |
693  // +----------+----------+----------+
694  //
695  // Destination memref after:
696  //
697  // Byte 0 Byte 1 Byte 2
698  // +----------+----------+----------+
699  // | 00001111 | 11111111 | 11000000 |
700  // +----------+----------+----------+
701  //
702  // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
703  // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
704  // requiring RMW access (atomicity is required).
705 
706  // The index into the target memref we are storing to.
707  Value currentDestIndex =
708  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
709  // The index into the source vector we are currently processing.
710  auto currentSourceIndex = 0;
711 
712  // Build a mask used for rmw.
713  auto subWidthStoreMaskType =
714  VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
715 
716  auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
717 
718  // 1. Partial width store for the leading byte.
719  // When the store address is not aligned to emulated width boundary, deal
720  // with the unaligned part so that the rest elements are aligned to width
721  // boundary.
722  auto frontSubWidthStoreElem =
723  (emulatedPerContainerElem - *foldedNumFrontPadElems) %
724  emulatedPerContainerElem;
725  if (frontSubWidthStoreElem > 0) {
726  SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
727  if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
728  std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
729  origElements, true);
730  frontSubWidthStoreElem = origElements;
731  } else {
732  std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
733  *foldedNumFrontPadElems, true);
734  }
735  auto frontMask = rewriter.create<arith::ConstantOp>(
736  loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
737 
738  currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
739  auto value =
740  extractSliceIntoByte(rewriter, loc, valueToStore, 0,
741  frontSubWidthStoreElem, *foldedNumFrontPadElems);
742 
743  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
744  cast<VectorValue>(value), frontMask.getResult());
745  }
746 
747  if (currentSourceIndex >= origElements) {
748  rewriter.eraseOp(op);
749  return success();
750  }
751 
752  // Increment the destination index by 1 to align to the emulated width
753  // boundary.
754  auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
755  currentDestIndex = rewriter.create<arith::AddIOp>(
756  loc, rewriter.getIndexType(), currentDestIndex, constantOne);
757 
758  // 2. Full width store for the inner output bytes.
759  // After the previous step, the store address is aligned to the emulated
760  // width boundary.
761  int64_t fullWidthStoreSize =
762  (origElements - currentSourceIndex) / emulatedPerContainerElem;
763  int64_t numNonFullWidthElements =
764  fullWidthStoreSize * emulatedPerContainerElem;
765  if (fullWidthStoreSize > 0) {
766  auto fullWidthStorePart = staticallyExtractSubvector(
767  rewriter, loc, valueToStore, currentSourceIndex,
768  numNonFullWidthElements);
769 
770  auto originType = cast<VectorType>(fullWidthStorePart.getType());
771  auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
772  auto storeType = VectorType::get(
773  {originType.getNumElements() / emulatedPerContainerElem},
774  memrefElemType);
775  auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
776  fullWidthStorePart);
777  rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
778  currentDestIndex);
779 
780  currentSourceIndex += numNonFullWidthElements;
781  currentDestIndex = rewriter.create<arith::AddIOp>(
782  loc, rewriter.getIndexType(), currentDestIndex,
783  rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
784  }
785 
786  // 3. Partial width store for the trailing output byte.
787  // It is needed when the residual length is smaller than the emulated width,
788  // which is not covered in step 2 above.
789  auto remainingElements = origElements - currentSourceIndex;
790  if (remainingElements != 0) {
791  auto subWidthStorePart =
792  extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
793  currentSourceIndex, remainingElements, 0);
794 
795  // Generate back mask.
796  auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
797  std::fill_n(maskValues.begin(), remainingElements, 1);
798  auto backMask = rewriter.create<arith::ConstantOp>(
799  loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
800 
801  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
802  cast<VectorValue>(subWidthStorePart), backMask.getResult());
803  }
804 
805  rewriter.eraseOp(op);
806  return success();
807  }
808 
809 private:
810  const bool disableAtomicRMW;
811 };
812 
813 //===----------------------------------------------------------------------===//
814 // ConvertVectorMaskedStore
815 //===----------------------------------------------------------------------===//
816 
817 // TODO: Document-me
818 struct ConvertVectorMaskedStore final
819  : OpConversionPattern<vector::MaskedStoreOp> {
821 
822  LogicalResult
823  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
824  ConversionPatternRewriter &rewriter) const override {
825 
826  // See #115653
827  if (op.getValueToStore().getType().getRank() != 1)
828  return rewriter.notifyMatchFailure(op,
829  "only 1-D vectors are supported ATM");
830 
831  auto loc = op.getLoc();
832  auto containerElemTy =
833  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
834  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
835  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
836  int containerBits = containerElemTy.getIntOrFloatBitWidth();
837 
838  // Check per-element alignment.
839  if (containerBits % emulatedBits != 0) {
840  return rewriter.notifyMatchFailure(
841  op, "impossible to pack emulated elements into container elements "
842  "(bit-wise misalignment)");
843  }
844 
845  int emulatedPerContainerElem = containerBits / emulatedBits;
846  int origElements = op.getValueToStore().getType().getNumElements();
847  if (origElements % emulatedPerContainerElem != 0)
848  return failure();
849 
850  auto stridedMetadata =
851  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
852  OpFoldResult linearizedIndicesOfr;
853  memref::LinearizedMemRefInfo linearizedInfo;
854  std::tie(linearizedInfo, linearizedIndicesOfr) =
856  rewriter, loc, emulatedBits, containerBits,
857  stridedMetadata.getConstifiedMixedOffset(),
858  stridedMetadata.getConstifiedMixedSizes(),
859  stridedMetadata.getConstifiedMixedStrides(),
860  getAsOpFoldResult(adaptor.getIndices()));
861  Value linearizedIndices =
862  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
863 
864  // Load the whole data and use arith.select to handle the corner cases.
865  //
866  // As an example, for this masked store of i4 values:
867  //
868  // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
869  //
870  // and given these input values:
871  //
872  // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
873  // %0[%c0, %c0] =
874  // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
875  // %val_to_store =
876  // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
877  //
878  // we'll have the following i4 output:
879  //
880  // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
881  //
882  // Emulating the above using i8 will give:
883  //
884  // %compressed_mask = [1, 1, 1, 0] (4 * i1)
885  // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
886  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
887  // %select_using_shifted_mask =
888  // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
889  // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
890  //
891  // Using the compressed mask to store %packed_data results in expected
892  // output.
893  //
894  // FIXME: Make an example based on the comment above work (see #115460 for
895  // reproducer).
896  FailureOr<Operation *> newMask = getCompressedMaskOp(
897  rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
898  if (failed(newMask))
899  return failure();
900 
901  auto numElements = (origElements + emulatedPerContainerElem - 1) /
902  emulatedPerContainerElem;
903  auto newType = VectorType::get(numElements, containerElemTy);
904  auto passThru = rewriter.create<arith::ConstantOp>(
905  loc, newType, rewriter.getZeroAttr(newType));
906 
907  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
908  loc, newType, adaptor.getBase(), linearizedIndices,
909  newMask.value()->getResult(0), passThru);
910 
911  auto newBitCastType =
912  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
913  Value valueToStore =
914  rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
915  valueToStore = rewriter.create<arith::SelectOp>(
916  loc, op.getMask(), op.getValueToStore(), valueToStore);
917  valueToStore =
918  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
919 
920  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
921  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
922  valueToStore);
923  return success();
924  }
925 };
926 
927 //===----------------------------------------------------------------------===//
928 // ConvertVectorLoad
929 //===----------------------------------------------------------------------===//
930 
931 // TODO: Document-me
932 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
934 
935  LogicalResult
936  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
937  ConversionPatternRewriter &rewriter) const override {
938 
939  // See #115653
940  if (op.getVectorType().getRank() != 1)
941  return rewriter.notifyMatchFailure(op,
942  "only 1-D vectors are supported ATM");
943 
944  auto loc = op.getLoc();
945  auto containerElemTy =
946  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
947  Type emulatedElemTy = op.getType().getElementType();
948  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
949  int containerBits = containerElemTy.getIntOrFloatBitWidth();
950 
951  // Check per-element alignment.
952  if (containerBits % emulatedBits != 0) {
953  return rewriter.notifyMatchFailure(
954  op, "impossible to pack emulated elements into container elements "
955  "(bit-wise misalignment)");
956  }
957  int emulatedPerContainerElem = containerBits / emulatedBits;
958 
959  // Adjust the number of elements to load when emulating narrow types,
960  // and then cast back to the original type with vector.bitcast op.
961  // Here only the 1-D vector load is considered, and the N-D memref types
962  // should be linearized.
963  // For example, to emulate i4 to i8, the following op:
964  //
965  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
966  //
967  // can be replaced with
968  //
969  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
970  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
971  //
972  // There are cases where the number of elements to load is not byte-aligned,
973  // for example:
974  //
975  // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
976  //
977  // we will have to load extra bytes and extract the exact slice in between.
978  //
979  // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
980  // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
981  // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
982  // = [1]}
983  // : vector<8xi2> to vector<3xi2>
984  //
985  // TODO: Currently the extract_strided_slice's attributes must be known at
986  // compile time as they must be constants.
987 
988  auto origElements = op.getVectorType().getNumElements();
989  // Note, per-element-alignment was already verified above.
990  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
991 
992  auto stridedMetadata =
993  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
994 
995  OpFoldResult linearizedIndices;
996  memref::LinearizedMemRefInfo linearizedInfo;
997  std::tie(linearizedInfo, linearizedIndices) =
999  rewriter, loc, emulatedBits, containerBits,
1000  stridedMetadata.getConstifiedMixedOffset(),
1001  stridedMetadata.getConstifiedMixedSizes(),
1002  stridedMetadata.getConstifiedMixedStrides(),
1003  getAsOpFoldResult(adaptor.getIndices()));
1004 
1005  std::optional<int64_t> foldedIntraVectorOffset =
1006  isDivisibleInSize ? 0
1007  : getConstantIntValue(linearizedInfo.intraDataOffset);
1008 
1009  // Always load enough elements which can cover the original elements.
1010  int64_t maxintraDataOffset =
1011  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1012  auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1013  emulatedPerContainerElem);
1014  Value result =
1015  emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1016  numElements, emulatedElemTy, containerElemTy);
1017 
1018  if (!foldedIntraVectorOffset) {
1019  auto resultVector = rewriter.create<arith::ConstantOp>(
1020  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1021  result = dynamicallyExtractSubVector(
1022  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
1023  linearizedInfo.intraDataOffset, origElements);
1024  } else if (!isDivisibleInSize) {
1025  result = staticallyExtractSubvector(
1026  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1027  }
1028  rewriter.replaceOp(op, result);
1029  return success();
1030  }
1031 };
1032 
1033 //===----------------------------------------------------------------------===//
1034 // ConvertVectorMaskedLoad
1035 //===----------------------------------------------------------------------===//
1036 
1037 // TODO: Document-me
1038 struct ConvertVectorMaskedLoad final
1039  : OpConversionPattern<vector::MaskedLoadOp> {
1041 
1042  LogicalResult
1043  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1044  ConversionPatternRewriter &rewriter) const override {
1045  // See #115653
1046  if (op.getVectorType().getRank() != 1)
1047  return rewriter.notifyMatchFailure(op,
1048  "only 1-D vectors are supported ATM");
1049 
1050  auto loc = op.getLoc();
1051 
1052  auto containerElemTy =
1053  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1054  Type emulatedElemTy = op.getType().getElementType();
1055  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1056  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1057 
1058  // Check per-element alignment.
1059  if (containerBits % emulatedBits != 0) {
1060  return rewriter.notifyMatchFailure(
1061  op, "impossible to pack emulated elements into container elements "
1062  "(bit-wise misalignment)");
1063  }
1064  int emulatedPerContainerElem = containerBits / emulatedBits;
1065 
1066  // Adjust the number of elements to load when emulating narrow types,
1067  // and then cast back to the original type with vector.bitcast op.
1068  // For example, to emulate i4 to i8, the following op:
1069  //
1070  // %mask = vector.constant_mask [3] : vector<6xi1>
1071  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1072  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1073  //
1074  // can be replaced with
1075  //
1076  // %new_mask = vector.constant_mask [2] : vector<3xi1>
1077  // %new_pass_thru = vector.bitcast %pass_thru :
1078  // vector<6xi4> to vector<3xi8>
1079  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1080  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1081  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1082  //
1083  // Since we are effectively loading 16 bits (2xi8) from the memref with the
1084  // new mask, while originally we only wanted to effectively load 12 bits
1085  // (3xi4) from the memref, we need to set the second half of the last i8
1086  // that was effectively loaded (i.e. the second i8) to %pass_thru.
1087  //
1088  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1089  //
1090  // Given these input values:
1091  // %mask = [1, 1, 1, 0, 0, 0]
1092  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1093  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1094  //
1095  // we'll have:
1096  //
1097  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1098  //
1099  // %new_mask = [1, 1, 0]
1100  // %new_pass_thru = [0x78, 0x9A, 0xBC]
1101  // %1 = [0x12, 0x34, 0xBC]
1102  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1103  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1104  //
1105  // TODO: Currently, only the even number of elements loading is supported.
1106  // To deal with the odd number of elements, one has to extract the
1107  // subvector at the proper offset after bit-casting.
1108  auto origType = op.getVectorType();
1109  auto origElements = origType.getNumElements();
1110  // Note, per-element-alignment was already verified above.
1111  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1112 
1113  auto stridedMetadata =
1114  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
1115  OpFoldResult linearizedIndices;
1116  memref::LinearizedMemRefInfo linearizedInfo;
1117  std::tie(linearizedInfo, linearizedIndices) =
1119  rewriter, loc, emulatedBits, containerBits,
1120  stridedMetadata.getConstifiedMixedOffset(),
1121  stridedMetadata.getConstifiedMixedSizes(),
1122  stridedMetadata.getConstifiedMixedStrides(),
1123  getAsOpFoldResult(adaptor.getIndices()));
1124 
1125  std::optional<int64_t> foldedIntraVectorOffset =
1126  isDivisibleInSize ? 0
1127  : getConstantIntValue(linearizedInfo.intraDataOffset);
1128 
1129  int64_t maxIntraDataOffset =
1130  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1131  FailureOr<Operation *> newMask =
1132  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1133  emulatedPerContainerElem, maxIntraDataOffset);
1134  if (failed(newMask))
1135  return failure();
1136 
1137  Value passthru = op.getPassThru();
1138 
1139  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1140  emulatedPerContainerElem);
1141  auto loadType = VectorType::get(numElements, containerElemTy);
1142  auto newBitcastType =
1143  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1144 
1145  auto emptyVector = rewriter.create<arith::ConstantOp>(
1146  loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1147  if (!foldedIntraVectorOffset) {
1148  passthru = dynamicallyInsertSubVector(
1149  rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1150  origElements);
1151  } else if (!isDivisibleInSize) {
1152  passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1153  *foldedIntraVectorOffset);
1154  }
1155  auto newPassThru =
1156  rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
1157 
1158  // Generating the new masked load.
1159  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
1160  loc, loadType, adaptor.getBase(),
1161  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1162  newMask.value()->getResult(0), newPassThru);
1163 
1164  // Setting the part that originally was not effectively loaded from memory
1165  // to pass through.
1166  auto bitCast =
1167  rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
1168 
1169  Value mask = op.getMask();
1170  auto newSelectMaskType = VectorType::get(
1171  numElements * emulatedPerContainerElem, rewriter.getI1Type());
1172  // TODO: try to fold if op's mask is constant
1173  auto emptyMask = rewriter.create<arith::ConstantOp>(
1174  loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
1175  if (!foldedIntraVectorOffset) {
1176  mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1177  linearizedInfo.intraDataOffset,
1178  origElements);
1179  } else if (!isDivisibleInSize) {
1180  mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1181  *foldedIntraVectorOffset);
1182  }
1183 
1184  Value result =
1185  rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
1186  if (!foldedIntraVectorOffset) {
1187  result = dynamicallyExtractSubVector(
1188  rewriter, loc, result, op.getPassThru(),
1189  linearizedInfo.intraDataOffset, origElements);
1190  } else if (!isDivisibleInSize) {
1191  result = staticallyExtractSubvector(
1192  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1193  }
1194  rewriter.replaceOp(op, result);
1195 
1196  return success();
1197  }
1198 };
1199 
1200 /// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1201 ///
1202 /// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1203 /// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1204 /// (a multi-byte scalar, e.g. i16), where N is some integer.
1205 ///
1206 /// Put differently, this method checks whether this would be valid:
1207 ///
1208 /// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1209 ///
1210 /// EXAMPLES:
1211 /// * vector<4xi4> -> i16 - yes (N = 1)
1212 /// * vector<4xi4> -> i8 - yes (N = 2)
1213 /// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1214 /// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1215 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1216  Type multiByteScalarTy) {
1217  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1218 
1219  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1220  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1221 
1222  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1223  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1224  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1225 
1226  int elemsPerMultiByte = multiByteBits / subByteBits;
1227 
1228  // TODO: This is a bit too restrictive for vectors rank > 1.
1229  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1230 }
1231 
1232 //===----------------------------------------------------------------------===//
1233 // ConvertVectorTransferRead
1234 //===----------------------------------------------------------------------===//
1235 
1236 // TODO: Document-me
1237 struct ConvertVectorTransferRead final
1238  : OpConversionPattern<vector::TransferReadOp> {
1240 
1241  LogicalResult
1242  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1243  ConversionPatternRewriter &rewriter) const override {
1244 
1245  // See #115653
1246  if (op.getVectorType().getRank() != 1)
1247  return rewriter.notifyMatchFailure(op,
1248  "only 1-D vectors are supported ATM");
1249 
1250  auto loc = op.getLoc();
1251  auto containerElemTy =
1252  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1253  Type emulatedElemTy = op.getType().getElementType();
1254  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1255  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1256 
1257  // Check per-element alignment.
1258  if (containerBits % emulatedBits != 0) {
1259  return rewriter.notifyMatchFailure(
1260  op, "impossible to pack emulated elements into container elements "
1261  "(bit-wise misalignment)");
1262  }
1263  int emulatedPerContainerElem = containerBits / emulatedBits;
1264 
1265  auto origElements = op.getVectorType().getNumElements();
1266 
1267  // Note, per-element-alignment was already verified above.
1268  bool isDivisibleInSize =
1269  fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1270 
1271  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
1272  adaptor.getPadding());
1273 
1274  auto stridedMetadata =
1275  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
1276 
1277  OpFoldResult linearizedIndices;
1278  memref::LinearizedMemRefInfo linearizedInfo;
1279  std::tie(linearizedInfo, linearizedIndices) =
1281  rewriter, loc, emulatedBits, containerBits,
1282  stridedMetadata.getConstifiedMixedOffset(),
1283  stridedMetadata.getConstifiedMixedSizes(),
1284  stridedMetadata.getConstifiedMixedStrides(),
1285  getAsOpFoldResult(adaptor.getIndices()));
1286 
1287  std::optional<int64_t> foldedIntraVectorOffset =
1288  isDivisibleInSize ? 0
1289  : getConstantIntValue(linearizedInfo.intraDataOffset);
1290 
1291  int64_t maxIntraDataOffset =
1292  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1293  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1294  emulatedPerContainerElem);
1295 
1296  auto newRead = rewriter.create<vector::TransferReadOp>(
1297  loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(),
1298  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1299  newPadding);
1300 
1301  auto bitCast = rewriter.create<vector::BitCastOp>(
1302  loc,
1303  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1304  newRead);
1305 
1306  Value result = bitCast->getResult(0);
1307  if (!foldedIntraVectorOffset) {
1308  auto zeros = rewriter.create<arith::ConstantOp>(
1309  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1310  result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1311  linearizedInfo.intraDataOffset,
1312  origElements);
1313  } else if (!isDivisibleInSize) {
1314  result = staticallyExtractSubvector(
1315  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1316  }
1317  rewriter.replaceOp(op, result);
1318 
1319  return success();
1320  }
1321 };
1322 } // end anonymous namespace
1323 
1324 //===----------------------------------------------------------------------===//
1325 // RewriteBitCastOfTruncI
1326 //===----------------------------------------------------------------------===//
1327 
1328 namespace {
1329 
1330 /// Helper struct to keep track of the provenance of a contiguous set of bits
1331 /// in a source vector.
1332 struct SourceElementRange {
1333  /// The index of the source vector element that contributes bits to *this.
1334  int64_t sourceElementIdx;
1335  /// The range of bits in the source vector element that contribute to *this.
1336  int64_t sourceBitBegin;
1337  int64_t sourceBitEnd;
1338 };
1339 
1340 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1341  /// Given the index of a SourceElementRange in the SourceElementRangeList,
1342  /// compute the amount of bits that need to be shifted to the left to get the
1343  /// bits in their final location. This shift amount is simply the sum of the
1344  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1345  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1346  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1347  int64_t res = 0;
1348  for (int64_t i = 0; i < shuffleIdx; ++i)
1349  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1350  return res;
1351  }
1352 };
1353 
1354 /// Helper struct to enumerate the source elements and bit ranges that are
1355 /// involved in a bitcast operation.
1356 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1357 /// any 1-D vector shape and any source/target bitwidths.
1358 /// This creates and holds a mapping of the form:
1359 /// [dstVectorElementJ] ==
1360 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1361 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1362 /// [0] = {0, [0-8)}
1363 /// [1] = {0, [8-16)}
1364 /// [2] = {0, [16-24)}
1365 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1366 /// [0] = {0, [0, 10)}, {1, [0, 5)}
1367 /// [1] = {1, [5, 10)}, {2, [0, 10)}
1368 struct BitCastBitsEnumerator {
1369  BitCastBitsEnumerator(VectorType sourceVectorType,
1370  VectorType targetVectorType);
1371 
1372  int64_t getMaxNumberOfEntries() {
1373  int64_t numVectors = 0;
1374  for (const auto &l : sourceElementRanges)
1375  numVectors = std::max(numVectors, (int64_t)l.size());
1376  return numVectors;
1377  }
1378 
1379  VectorType sourceVectorType;
1380  VectorType targetVectorType;
1381  SmallVector<SourceElementRangeList> sourceElementRanges;
1382 };
1383 
1384 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1385 /// advantage of high-level information to avoid leaving LLVM to scramble with
1386 /// peephole optimizations.
1387 /// BitCastBitsEnumerator encodes for each element of the target vector the
1388 /// provenance of the bits in the source vector. We can "transpose" this
1389 /// information to build a sequence of shuffles and bitwise ops that will
1390 /// produce the desired result.
1391 //
1392 /// Consider the following motivating example:
1393 /// ```
1394 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1395 /// ```
1396 //
1397 /// BitCastBitsEnumerator contains the following information:
1398 /// ```
1399 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1400 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1401 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1402 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1403 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1404 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1405 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1406 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1407 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1408 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1409 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1410 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1411 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1412 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1413 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1414 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1415 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1416 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1417 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1418 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1419 /// ```
1420 ///
1421 /// In the above, each row represents one target vector element and each
1422 /// column represents one bit contribution from a source vector element.
1423 /// The algorithm creates vector.shuffle operations (in this case there are 3
1424 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1425 /// algorithm populates the bits as follows:
1426 /// ```
1427 /// src bits 0 ...
1428 /// 1st shuffle |xxxxx |xx |...
1429 /// 2nd shuffle | xxx| xxxxx |...
1430 /// 3rd shuffle | | x|...
1431 /// ```
1432 //
1433 /// The algorithm proceeds as follows:
1434 /// 1. for each vector.shuffle, collect the source vectors that participate in
1435 /// this shuffle. One source vector per target element of the resulting
1436 /// vector.shuffle. If there is no source element contributing bits for the
1437 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1438 /// 2 columns).
1439 /// 2. represent the bitrange in the source vector as a mask. If there is no
1440 /// source element contributing bits for the current vector.shuffle, take 0.
1441 /// 3. shift right by the proper amount to align the source bitrange at
1442 /// position 0. This is exactly the low end of the bitrange. For instance,
1443 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1444 /// shift right by 3 to get the bits contributed by the source element #1
1445 /// into position 0.
1446 /// 4. shift left by the proper amount to to align to the desired position in
1447 /// the result element vector. For instance, the contribution of the second
1448 /// source element for the first row needs to be shifted by `5` to form the
1449 /// first i8 result element.
1450 ///
1451 /// Eventually, we end up building the sequence
1452 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1453 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1454 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1455 struct BitCastRewriter {
1456  /// Helper metadata struct to hold the static quantities for the rewrite.
1457  struct Metadata {
1458  SmallVector<int64_t> shuffles;
1459  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1460  };
1461 
1462  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1463 
1464  /// Verify that general preconditions for the rewrite are met.
1465  LogicalResult commonPrecondition(PatternRewriter &rewriter,
1466  VectorType preconditionType, Operation *op);
1467 
1468  /// Precompute the metadata for the rewrite.
1470  precomputeMetadata(IntegerType shuffledElementType);
1471 
1472  /// Rewrite one step of the sequence:
1473  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1474  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1475  Value initialValue, Value runningResult,
1476  const BitCastRewriter::Metadata &metadata);
1477 
1478 private:
1479  /// Underlying enumerator that encodes the provenance of the bits in the each
1480  /// element of the result vector.
1481  BitCastBitsEnumerator enumerator;
1482 };
1483 
1484 } // namespace
1485 
1486 [[maybe_unused]] static raw_ostream &
1487 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
1488  for (const auto &l : vec) {
1489  for (auto it : llvm::enumerate(l)) {
1490  os << "{ " << it.value().sourceElementIdx << ": b@["
1491  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1492  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1493  }
1494  os << "\n";
1495  }
1496  return os;
1497 }
1498 
1499 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1500  VectorType targetVectorType)
1501  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1502 
1503  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1504  "requires -D non-scalable vector type");
1505  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1506  "requires -D non-scalable vector type");
1507  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1508  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1509  LDBG("sourceVectorType: " << sourceVectorType);
1510 
1511  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1512  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1513  LDBG("targetVectorType: " << targetVectorType);
1514 
1515  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1516  (void)mostMinorSourceDim;
1517  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1518  "source and target bitwidths must match");
1519 
1520  // Prepopulate one source element range per target element.
1521  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1522  for (int64_t resultBit = 0; resultBit < bitwidth;) {
1523  int64_t resultElement = resultBit / targetBitWidth;
1524  int64_t resultBitInElement = resultBit % targetBitWidth;
1525  int64_t sourceElementIdx = resultBit / sourceBitWidth;
1526  int64_t sourceBitInElement = resultBit % sourceBitWidth;
1527  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1528  targetBitWidth - resultBitInElement);
1529  sourceElementRanges[resultElement].push_back(
1530  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1531  resultBit += step;
1532  }
1533 }
1534 
1535 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1536  VectorType targetVectorType)
1537  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1538  LDBG("\n" << enumerator.sourceElementRanges);
1539 }
1540 
1541 /// Verify that the precondition type meets the common preconditions for any
1542 /// conversion.
1543 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1544  VectorType preconditionType,
1545  Operation *op) {
1546  if (!preconditionType || preconditionType.isScalable())
1547  return rewriter.notifyMatchFailure(op, "scalable vector");
1548 
1549  // TODO: consider relaxing this restriction in the future if we find ways
1550  // to really work with subbyte elements across the MLIR/LLVM boundary.
1551  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1552  if (bitwidth % 8 != 0)
1553  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1554 
1555  return success();
1556 }
1557 
1558 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1559  VectorType preconditionType,
1560  Operation *op) {
1561  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1562  return rewriter.notifyMatchFailure(op, "types are not vector");
1563 
1564  if (!preconditionType || preconditionType.getRank() != 1)
1565  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1566 
1567  return commonConversionPrecondition(rewriter, preconditionType, op);
1568 }
1569 
1570 /// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1571 ///
1572 /// Alignment means that `subByteVecTy` can be packed into a vector of
1573 /// `containerTy` elements. More specifically:
1574 /// 1. The bit-width of `containerTy` is a multiple of the
1575 /// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1576 /// this multiple is 4.
1577 /// 2. The multiple from 1. above divides evenly the number of the (trailing)
1578 /// elements in `subByteVecTy`.
1579 ///
1580 /// EXAMPLE 1:
1581 /// `subByteVecTy = vector<2xi4>`, and
1582 /// `containerTy = i16`
1583 ///
1584 /// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1585 ///
1586 /// EXAMPLE 2:
1587 /// `subByteVecTy = vector<3xi4>`, and
1588 /// `containerTy = i16`
1589 ///
1590 /// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1591 ///
1592 /// EXAMPLE 3:
1593 /// `subByteVecTy = vector<3xi3>`, and
1594 /// `containerTy = i16`
1595 ///
1596 /// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1597 ///
1598 /// NOTE: This method assumes that common conversion preconditions are met. In
1599 /// particular, `containerTy` is assumed to be a
1600 /// multi-byte scalar type (e.g., i8, i16, i32).
1601 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1602  VectorType subByteVecTy,
1603  Type containerTy,
1604  Operation *op) {
1605  assert(containerTy.isIntOrFloat() &&
1606  "container element type is not a scalar");
1607 
1608  // TODO: This is validating the inputs rather than checking the conditions
1609  // documented above. Replace with an assert.
1610  if (!subByteVecTy)
1611  return rewriter.notifyMatchFailure(op, "not a vector!");
1612 
1613  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1614  unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1615 
1616  // Enforced by the common pre-conditions.
1617  assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1618 
1619  // TODO: Add support other widths (when/if needed)
1620  if (subByteBits != 2 && subByteBits != 4)
1621  return rewriter.notifyMatchFailure(
1622  op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1623 
1624  // Condition 1 ("per-element" alignment)
1625  if (containerBits % subByteBits != 0)
1626  return rewriter.notifyMatchFailure(op, "unalagined element types");
1627 
1628  // Condition 2 ("full" alignment)
1629  if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1630  return rewriter.notifyMatchFailure(
1631  op, "not possible to fit this sub-byte vector type into a vector of "
1632  "the given multi-byte type");
1633 
1634  return success();
1635 }
1636 
1638 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1640  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1641  shuffleIdx < e; ++shuffleIdx) {
1642  SmallVector<int64_t> shuffles;
1643  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1644 
1645  // Create the attribute quantities for the shuffle / mask / shift ops.
1646  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1647  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1648  ? srcEltRangeList[shuffleIdx].sourceElementIdx
1649  : 0;
1650  shuffles.push_back(sourceElement);
1651 
1652  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1653  ? srcEltRangeList[shuffleIdx].sourceBitBegin
1654  : 0;
1655  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1656  ? srcEltRangeList[shuffleIdx].sourceBitEnd
1657  : 0;
1658  IntegerAttr mask = IntegerAttr::get(
1659  shuffledElementType,
1660  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1661  bitLo, bitHi));
1662  masks.push_back(mask);
1663 
1664  int64_t shiftRight = bitLo;
1665  shiftRightAmounts.push_back(
1666  IntegerAttr::get(shuffledElementType, shiftRight));
1667 
1668  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1669  shiftLeftAmounts.push_back(
1670  IntegerAttr::get(shuffledElementType, shiftLeft));
1671  }
1672 
1673  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1674  }
1675  return result;
1676 }
1677 
1678 Value BitCastRewriter::genericRewriteStep(
1679  PatternRewriter &rewriter, Location loc, Value initialValue,
1680  Value runningResult, const BitCastRewriter::Metadata &metadata) {
1681  // Create vector.shuffle from the metadata.
1682  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
1683  loc, initialValue, initialValue, metadata.shuffles);
1684 
1685  // Intersect with the mask.
1686  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1687  auto constOp = rewriter.create<arith::ConstantOp>(
1688  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1689  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
1690 
1691  // Align right on 0.
1692  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
1693  loc,
1694  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1695  Value shiftedRight =
1696  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1697 
1698  // Shift bits left into their final position.
1699  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
1700  loc,
1701  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1702  Value shiftedLeft =
1703  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1704 
1705  runningResult =
1706  runningResult
1707  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1708  : shiftedLeft;
1709 
1710  return runningResult;
1711 }
1712 
1713 /// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1714 /// Where aligned means it satisfies the alignedConversionPreconditions.
1715 ///
1716 /// Example:
1717 /// vector<16x16xi2> -> vector<16x4xi8>
1718 /// vector<16x16xi4> -> vector<16x8xi8>
1720  Value subByteVec) {
1721  auto srcVecType = cast<VectorType>(subByteVec.getType());
1722  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1723  assert(8 % srcBitwidth == 0 &&
1724  "Unsupported sub-byte type (not a divisor of i8)");
1725  int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1726  SmallVector<int64_t> vecShape(srcVecType.getShape());
1727  // Adjust last dimension of the vector, so the total size remains the same.
1728  vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1729  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1730  return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1731 }
1732 
1733 /// Extracts a signed N-bit sequence from each element of a vector of bytes,
1734 /// starting at the specified bit index.
1735 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1736 ///
1737 /// Example for a single element:
1738 /// Extract numBits=2 starting at bitIdx=2
1739 /// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1740 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1741 /// target = [. . . . ^ ^ . .]
1742 ///
1743 /// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1744 /// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1745 ///
1746 /// src = [01 01 11 10]
1747 /// shl = arith.shl(src, 4) -> [11 10 00 00]
1748 /// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1750  Location loc, Value src,
1751  int bitIdx, int numBits) {
1752  auto srcType = cast<VectorType>(src.getType());
1753  Value shl = src;
1754  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1755  assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1756  "Invalid bitIdx range");
1757  if (bitsToShiftLeft != 0) {
1758  Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
1759  loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1760  shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
1761  }
1762 
1763  int8_t bitsToShiftRight = 8 - numBits;
1764  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1765  loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1766  Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1767  return shr;
1768 }
1769 
1770 /// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1771 /// starting at the specified bit index.
1772 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1773 ///
1774 /// Example for a single element:
1775 /// Extract numBits=2 starting at bitIdx=2
1776 /// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1777 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1778 /// target = [. . . . ^ ^ . .]
1779 ///
1780 /// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1781 /// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1782 ///
1783 /// src = [01 01 10 10]
1784 /// mask = [00 00 00 11]
1785 /// shr = arith.shrui(src, 2) = [00 01 01 10]
1786 /// result = arith.andi(shr, mask) = [00 00 00 10]
1787 /// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1788 /// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1789 /// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1790 /// left when the index is 0.
1792  Location loc, Value src,
1793  int bitIdx, int numBits) {
1794  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1795  "Invalid bitIdx range");
1796  auto srcType = cast<VectorType>(src.getType());
1797  int8_t bitsToShiftRight = bitIdx;
1798  Value shr = src;
1799  if (bitsToShiftRight != 0) {
1800  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1801  loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1802  shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
1803  }
1804  if (bitIdx + numBits == 8) {
1805  return shr;
1806  }
1807  uint8_t lowBitsMask = (1 << numBits) - 1;
1808  Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1809  loc, DenseElementsAttr::get(srcType, lowBitsMask));
1810  return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1811 }
1812 
1814  std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1815 
1816 /// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1817 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1819  Value srcValue, const ExtractNBitsFn &extFn) {
1820  [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1821  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1822  "Expected i4 type");
1823 
1824  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1825  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1826 
1827  // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1828  // byte are place in one vector and the high i4 elements in another vector.
1829  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1830  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1831 
1832  // 3. Interleave low and high i8 elements.
1833  return rewriter.create<vector::InterleaveOp>(loc, low, high);
1834 }
1835 
1836 /// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1837 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1839  Value srcValue, const ExtractNBitsFn &extFn) {
1840  [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1841  assert(srcVecType.getElementType().isSignlessInteger(2) &&
1842  "Expected i2 type");
1843 
1844  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1845  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1846 
1847  // 2. Extract each i2 element
1848  // Positon 0 (bits 0-1)
1849  Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1850  // Position 1 (bits 2-3)
1851  Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1852  // Position 2 (bits 4-5)
1853  Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1854  // Position 3 (bits 6-7)
1855  Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1856 
1857  // 3. Interleave all 4 elements by first interleaving
1858  // even elements and then odd
1859  // vec0 = [0,0,0,0],...
1860  // vec1 = [1,1,1,1],...
1861  // vec2 = [2,2,2,2],...
1862  // vec3 = [3,3,3,3],...
1863  // 02 = [0,2,0,2,0,2,0,2],...
1864  // 13 = [1,3,1,3,1,3,1,3],...
1865  // 0213 = [0,1,2,3,...],...
1866  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
1867  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
1868  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
1869 }
1870 
1871 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1872 /// ops to avoid leaving LLVM to scramble with peephole optimizations.
1874  Value srcValue) {
1875  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1876  assert(srcVecType.getElementType().isSignlessInteger(8) &&
1877  "Expected i8 type");
1878 
1879  // 1. De-interleave low and high i8 elements.
1880  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
1881 
1882  // 2. Zero out the upper side of each low i8 element.
1883  constexpr int8_t i8LowBitMask = 0x0F;
1884  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1885  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
1886  loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1887  Value zeroOutLow = rewriter.create<arith::AndIOp>(
1888  loc, deinterleaveOp.getRes1(), zeroOutMask);
1889 
1890  // 3. Move high i4 values to upper side of the byte.
1891  constexpr int8_t bitsToShift = 4;
1892  auto shiftValues = rewriter.create<arith::ConstantOp>(
1893  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1894  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1895  shiftValues);
1896 
1897  // 4. Merge high and low i4 values.
1898  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1899 
1900  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1901  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1902  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1903 }
1904 
1905 namespace {
1906 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1907 /// advantage of high-level information to avoid leaving LLVM to scramble with
1908 /// peephole optimizations.
1909 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1911 
1912  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1913  PatternRewriter &rewriter) const override {
1914  // The source must be a trunc op.
1915  auto truncOp =
1916  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1917  if (!truncOp)
1918  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1919 
1920  // Set up the BitCastRewriter and verify the precondition.
1921  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1922  VectorType targetVectorType = bitCastOp.getResultVectorType();
1923  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1924  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1925  return failure();
1926 
1927  // Perform the rewrite.
1928  Value truncValue = truncOp.getIn();
1929  auto shuffledElementType =
1930  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1931  Value runningResult;
1932  for (const BitCastRewriter ::Metadata &metadata :
1933  bcr.precomputeMetadata(shuffledElementType)) {
1934  runningResult = bcr.genericRewriteStep(
1935  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1936  }
1937 
1938  // Finalize the rewrite.
1939  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1940  shuffledElementType.getIntOrFloatBitWidth();
1941  if (narrowing) {
1942  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1943  rewriter.replaceOp(bitCastOp, runningResult);
1944  } else {
1945  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1946  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1947  }
1948  } else {
1949  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1950  rewriter.replaceOp(bitCastOp, runningResult);
1951  } else {
1952  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1953  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1954  }
1955  }
1956 
1957  return success();
1958  }
1959 };
1960 } // namespace
1961 
1962 //===----------------------------------------------------------------------===//
1963 // RewriteExtOfBitCast
1964 //===----------------------------------------------------------------------===//
1965 
1966 namespace {
1967 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1968 /// take advantage of high-level information to avoid leaving LLVM to scramble
1969 /// with peephole optimizations.
1970 template <typename ExtOpType>
1971 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1973 
1974  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1975  : OpRewritePattern<ExtOpType>(context, benefit) {}
1976 
1977  LogicalResult matchAndRewrite(ExtOpType extOp,
1978  PatternRewriter &rewriter) const override {
1979  // The source must be a bitcast op.
1980  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1981  if (!bitCastOp)
1982  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1983 
1984  // Set up the BitCastRewriter and verify the precondition.
1985  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1986  VectorType targetVectorType = bitCastOp.getResultVectorType();
1987  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1988  if (failed(bcr.commonPrecondition(
1989  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1990  return failure();
1991 
1992  // Perform the rewrite.
1993  Value runningResult;
1994  Value sourceValue = bitCastOp.getSource();
1995  auto shuffledElementType =
1996  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1997  for (const BitCastRewriter::Metadata &metadata :
1998  bcr.precomputeMetadata(shuffledElementType)) {
1999  runningResult = bcr.genericRewriteStep(
2000  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2001  }
2002 
2003  // Finalize the rewrite.
2004  bool narrowing =
2005  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2006  shuffledElementType.getIntOrFloatBitWidth();
2007  if (narrowing) {
2008  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2009  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2010  } else {
2011  rewriter.replaceOpWithNewOp<ExtOpType>(
2012  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2013  }
2014 
2015  return success();
2016  }
2017 };
2018 
2019 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
2020 /// bitwise ops that take advantage of high-level information to avoid leaving
2021 /// LLVM to scramble with peephole optimizations. Templated to choose between
2022 /// signed and unsigned conversions.
2023 ///
2024 /// EXAMPLE 1 (signed):
2025 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
2026 /// is rewriten as:
2027 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2028 /// %1 = arith.shli %0, 4 : vector<4xi8>
2029 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
2030 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
2031 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2032 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2033 ///
2034 /// EXAMPLE 2 (fp):
2035 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2036 /// is rewriten as:
2037 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2038 /// %1 = arith.shli %0, 4 : vector<4xi8>
2039 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
2040 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
2041 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2042 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2043 ///
2044 /// EXAMPLE 3 (unsigned):
2045 /// arith.extui %in : vector<8xi4> to vector<8xi32>
2046 /// is rewritten as:
2047 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2048 /// %1 = arith.andi %0, 15 : vector<4xi8>
2049 /// %2 = arith.shrui %0, 4 : vector<4xi8>
2050 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2051 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2052 ///
2053 template <typename ConversionOpType, bool isSigned>
2054 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2056 
2057  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2058  PatternRewriter &rewriter) const override {
2059  // Verify the preconditions.
2060  Value srcValue = conversionOp.getIn();
2061  VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2062  VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2063 
2064  if (failed(
2065  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2066  return failure();
2067 
2068  // Check general alignment preconditions.
2069  if (failed(alignedConversionPrecondition(
2070  rewriter, srcVecType,
2071  /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2072  return failure();
2073 
2074  // Perform the rewrite.
2075  Location loc = conversionOp.getLoc();
2076  const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2078  Value subByteExt;
2079  switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2080  case 2:
2081  subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2082  break;
2083  case 4:
2084  subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2085  break;
2086  default:
2087  return failure();
2088  }
2089 
2090  // Finalize the rewrite.
2091  rewriter.replaceOpWithNewOp<ConversionOpType>(
2092  conversionOp, conversionOp.getType(), subByteExt);
2093  return success();
2094  }
2095 };
2096 
2097 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2098 /// bitwise ops that take advantage of high-level information to avoid leaving
2099 /// LLVM to scramble with peephole optimizations.
2100 ///
2101 /// For example:
2102 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
2103 ///
2104 /// is rewriten as:
2105 ///
2106 /// %cst = arith.constant dense<15> : vector<4xi8>
2107 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
2108 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2109 /// %2 = arith.andi %0, %cst : vector<4xi8>
2110 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2111 /// %4 = arith.ori %2, %3 : vector<4xi8>
2112 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2113 ///
2114 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2116 
2117  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2118  PatternRewriter &rewriter) const override {
2119  // Verify the preconditions.
2120  Value srcValue = truncOp.getIn();
2121  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2122  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2123  if (!srcVecType || !dstVecType)
2124  return failure();
2125 
2126  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2127  return failure();
2128 
2129  // TODO: Add support for truncating to i2.
2130  if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2131  return failure();
2132 
2133  // Check general alignment preconditions. We invert the src/dst type order
2134  // to reuse the existing precondition logic.
2135  if (failed(alignedConversionPrecondition(
2136  rewriter, dstVecType,
2137  /*containerTy=*/rewriter.getI8Type(), truncOp)))
2138  return failure();
2139 
2140  // Create a new iX -> i8 truncation op.
2141  Location loc = truncOp.getLoc();
2142  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2143  Value i8TruncVal =
2144  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
2145 
2146  // Rewrite the i8 -> i4 truncation part.
2147  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2148 
2149  // Finalize the rewrite.
2150  rewriter.replaceOp(truncOp, subByteTrunc);
2151  return success();
2152  }
2153 };
2154 
2155 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
2156 /// perform the transpose on wider (byte) element types.
2157 ///
2158 /// EXAMPLE:
2159 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2160 ///
2161 /// is rewritten as:
2162 ///
2163 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2164 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2165 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2166 ///
2167 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2169 
2170  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2171  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2172 
2173  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2174  PatternRewriter &rewriter) const override {
2175  // Precondition: sub-byte integer transpose.
2176  constexpr unsigned minNativeBitwidth = 8;
2177  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2178  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2179  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2180  return rewriter.notifyMatchFailure(transposeOp,
2181  "not a sub-byte transpose");
2182  }
2183 
2184  // Perform the rewrite.
2185  Location loc = transposeOp.getLoc();
2186  // Signed/unsigned interpretation shouldn't matter here as we are just
2187  // transposing the elements and truncating them back to the original size.
2188  // TODO: Use unsigned extension (more efficient) when emulation or backend
2189  // support is available.
2190  auto srcNativeVecType = srcSubByteVecType.cloneWith(
2191  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2192  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
2193  transposeOp.getVector());
2194  Value newTranspose = rewriter.create<vector::TransposeOp>(
2195  loc, extOp, transposeOp.getPermutation());
2196  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2197  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2198  newTranspose);
2199  return success();
2200  }
2201 };
2202 
2203 } // namespace
2204 
2205 //===----------------------------------------------------------------------===//
2206 // Public Interface Definition
2207 //===----------------------------------------------------------------------===//
2208 
2209 // The emulated type is inferred from the converted memref type.
2210 void vector::populateVectorNarrowTypeEmulationPatterns(
2211  const arith::NarrowTypeEmulationConverter &typeConverter,
2212  RewritePatternSet &patterns, bool disableAtomicRMW) {
2213 
2214  // Populate `vector.*` conversion patterns.
2215  // TODO: #119553 support atomicity
2216  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2217  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2218  typeConverter, patterns.getContext());
2219 
2220  // Populate `vector.*` store conversion patterns. The caller can choose
2221  // to avoid emitting atomic operations and reduce it to read-modify-write
2222  // sequence for stores if it is known there are no thread contentions.
2223  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
2224 }
2225 
2226 void vector::populateVectorNarrowTypeRewritePatterns(
2228  // TODO: Document what the emulated type is.
2229  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2230  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2231  benefit);
2232 
2233  // Patterns for aligned cases. We set higher priority as they are expected to
2234  // generate better performance for aligned cases.
2235  // The container type is always i8.
2236  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2237  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2238  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2239  benefit.getBenefit() + 1);
2240  // The container type is always i8.
2241  patterns
2242  .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2243  RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2244  patterns.getContext(), benefit.getBenefit() + 1);
2245 }
2246 
2247 // The container type is always i8.
2248 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2250  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2251 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
IntegerType getI4Type()
Definition: Builders.cpp:57
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:428
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1224
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.