MLIR  21.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4
11 // ("emulated type"), using wider types, e.g. i8 ("container type").
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
41 using namespace mlir;
42 
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
47 
50 
51 //===----------------------------------------------------------------------===//
52 // Utils
53 //===----------------------------------------------------------------------===//
54 
55 /// Returns a compressed mask for the emulated vector. For example, when
56 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
57 /// elements span two dest elements), this method compresses `vector<8xi1>`
58 /// into `vector<2xi1>`.
59 ///
60 /// The compressed/output mask value is set iff any mask in the corresponding
61 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
62 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
63 /// following mask:
64 ///
65 /// %mask = [1, 1, 0, 0, 0, 0]
66 ///
67 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
68 /// will be added in the back to make the number of elements a multiple of
69 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
70 ///
71 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
72 ///
73 /// then it will return the following new compressed mask:
74 ///
75 /// %mask = [1, 1, 0, 0]
76 ///
77 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
78 /// `numSrcElemsPerDest`.
79 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
80  Location loc, Value mask,
81  int numSrcElems,
82  int numSrcElemsPerDest,
83  int numFrontPadElems = 0) {
84 
85  assert(numFrontPadElems < numSrcElemsPerDest &&
86  "numFrontPadElems must be less than numSrcElemsPerDest");
87 
88  auto numDestElems =
89  (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
90  numSrcElemsPerDest;
91 
92  Operation *maskOp = mask.getDefiningOp();
94  // TODO: add support to `vector.splat`.
95  // Finding the mask creation operation.
96  while (maskOp &&
97  !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98  maskOp)) {
99  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
100  maskOp = extractOp.getVector().getDefiningOp();
101  extractOps.push_back(extractOp);
102  }
103  }
104 
105  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
106  maskOp))
107  return failure();
108 
109  // Computing the "compressed" mask. All the emulation logic (i.e. computing
110  // new mask index) only happens on the last dimension of the vectors.
111  SmallVector<int64_t> maskShape(
112  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
113  maskShape.back() = numDestElems;
114  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
115  std::optional<Operation *> newMask =
117  .Case<vector::CreateMaskOp>(
118  [&](auto createMaskOp) -> std::optional<Operation *> {
119  OperandRange maskOperands = createMaskOp.getOperands();
120  // The `vector.create_mask` op creates a mask arrangement
121  // without any zeros at the front. Also, because
122  // `numFrontPadElems` is strictly smaller than
123  // `numSrcElemsPerDest`, the compressed mask generated by
124  // padding the original mask by `numFrontPadElems` will not
125  // have any zeros at the front as well.
126  AffineExpr s0;
127  bindSymbols(rewriter.getContext(), s0);
128  s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
129  OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
131  rewriter, loc, s0, origIndex);
132  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
133  newMaskOperands.push_back(
134  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
135  return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
136  newMaskOperands);
137  })
138  .Case<vector::ConstantMaskOp>(
139  [&](auto constantMaskOp) -> std::optional<Operation *> {
140  // Take the shape of mask, compress its trailing dimension:
141  SmallVector<int64_t> maskDimSizes(
142  constantMaskOp.getMaskDimSizes());
143  int64_t &maskIndex = maskDimSizes.back();
144  maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
145  numSrcElemsPerDest);
146  return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
147  maskDimSizes);
148  })
149  .Case<arith::ConstantOp>([&](auto constantOp)
150  -> std::optional<Operation *> {
151  // TODO: Support multiple dimensions.
152  if (maskShape.size() != 1)
153  return std::nullopt;
154  // Rearrange the original mask values to cover the whole potential
155  // loading region. For example, in the case of using byte-size for
156  // emulation, given the following mask:
157  //
158  // %mask = [0, 1, 0, 1, 0, 0]
159  //
160  // With front offset of 1, the mask will be padded 0s in the front
161  // and back so that:
162  // 1. It is aligned with the effective loading bits
163  // 2. Its length is multiple of `numSrcElemPerDest` (and the total
164  // coverage size is mulitiple of bytes). The new mask will be like
165  // this before compressing:
166  //
167  // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
168  auto originalMask =
169  cast<DenseIntElementsAttr>(constantOp.getValue());
170  SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
171  paddedMaskValues.append(originalMask.template value_begin<bool>(),
172  originalMask.template value_end<bool>());
173  paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
174 
175  // Compressing by combining every `numSrcElemsPerDest` elements:
176  SmallVector<bool> compressedMaskValues;
177  for (size_t i = 0; i < paddedMaskValues.size();
178  i += numSrcElemsPerDest) {
179  bool combinedValue = false;
180  for (int j = 0; j < numSrcElemsPerDest; ++j) {
181  combinedValue |= paddedMaskValues[i + j];
182  }
183  compressedMaskValues.push_back(combinedValue);
184  }
185  return rewriter.create<arith::ConstantOp>(
186  loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
187  });
188 
189  if (!newMask)
190  return failure();
191 
192  while (!extractOps.empty()) {
193  newMask = rewriter.create<vector::ExtractOp>(
194  loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
195  extractOps.pop_back();
196  }
197 
198  return *newMask;
199 }
200 
201 /// Extracts 1-D subvector from a 1-D vector.
202 ///
203 /// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
204 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
205 ///
206 /// vector<numElemsToExtract x !elemType>
207 ///
208 /// (`!elType` is the element type of the source vector). As `offset` is a known
209 /// _static_ value, this helper hook emits `vector.extract_strided_slice`.
210 ///
211 /// EXAMPLE:
212 /// %res = vector.extract_strided_slice %src
213 /// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
215  Value src, int64_t offset,
216  int64_t numElemsToExtract) {
217  auto vectorType = cast<VectorType>(src.getType());
218  assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
219  assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220  "subvector out of bounds");
221 
222  // When extracting all available elements, just use the source vector as the
223  // result.
224  if (vectorType.getNumElements() == numElemsToExtract)
225  return src;
226 
227  auto offsets = rewriter.getI64ArrayAttr({offset});
228  auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
229  auto strides = rewriter.getI64ArrayAttr({1});
230 
231  auto resultVectorType =
232  VectorType::get({numElemsToExtract}, vectorType.getElementType());
233  return rewriter
234  .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
235  offsets, sizes, strides)
236  ->getResult(0);
237 }
238 
239 /// Inserts 1-D subvector into a 1-D vector.
240 ///
241 /// Inserts the input rank-1 source vector into the destination vector starting
242 /// at `offset`. As `offset` is a known _static_ value, this helper hook emits
243 /// `vector.insert_strided_slice`.
244 ///
245 /// EXAMPLE:
246 /// %res = vector.insert_strided_slice %src, %dest
247 /// {offsets = [%offset], strides [1]}
249  Value src, Value dest, int64_t offset) {
250  [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
251  [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
252  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253  "expected source and dest to be rank-1 vector types");
254 
255  // If overwritting the destination vector, just return the source.
256  if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
257  return src;
258 
259  auto offsets = rewriter.getI64ArrayAttr({offset});
260  auto strides = rewriter.getI64ArrayAttr({1});
261  return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
262  dest, offsets, strides);
263 }
264 
265 /// Extracts 1-D subvector from a 1-D vector.
266 ///
267 /// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
269 ///
270 /// vector<numElemsToExtact x !elType>
271 ///
272 /// (`!elType` is the element type of the source vector). As `offset` is assumed
273 /// to be a _dynamic_ SSA value, this helper method generates a sequence of
274 /// `vector.extract` + `vector.insert` pairs.
275 ///
276 /// EXAMPLE:
277 /// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278 /// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279 /// %c1 = arith.constant 1 : index
280 /// %idx2 = arith.addi %offset, %c1 : index
281 /// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282 /// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283 /// (...)
285  Value src, Value dest,
286  OpFoldResult offset,
287  int64_t numElemsToExtract) {
288  auto srcVecTy = cast<VectorType>(src.getType());
289  assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
290  // NOTE: We are unable to take the offset into account in the following
291  // assert, hence its still possible that the subvector is out-of-bounds even
292  // if the condition is true.
293  assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294  "subvector out of bounds");
295 
296  // When extracting all available elements, just use the source vector as the
297  // result.
298  if (srcVecTy.getNumElements() == numElemsToExtract)
299  return src;
300 
301  for (int i = 0; i < numElemsToExtract; ++i) {
302  Value extractLoc =
303  (i == 0) ? dyn_cast<Value>(offset)
304  : rewriter.create<arith::AddIOp>(
305  loc, rewriter.getIndexType(), dyn_cast<Value>(offset),
306  rewriter.create<arith::ConstantIndexOp>(loc, i));
307  auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, extractLoc);
308  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
309  }
310  return dest;
311 }
312 
313 /// Inserts 1-D subvector into a 1-D vector.
314 ///
315 /// Inserts the input rank-1 source vector into the destination vector starting
316 /// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
317 /// uses a sequence of `vector.extract` + `vector.insert` pairs.
318 ///
319 /// EXAMPLE:
320 /// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
321 /// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
322 /// %c1 = arith.constant 1 : index
323 /// %idx2 = arith.addi %offset, %c1 : index
324 /// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
325 /// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
326 /// (...)
328  Value src, Value dest,
329  OpFoldResult offset,
330  int64_t numElemsToInsert) {
331  auto srcVecTy = cast<VectorType>(src.getType());
332  auto destVecTy = cast<VectorType>(dest.getType());
333  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334  "expected source and dest to be rank-1 vector types");
335  (void)srcVecTy;
336  (void)destVecTy;
337  assert(numElemsToInsert > 0 &&
338  "the number of elements to insert must be greater than 0");
339  // NOTE: We are unable to take the offset into account in the following
340  // assert, hence its still possible that the subvector is out-of-bounds even
341  // if the condition is true.
342  assert(numElemsToInsert <= destVecTy.getNumElements() &&
343  "subvector out of bounds");
344 
345  Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
346  for (int64_t i = 0; i < numElemsToInsert; ++i) {
347  auto insertLoc = i == 0
348  ? destOffsetVal
349  : rewriter.create<arith::AddIOp>(
350  loc, rewriter.getIndexType(), destOffsetVal,
351  rewriter.create<arith::ConstantIndexOp>(loc, i));
352  auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, i);
353  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
354  }
355  return dest;
356 }
357 
358 /// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
359 ///
360 /// Specifically, use `containerElemTy` for loading a vector of
361 /// `emulatedElemTy`. The load location is given by `base` and
362 /// `linearizedIndices`, and the load size is given by
363 /// `numEmulatedElementsToLoad`.
365  Value base,
366  OpFoldResult linearizedIndices,
367  int64_t numContainerElemsToLoad,
368  Type emulatedElemTy,
369  Type containerElemTy) {
370  auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
371  emulatedElemTy.getIntOrFloatBitWidth();
372  auto newLoad = rewriter.create<vector::LoadOp>(
373  loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
374  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
375  return rewriter.create<vector::BitCastOp>(
376  loc,
377  VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
378  emulatedElemTy),
379  newLoad);
380 }
381 
382 /// Downcast two values to `downcastType`, then select values
383 /// based on `mask`, and casts the result to `upcastType`.
385  VectorType downcastType,
386  VectorType upcastType, Value mask,
387  Value trueValue, Value falseValue) {
388  assert(
389  downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390  upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391  "expected input and output number of bits to match");
392  if (trueValue.getType() != downcastType) {
393  trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
394  }
395  if (falseValue.getType() != downcastType) {
396  falseValue =
397  builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
398  }
399  Value selectedType =
400  builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
401  // Upcast the selected value to the new type.
402  return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
403 }
404 
405 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
406 /// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
407 /// subbyte-sized elements, with size of 8 bits, and the mask is used to select
408 /// which elements to store.
409 ///
410 /// Inputs:
411 /// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
412 /// storeIdx = 2
413 /// valueToStore = |3|3|3|3| : vector<4xi2>
414 /// mask = |0|0|1|1| : vector<4xi1>
415 ///
416 /// Result:
417 /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
418 static void atomicRMW(OpBuilder &builder, Location loc,
419  MemRefValue linearizedMemref, Value storeIdx,
420  VectorValue valueToStore, Value mask) {
421  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
422 
423  // Create an atomic load-modify-write region using
424  // `memref.generic_atomic_rmw`.
425  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
426  loc, linearizedMemref, ValueRange{storeIdx});
427  Value origValue = atomicOp.getCurrentValue();
428 
429  OpBuilder::InsertionGuard guard(builder);
430  builder.setInsertionPointToStart(atomicOp.getBody());
431 
432  // Load the original value from memory, and cast it to the original element
433  // type.
434  auto oneElemVecType = VectorType::get({1}, origValue.getType());
435  Value origVecValue = builder.create<vector::FromElementsOp>(
436  loc, oneElemVecType, ValueRange{origValue});
437 
438  // Construct the final masked value and yield it.
439  Value maskedValue =
440  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
441  oneElemVecType, mask, valueToStore, origVecValue);
442  auto scalarMaskedValue =
443  builder.create<vector::ExtractOp>(loc, maskedValue, 0);
444  builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
445 }
446 
447 /// Generate a non-atomic read-modify-write sequence for storing to the emulated
448 /// type. It has similar logic to `atomicRMWStore`, but without atomicity.
449 static void nonAtomicRMW(OpBuilder &builder, Location loc,
450  MemRefValue linearizedMemref, Value linearizedIndex,
451  VectorValue valueToStore, Value mask) {
452  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
453 
454  auto oneElemVecType =
455  VectorType::get({1}, linearizedMemref.getType().getElementType());
456  Value origVecValue = builder.create<vector::LoadOp>(
457  loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
458  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
459  origVecValue);
460 
461  Value maskedValue =
462  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
463  oneElemVecType, mask, valueToStore, origVecValue);
464  builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
465  linearizedIndex);
466 }
467 
468 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
469 /// and insert it into an empty vector at `insertOffset`.
470 /// Inputs:
471 /// vec_in = |0|1|2|3| : vector<4xi2>
472 /// extractOffset = 1
473 /// sliceNumElements = 2
474 /// insertOffset = 2
475 /// Output:
476 /// vec_out = |0|0|1|2| : vector<4xi2>
478  Location loc, VectorValue vector,
479  int64_t extractOffset,
480  int64_t sliceNumElements,
481  int64_t insertOffset) {
482  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
483  auto vectorElementType = vector.getType().getElementType();
484  // TODO: update and use `alignedConversionPrecondition` in the place of
485  // these asserts.
486  assert(
487  sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
488  "sliceNumElements * vector element size must be less than or equal to 8");
489  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
490  "vector element must be a valid sub-byte type");
491  auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
492  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
493  loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
494  rewriter.getZeroAttr(
495  VectorType::get({emulatedPerContainerElem}, vectorElementType)));
496  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
497  extractOffset, sliceNumElements);
498  return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
499  insertOffset);
500 }
501 
502 namespace {
503 
504 //===----------------------------------------------------------------------===//
505 // ConvertVectorStore
506 //===----------------------------------------------------------------------===//
507 
508 // Emulate `vector.store` using a multi-byte container type.
509 //
510 // The container type is obtained through Op adaptor and would normally be
511 // generated via `NarrowTypeEmulationConverter`.
512 //
513 // EXAMPLE 1
514 // (aligned store of i4, emulated using i8 as the container type)
515 //
516 // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
517 //
518 // is rewritten as:
519 //
520 // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
521 // vector.store %src_bitcast, %dest_bitcast[%idx]
522 // : memref<16xi8>, vector<4xi8>
523 //
524 // EXAMPLE 2
525 // (unaligned store of i2, emulated using i8 as the container type)
526 //
527 // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
528 //
529 // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
530 // is modelled using 3 bytes:
531 //
532 // Byte 0 Byte 1 Byte 2
533 // +----------+----------+----------+
534 // | oooooooo | ooooNNNN | NNoooooo |
535 // +----------+----------+----------+
536 //
537 // N - (N)ew entries (i.e. to be overwritten by vector.store)
538 // o - (o)ld entries (to be preserved)
539 //
540 // For the generated output in the non-atomic case, see:
541 // * @vector_store_i2_const_index_two_partial_stores`
542 // in:
543 // * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
544 //
545 // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
546 // `false` to generate non-atomic RMW sequences.
547 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
549 
550  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
551  : OpConversionPattern<vector::StoreOp>(context),
552  disableAtomicRMW(disableAtomicRMW) {}
553 
554  LogicalResult
555  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
556  ConversionPatternRewriter &rewriter) const override {
557 
558  // See #115653
559  if (op.getValueToStore().getType().getRank() != 1)
560  return rewriter.notifyMatchFailure(op,
561  "only 1-D vectors are supported ATM");
562 
563  auto loc = op.getLoc();
564 
565  auto valueToStore = cast<VectorValue>(op.getValueToStore());
566  auto containerElemTy =
567  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
568  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
569  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
570  int containerBits = containerElemTy.getIntOrFloatBitWidth();
571 
572  // Check per-element alignment.
573  if (containerBits % emulatedBits != 0) {
574  return rewriter.notifyMatchFailure(
575  op, "impossible to pack emulated elements into container elements "
576  "(bit-wise misalignment)");
577  }
578  int emulatedPerContainerElem = containerBits / emulatedBits;
579 
580  // Adjust the number of elements to store when emulating narrow types.
581  // Here only the 1-D vector store is considered, and the N-D memref types
582  // should be linearized.
583  // For example, to emulate i4 to i8, the following op:
584  //
585  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
586  //
587  // can be replaced with
588  //
589  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
590  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
591  // vector<4xi8>
592 
593  auto origElements = valueToStore.getType().getNumElements();
594  // Note, per-element-alignment was already verified above.
595  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
596 
597  auto stridedMetadata =
598  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
599 
600  OpFoldResult linearizedIndices;
601  memref::LinearizedMemRefInfo linearizedInfo;
602  std::tie(linearizedInfo, linearizedIndices) =
604  rewriter, loc, emulatedBits, containerBits,
605  stridedMetadata.getConstifiedMixedOffset(),
606  stridedMetadata.getConstifiedMixedSizes(),
607  stridedMetadata.getConstifiedMixedStrides(),
608  getAsOpFoldResult(adaptor.getIndices()));
609 
610  std::optional<int64_t> foldedNumFrontPadElems =
611  isDivisibleInSize ? 0
612  : getConstantIntValue(linearizedInfo.intraDataOffset);
613 
614  if (!foldedNumFrontPadElems) {
615  return rewriter.notifyMatchFailure(
616  op, "subbyte store emulation: dynamic front padding size is "
617  "not yet implemented");
618  }
619 
620  auto memrefBase = cast<MemRefValue>(adaptor.getBase());
621 
622  // Conditions when atomic RMWs are not needed:
623  // 1. The source vector size (in bits) is a multiple of byte size.
624  // 2. The address of the store is aligned to the emulated width boundary.
625  //
626  // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
627  // need unaligned emulation because the store address is aligned and the
628  // source is a whole byte.
629  bool emulationRequiresPartialStores =
630  !isDivisibleInSize || *foldedNumFrontPadElems != 0;
631  if (!emulationRequiresPartialStores) {
632  // Basic case: storing full bytes.
633  auto numElements = origElements / emulatedPerContainerElem;
634  auto bitCast = rewriter.create<vector::BitCastOp>(
635  loc, VectorType::get(numElements, containerElemTy),
636  op.getValueToStore());
637  rewriter.replaceOpWithNewOp<vector::StoreOp>(
638  op, bitCast.getResult(), memrefBase,
639  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
640  return success();
641  }
642 
643  // Next, handle the case when sub-byte read-modify-write
644  // sequences are needed to emulate a vector store.
645  // Here is an example:
646  //
647  // Vector to store: vector<7xi2>
648  // Value to store: 11 11 11 11 11 11 11 (all ones)
649  //
650  // Destination: memref<12xi2>
651  // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
652  //
653  // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
654  //
655  // Destination memref before:
656  //
657  // Byte 0 Byte 1 Byte 2
658  // +----------+----------+----------+
659  // | 00000000 | 00000000 | 00000000 |
660  // +----------+----------+----------+
661  //
662  // Destination memref after:
663  //
664  // Byte 0 Byte 1 Byte 2
665  // +----------+----------+----------+
666  // | 00001111 | 11111111 | 11000000 |
667  // +----------+----------+----------+
668  //
669  // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
670  // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
671  // requiring RMW access (atomicity is required).
672 
673  // The index into the target memref we are storing to.
674  Value currentDestIndex =
675  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
676  // The index into the source vector we are currently processing.
677  auto currentSourceIndex = 0;
678 
679  // Build a mask used for rmw.
680  auto subWidthStoreMaskType =
681  VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
682 
683  auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
684 
685  // 1. Partial width store for the leading byte.
686  // When the store address is not aligned to emulated width boundary, deal
687  // with the unaligned part so that the rest elements are aligned to width
688  // boundary.
689  auto frontSubWidthStoreElem =
690  (emulatedPerContainerElem - *foldedNumFrontPadElems) %
691  emulatedPerContainerElem;
692  if (frontSubWidthStoreElem > 0) {
693  SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
694  if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
695  std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
696  origElements, true);
697  frontSubWidthStoreElem = origElements;
698  } else {
699  std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
700  *foldedNumFrontPadElems, true);
701  }
702  auto frontMask = rewriter.create<arith::ConstantOp>(
703  loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
704 
705  currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
706  auto value =
707  extractSliceIntoByte(rewriter, loc, valueToStore, 0,
708  frontSubWidthStoreElem, *foldedNumFrontPadElems);
709 
710  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
711  cast<VectorValue>(value), frontMask.getResult());
712  }
713 
714  if (currentSourceIndex >= origElements) {
715  rewriter.eraseOp(op);
716  return success();
717  }
718 
719  // Increment the destination index by 1 to align to the emulated width
720  // boundary.
721  auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
722  currentDestIndex = rewriter.create<arith::AddIOp>(
723  loc, rewriter.getIndexType(), currentDestIndex, constantOne);
724 
725  // 2. Full width store for the inner output bytes.
726  // After the previous step, the store address is aligned to the emulated
727  // width boundary.
728  int64_t fullWidthStoreSize =
729  (origElements - currentSourceIndex) / emulatedPerContainerElem;
730  int64_t numNonFullWidthElements =
731  fullWidthStoreSize * emulatedPerContainerElem;
732  if (fullWidthStoreSize > 0) {
733  auto fullWidthStorePart = staticallyExtractSubvector(
734  rewriter, loc, valueToStore, currentSourceIndex,
735  numNonFullWidthElements);
736 
737  auto originType = cast<VectorType>(fullWidthStorePart.getType());
738  auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
739  auto storeType = VectorType::get(
740  {originType.getNumElements() / emulatedPerContainerElem},
741  memrefElemType);
742  auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
743  fullWidthStorePart);
744  rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
745  currentDestIndex);
746 
747  currentSourceIndex += numNonFullWidthElements;
748  currentDestIndex = rewriter.create<arith::AddIOp>(
749  loc, rewriter.getIndexType(), currentDestIndex,
750  rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
751  }
752 
753  // 3. Partial width store for the trailing output byte.
754  // It is needed when the residual length is smaller than the emulated width,
755  // which is not covered in step 2 above.
756  auto remainingElements = origElements - currentSourceIndex;
757  if (remainingElements != 0) {
758  auto subWidthStorePart =
759  extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
760  currentSourceIndex, remainingElements, 0);
761 
762  // Generate back mask.
763  auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
764  std::fill_n(maskValues.begin(), remainingElements, 1);
765  auto backMask = rewriter.create<arith::ConstantOp>(
766  loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
767 
768  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
769  cast<VectorValue>(subWidthStorePart), backMask.getResult());
770  }
771 
772  rewriter.eraseOp(op);
773  return success();
774  }
775 
776 private:
777  const bool disableAtomicRMW;
778 };
779 
780 //===----------------------------------------------------------------------===//
781 // ConvertVectorMaskedStore
782 //===----------------------------------------------------------------------===//
783 
784 // TODO: Document-me
785 struct ConvertVectorMaskedStore final
786  : OpConversionPattern<vector::MaskedStoreOp> {
788 
789  LogicalResult
790  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
791  ConversionPatternRewriter &rewriter) const override {
792 
793  // See #115653
794  if (op.getValueToStore().getType().getRank() != 1)
795  return rewriter.notifyMatchFailure(op,
796  "only 1-D vectors are supported ATM");
797 
798  auto loc = op.getLoc();
799  auto containerElemTy =
800  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
801  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
802  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
803  int containerBits = containerElemTy.getIntOrFloatBitWidth();
804 
805  // Check per-element alignment.
806  if (containerBits % emulatedBits != 0) {
807  return rewriter.notifyMatchFailure(
808  op, "impossible to pack emulated elements into container elements "
809  "(bit-wise misalignment)");
810  }
811 
812  int emulatedPerContainerElem = containerBits / emulatedBits;
813  int origElements = op.getValueToStore().getType().getNumElements();
814  if (origElements % emulatedPerContainerElem != 0)
815  return failure();
816 
817  auto stridedMetadata =
818  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
819  OpFoldResult linearizedIndicesOfr;
820  memref::LinearizedMemRefInfo linearizedInfo;
821  std::tie(linearizedInfo, linearizedIndicesOfr) =
823  rewriter, loc, emulatedBits, containerBits,
824  stridedMetadata.getConstifiedMixedOffset(),
825  stridedMetadata.getConstifiedMixedSizes(),
826  stridedMetadata.getConstifiedMixedStrides(),
827  getAsOpFoldResult(adaptor.getIndices()));
828  Value linearizedIndices =
829  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
830 
831  // Load the whole data and use arith.select to handle the corner cases.
832  //
833  // As an example, for this masked store of i4 values:
834  //
835  // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
836  //
837  // and given these input values:
838  //
839  // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
840  // %0[%c0, %c0] =
841  // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
842  // %val_to_store =
843  // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
844  //
845  // we'll have the following i4 output:
846  //
847  // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
848  //
849  // Emulating the above using i8 will give:
850  //
851  // %compressed_mask = [1, 1, 1, 0] (4 * i1)
852  // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
853  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
854  // %select_using_shifted_mask =
855  // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
856  // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
857  //
858  // Using the compressed mask to store %packed_data results in expected
859  // output.
860  //
861  // FIXME: Make an example based on the comment above work (see #115460 for
862  // reproducer).
863  FailureOr<Operation *> newMask = getCompressedMaskOp(
864  rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
865  if (failed(newMask))
866  return failure();
867 
868  auto numElements = (origElements + emulatedPerContainerElem - 1) /
869  emulatedPerContainerElem;
870  auto newType = VectorType::get(numElements, containerElemTy);
871  auto passThru = rewriter.create<arith::ConstantOp>(
872  loc, newType, rewriter.getZeroAttr(newType));
873 
874  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
875  loc, newType, adaptor.getBase(), linearizedIndices,
876  newMask.value()->getResult(0), passThru);
877 
878  auto newBitCastType =
879  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
880  Value valueToStore =
881  rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
882  valueToStore = rewriter.create<arith::SelectOp>(
883  loc, op.getMask(), op.getValueToStore(), valueToStore);
884  valueToStore =
885  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
886 
887  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
888  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
889  valueToStore);
890  return success();
891  }
892 };
893 
894 //===----------------------------------------------------------------------===//
895 // ConvertVectorLoad
896 //===----------------------------------------------------------------------===//
897 
898 // TODO: Document-me
899 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
901 
902  LogicalResult
903  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
904  ConversionPatternRewriter &rewriter) const override {
905 
906  // See #115653
907  if (op.getVectorType().getRank() != 1)
908  return rewriter.notifyMatchFailure(op,
909  "only 1-D vectors are supported ATM");
910 
911  auto loc = op.getLoc();
912  auto containerElemTy =
913  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
914  Type emulatedElemTy = op.getType().getElementType();
915  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
916  int containerBits = containerElemTy.getIntOrFloatBitWidth();
917 
918  // Check per-element alignment.
919  if (containerBits % emulatedBits != 0) {
920  return rewriter.notifyMatchFailure(
921  op, "impossible to pack emulated elements into container elements "
922  "(bit-wise misalignment)");
923  }
924  int emulatedPerContainerElem = containerBits / emulatedBits;
925 
926  // Adjust the number of elements to load when emulating narrow types,
927  // and then cast back to the original type with vector.bitcast op.
928  // Here only the 1-D vector load is considered, and the N-D memref types
929  // should be linearized.
930  // For example, to emulate i4 to i8, the following op:
931  //
932  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
933  //
934  // can be replaced with
935  //
936  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
937  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
938  //
939  // There are cases where the number of elements to load is not byte-aligned,
940  // for example:
941  //
942  // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
943  //
944  // we will have to load extra bytes and extract the exact slice in between.
945  //
946  // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
947  // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
948  // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
949  // = [1]}
950  // : vector<8xi2> to vector<3xi2>
951  //
952  // TODO: Currently the extract_strided_slice's attributes must be known at
953  // compile time as they must be constants.
954 
955  auto origElements = op.getVectorType().getNumElements();
956  // Note, per-element-alignment was already verified above.
957  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
958 
959  auto stridedMetadata =
960  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
961 
962  OpFoldResult linearizedIndices;
963  memref::LinearizedMemRefInfo linearizedInfo;
964  std::tie(linearizedInfo, linearizedIndices) =
966  rewriter, loc, emulatedBits, containerBits,
967  stridedMetadata.getConstifiedMixedOffset(),
968  stridedMetadata.getConstifiedMixedSizes(),
969  stridedMetadata.getConstifiedMixedStrides(),
970  getAsOpFoldResult(adaptor.getIndices()));
971 
972  std::optional<int64_t> foldedIntraVectorOffset =
973  isDivisibleInSize ? 0
974  : getConstantIntValue(linearizedInfo.intraDataOffset);
975 
976  // Always load enough elements which can cover the original elements.
977  int64_t maxintraDataOffset =
978  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
979  auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
980  emulatedPerContainerElem);
981  Value result =
982  emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
983  numElements, emulatedElemTy, containerElemTy);
984 
985  if (!foldedIntraVectorOffset) {
986  auto resultVector = rewriter.create<arith::ConstantOp>(
987  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
989  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
990  linearizedInfo.intraDataOffset, origElements);
991  } else if (!isDivisibleInSize) {
993  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
994  }
995  rewriter.replaceOp(op, result);
996  return success();
997  }
998 };
999 
1000 //===----------------------------------------------------------------------===//
1001 // ConvertVectorMaskedLoad
1002 //===----------------------------------------------------------------------===//
1003 
1004 // TODO: Document-me
1005 struct ConvertVectorMaskedLoad final
1006  : OpConversionPattern<vector::MaskedLoadOp> {
1008 
1009  LogicalResult
1010  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1011  ConversionPatternRewriter &rewriter) const override {
1012  // See #115653
1013  if (op.getVectorType().getRank() != 1)
1014  return rewriter.notifyMatchFailure(op,
1015  "only 1-D vectors are supported ATM");
1016 
1017  auto loc = op.getLoc();
1018 
1019  auto containerElemTy =
1020  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1021  Type emulatedElemTy = op.getType().getElementType();
1022  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1023  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1024 
1025  // Check per-element alignment.
1026  if (containerBits % emulatedBits != 0) {
1027  return rewriter.notifyMatchFailure(
1028  op, "impossible to pack emulated elements into container elements "
1029  "(bit-wise misalignment)");
1030  }
1031  int emulatedPerContainerElem = containerBits / emulatedBits;
1032 
1033  // Adjust the number of elements to load when emulating narrow types,
1034  // and then cast back to the original type with vector.bitcast op.
1035  // For example, to emulate i4 to i8, the following op:
1036  //
1037  // %mask = vector.constant_mask [3] : vector<6xi1>
1038  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1039  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1040  //
1041  // can be replaced with
1042  //
1043  // %new_mask = vector.constant_mask [2] : vector<3xi1>
1044  // %new_pass_thru = vector.bitcast %pass_thru :
1045  // vector<6xi4> to vector<3xi8>
1046  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1047  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1048  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1049  //
1050  // Since we are effectively loading 16 bits (2xi8) from the memref with the
1051  // new mask, while originally we only wanted to effectively load 12 bits
1052  // (3xi4) from the memref, we need to set the second half of the last i8
1053  // that was effectively loaded (i.e. the second i8) to %pass_thru.
1054  //
1055  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1056  //
1057  // Given these input values:
1058  // %mask = [1, 1, 1, 0, 0, 0]
1059  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1060  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1061  //
1062  // we'll have:
1063  //
1064  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1065  //
1066  // %new_mask = [1, 1, 0]
1067  // %new_pass_thru = [0x78, 0x9A, 0xBC]
1068  // %1 = [0x12, 0x34, 0xBC]
1069  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1070  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1071  //
1072  // TODO: Currently, only the even number of elements loading is supported.
1073  // To deal with the odd number of elements, one has to extract the
1074  // subvector at the proper offset after bit-casting.
1075  auto origType = op.getVectorType();
1076  auto origElements = origType.getNumElements();
1077  // Note, per-element-alignment was already verified above.
1078  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1079 
1080  auto stridedMetadata =
1081  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
1082  OpFoldResult linearizedIndices;
1083  memref::LinearizedMemRefInfo linearizedInfo;
1084  std::tie(linearizedInfo, linearizedIndices) =
1086  rewriter, loc, emulatedBits, containerBits,
1087  stridedMetadata.getConstifiedMixedOffset(),
1088  stridedMetadata.getConstifiedMixedSizes(),
1089  stridedMetadata.getConstifiedMixedStrides(),
1090  getAsOpFoldResult(adaptor.getIndices()));
1091 
1092  std::optional<int64_t> foldedIntraVectorOffset =
1093  isDivisibleInSize ? 0
1094  : getConstantIntValue(linearizedInfo.intraDataOffset);
1095 
1096  int64_t maxIntraDataOffset =
1097  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1098  FailureOr<Operation *> newMask =
1099  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1100  emulatedPerContainerElem, maxIntraDataOffset);
1101  if (failed(newMask))
1102  return failure();
1103 
1104  Value passthru = op.getPassThru();
1105 
1106  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1107  emulatedPerContainerElem);
1108  auto loadType = VectorType::get(numElements, containerElemTy);
1109  auto newBitcastType =
1110  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1111 
1112  auto emptyVector = rewriter.create<arith::ConstantOp>(
1113  loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1114  if (!foldedIntraVectorOffset) {
1115  passthru = dynamicallyInsertSubVector(
1116  rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1117  origElements);
1118  } else if (!isDivisibleInSize) {
1119  passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1120  *foldedIntraVectorOffset);
1121  }
1122  auto newPassThru =
1123  rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
1124 
1125  // Generating the new masked load.
1126  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
1127  loc, loadType, adaptor.getBase(),
1128  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1129  newMask.value()->getResult(0), newPassThru);
1130 
1131  // Setting the part that originally was not effectively loaded from memory
1132  // to pass through.
1133  auto bitCast =
1134  rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
1135 
1136  Value mask = op.getMask();
1137  auto newSelectMaskType = VectorType::get(
1138  numElements * emulatedPerContainerElem, rewriter.getI1Type());
1139  // TODO: try to fold if op's mask is constant
1140  auto emptyMask = rewriter.create<arith::ConstantOp>(
1141  loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
1142  if (!foldedIntraVectorOffset) {
1143  mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1144  linearizedInfo.intraDataOffset,
1145  origElements);
1146  } else if (!isDivisibleInSize) {
1147  mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1148  *foldedIntraVectorOffset);
1149  }
1150 
1151  Value result =
1152  rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
1153  if (!foldedIntraVectorOffset) {
1154  result = dynamicallyExtractSubVector(
1155  rewriter, loc, result, op.getPassThru(),
1156  linearizedInfo.intraDataOffset, origElements);
1157  } else if (!isDivisibleInSize) {
1158  result = staticallyExtractSubvector(
1159  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1160  }
1161  rewriter.replaceOp(op, result);
1162 
1163  return success();
1164  }
1165 };
1166 
1167 /// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1168 ///
1169 /// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1170 /// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1171 /// (a multi-byte scalar, e.g. i16), where N is some integer.
1172 ///
1173 /// Put differently, this method checks whether this would be valid:
1174 ///
1175 /// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1176 ///
1177 /// EXAMPLES:
1178 /// * vector<4xi4> -> i16 - yes (N = 1)
1179 /// * vector<4xi4> -> i8 - yes (N = 2)
1180 /// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1181 /// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1182 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1183  Type multiByteScalarTy) {
1184  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1185 
1186  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1187  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1188 
1189  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1190  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1191  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1192 
1193  int elemsPerMultiByte = multiByteBits / subByteBits;
1194 
1195  // TODO: This is a bit too restrictive for vectors rank > 1.
1196  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // ConvertVectorTransferRead
1201 //===----------------------------------------------------------------------===//
1202 
1203 // TODO: Document-me
1204 struct ConvertVectorTransferRead final
1205  : OpConversionPattern<vector::TransferReadOp> {
1207 
1208  LogicalResult
1209  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1210  ConversionPatternRewriter &rewriter) const override {
1211 
1212  // See #115653
1213  if (op.getVectorType().getRank() != 1)
1214  return rewriter.notifyMatchFailure(op,
1215  "only 1-D vectors are supported ATM");
1216 
1217  auto loc = op.getLoc();
1218  auto containerElemTy =
1219  cast<MemRefType>(adaptor.getSource().getType()).getElementType();
1220  Type emulatedElemTy = op.getType().getElementType();
1221  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1222  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1223 
1224  // Check per-element alignment.
1225  if (containerBits % emulatedBits != 0) {
1226  return rewriter.notifyMatchFailure(
1227  op, "impossible to pack emulated elements into container elements "
1228  "(bit-wise misalignment)");
1229  }
1230  int emulatedPerContainerElem = containerBits / emulatedBits;
1231 
1232  auto origElements = op.getVectorType().getNumElements();
1233 
1234  // Note, per-element-alignment was already verified above.
1235  bool isDivisibleInSize =
1236  fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1237 
1238  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
1239  adaptor.getPadding());
1240 
1241  auto stridedMetadata =
1242  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
1243 
1244  OpFoldResult linearizedIndices;
1245  memref::LinearizedMemRefInfo linearizedInfo;
1246  std::tie(linearizedInfo, linearizedIndices) =
1248  rewriter, loc, emulatedBits, containerBits,
1249  stridedMetadata.getConstifiedMixedOffset(),
1250  stridedMetadata.getConstifiedMixedSizes(),
1251  stridedMetadata.getConstifiedMixedStrides(),
1252  getAsOpFoldResult(adaptor.getIndices()));
1253 
1254  std::optional<int64_t> foldedIntraVectorOffset =
1255  isDivisibleInSize ? 0
1256  : getConstantIntValue(linearizedInfo.intraDataOffset);
1257 
1258  int64_t maxIntraDataOffset =
1259  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1260  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1261  emulatedPerContainerElem);
1262 
1263  auto newRead = rewriter.create<vector::TransferReadOp>(
1264  loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
1265  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1266  newPadding);
1267 
1268  auto bitCast = rewriter.create<vector::BitCastOp>(
1269  loc,
1270  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1271  newRead);
1272 
1273  Value result = bitCast->getResult(0);
1274  if (!foldedIntraVectorOffset) {
1275  auto zeros = rewriter.create<arith::ConstantOp>(
1276  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1277  result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1278  linearizedInfo.intraDataOffset,
1279  origElements);
1280  } else if (!isDivisibleInSize) {
1281  result = staticallyExtractSubvector(
1282  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1283  }
1284  rewriter.replaceOp(op, result);
1285 
1286  return success();
1287  }
1288 };
1289 } // end anonymous namespace
1290 
1291 //===----------------------------------------------------------------------===//
1292 // RewriteBitCastOfTruncI
1293 //===----------------------------------------------------------------------===//
1294 
1295 namespace {
1296 
1297 /// Helper struct to keep track of the provenance of a contiguous set of bits
1298 /// in a source vector.
1299 struct SourceElementRange {
1300  /// The index of the source vector element that contributes bits to *this.
1301  int64_t sourceElementIdx;
1302  /// The range of bits in the source vector element that contribute to *this.
1303  int64_t sourceBitBegin;
1304  int64_t sourceBitEnd;
1305 };
1306 
1307 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1308  /// Given the index of a SourceElementRange in the SourceElementRangeList,
1309  /// compute the amount of bits that need to be shifted to the left to get the
1310  /// bits in their final location. This shift amount is simply the sum of the
1311  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1312  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1313  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1314  int64_t res = 0;
1315  for (int64_t i = 0; i < shuffleIdx; ++i)
1316  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1317  return res;
1318  }
1319 };
1320 
1321 /// Helper struct to enumerate the source elements and bit ranges that are
1322 /// involved in a bitcast operation.
1323 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1324 /// any 1-D vector shape and any source/target bitwidths.
1325 /// This creates and holds a mapping of the form:
1326 /// [dstVectorElementJ] ==
1327 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1328 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1329 /// [0] = {0, [0-8)}
1330 /// [1] = {0, [8-16)}
1331 /// [2] = {0, [16-24)}
1332 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1333 /// [0] = {0, [0, 10)}, {1, [0, 5)}
1334 /// [1] = {1, [5, 10)}, {2, [0, 10)}
1335 struct BitCastBitsEnumerator {
1336  BitCastBitsEnumerator(VectorType sourceVectorType,
1337  VectorType targetVectorType);
1338 
1339  int64_t getMaxNumberOfEntries() {
1340  int64_t numVectors = 0;
1341  for (const auto &l : sourceElementRanges)
1342  numVectors = std::max(numVectors, (int64_t)l.size());
1343  return numVectors;
1344  }
1345 
1346  VectorType sourceVectorType;
1347  VectorType targetVectorType;
1348  SmallVector<SourceElementRangeList> sourceElementRanges;
1349 };
1350 
1351 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1352 /// advantage of high-level information to avoid leaving LLVM to scramble with
1353 /// peephole optimizations.
1354 /// BitCastBitsEnumerator encodes for each element of the target vector the
1355 /// provenance of the bits in the source vector. We can "transpose" this
1356 /// information to build a sequence of shuffles and bitwise ops that will
1357 /// produce the desired result.
1358 //
1359 /// Consider the following motivating example:
1360 /// ```
1361 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1362 /// ```
1363 //
1364 /// BitCastBitsEnumerator contains the following information:
1365 /// ```
1366 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1367 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1368 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1369 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1370 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1371 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1372 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1373 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1374 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1375 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1376 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1377 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1378 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1379 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1380 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1381 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1382 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1383 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1384 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1385 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1386 /// ```
1387 ///
1388 /// In the above, each row represents one target vector element and each
1389 /// column represents one bit contribution from a source vector element.
1390 /// The algorithm creates vector.shuffle operations (in this case there are 3
1391 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1392 /// algorithm populates the bits as follows:
1393 /// ```
1394 /// src bits 0 ...
1395 /// 1st shuffle |xxxxx |xx |...
1396 /// 2nd shuffle | xxx| xxxxx |...
1397 /// 3rd shuffle | | x|...
1398 /// ```
1399 //
1400 /// The algorithm proceeds as follows:
1401 /// 1. for each vector.shuffle, collect the source vectors that participate in
1402 /// this shuffle. One source vector per target element of the resulting
1403 /// vector.shuffle. If there is no source element contributing bits for the
1404 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1405 /// 2 columns).
1406 /// 2. represent the bitrange in the source vector as a mask. If there is no
1407 /// source element contributing bits for the current vector.shuffle, take 0.
1408 /// 3. shift right by the proper amount to align the source bitrange at
1409 /// position 0. This is exactly the low end of the bitrange. For instance,
1410 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1411 /// shift right by 3 to get the bits contributed by the source element #1
1412 /// into position 0.
1413 /// 4. shift left by the proper amount to to align to the desired position in
1414 /// the result element vector. For instance, the contribution of the second
1415 /// source element for the first row needs to be shifted by `5` to form the
1416 /// first i8 result element.
1417 ///
1418 /// Eventually, we end up building the sequence
1419 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1420 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1421 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1422 struct BitCastRewriter {
1423  /// Helper metadata struct to hold the static quantities for the rewrite.
1424  struct Metadata {
1425  SmallVector<int64_t> shuffles;
1426  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1427  };
1428 
1429  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1430 
1431  /// Verify that general preconditions for the rewrite are met.
1432  LogicalResult commonPrecondition(PatternRewriter &rewriter,
1433  VectorType preconditionType, Operation *op);
1434 
1435  /// Precompute the metadata for the rewrite.
1437  precomputeMetadata(IntegerType shuffledElementType);
1438 
1439  /// Rewrite one step of the sequence:
1440  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1441  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1442  Value initialValue, Value runningResult,
1443  const BitCastRewriter::Metadata &metadata);
1444 
1445 private:
1446  /// Underlying enumerator that encodes the provenance of the bits in the each
1447  /// element of the result vector.
1448  BitCastBitsEnumerator enumerator;
1449 };
1450 
1451 } // namespace
1452 
1453 [[maybe_unused]] static raw_ostream &
1454 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
1455  for (const auto &l : vec) {
1456  for (auto it : llvm::enumerate(l)) {
1457  os << "{ " << it.value().sourceElementIdx << ": b@["
1458  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1459  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1460  }
1461  os << "\n";
1462  }
1463  return os;
1464 }
1465 
1466 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1467  VectorType targetVectorType)
1468  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1469 
1470  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1471  "requires -D non-scalable vector type");
1472  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1473  "requires -D non-scalable vector type");
1474  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1475  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1476  LDBG("sourceVectorType: " << sourceVectorType);
1477 
1478  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1479  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1480  LDBG("targetVectorType: " << targetVectorType);
1481 
1482  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1483  (void)mostMinorSourceDim;
1484  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1485  "source and target bitwidths must match");
1486 
1487  // Prepopulate one source element range per target element.
1488  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1489  for (int64_t resultBit = 0; resultBit < bitwidth;) {
1490  int64_t resultElement = resultBit / targetBitWidth;
1491  int64_t resultBitInElement = resultBit % targetBitWidth;
1492  int64_t sourceElementIdx = resultBit / sourceBitWidth;
1493  int64_t sourceBitInElement = resultBit % sourceBitWidth;
1494  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1495  targetBitWidth - resultBitInElement);
1496  sourceElementRanges[resultElement].push_back(
1497  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1498  resultBit += step;
1499  }
1500 }
1501 
1502 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1503  VectorType targetVectorType)
1504  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1505  LDBG("\n" << enumerator.sourceElementRanges);
1506 }
1507 
1508 /// Verify that the precondition type meets the common preconditions for any
1509 /// conversion.
1510 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1511  VectorType preconditionType,
1512  Operation *op) {
1513  if (!preconditionType || preconditionType.isScalable())
1514  return rewriter.notifyMatchFailure(op, "scalable vector");
1515 
1516  // TODO: consider relaxing this restriction in the future if we find ways
1517  // to really work with subbyte elements across the MLIR/LLVM boundary.
1518  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1519  if (bitwidth % 8 != 0)
1520  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1521 
1522  return success();
1523 }
1524 
1525 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1526  VectorType preconditionType,
1527  Operation *op) {
1528  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1529  return rewriter.notifyMatchFailure(op, "types are not vector");
1530 
1531  if (!preconditionType || preconditionType.getRank() != 1)
1532  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1533 
1534  return commonConversionPrecondition(rewriter, preconditionType, op);
1535 }
1536 
1537 /// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1538 ///
1539 /// Alignment means that `subByteVecTy` can be packed into a vector of
1540 /// `containerTy` elements. More specifically:
1541 /// 1. The bit-width of `containerTy` is a multiple of the
1542 /// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1543 /// this multiple is 4.
1544 /// 2. The multiple from 1. above divides evenly the number of the (trailing)
1545 /// elements in `subByteVecTy`.
1546 ///
1547 /// EXAMPLE 1:
1548 /// `subByteVecTy = vector<2xi4>`, and
1549 /// `containerTy = i16`
1550 ///
1551 /// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1552 ///
1553 /// EXAMPLE 2:
1554 /// `subByteVecTy = vector<3xi4>`, and
1555 /// `containerTy = i16`
1556 ///
1557 /// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1558 ///
1559 /// EXAMPLE 3:
1560 /// `subByteVecTy = vector<3xi3>`, and
1561 /// `containerTy = i16`
1562 ///
1563 /// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1564 ///
1565 /// NOTE: This method assumes that common conversion preconditions are met. In
1566 /// particular, `containerTy` is assumed to be a
1567 /// multi-byte scalar type (e.g., i8, i16, i32).
1568 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1569  VectorType subByteVecTy,
1570  Type containerTy,
1571  Operation *op) {
1572  assert(containerTy.isIntOrFloat() &&
1573  "container element type is not a scalar");
1574 
1575  // TODO: This is validating the inputs rather than checking the conditions
1576  // documented above. Replace with an assert.
1577  if (!subByteVecTy)
1578  return rewriter.notifyMatchFailure(op, "not a vector!");
1579 
1580  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1581  unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1582 
1583  // Enforced by the common pre-conditions.
1584  assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1585 
1586  // TODO: Add support other widths (when/if needed)
1587  if (subByteBits != 2 && subByteBits != 4)
1588  return rewriter.notifyMatchFailure(
1589  op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1590 
1591  // Condition 1 ("per-element" alignment)
1592  if (containerBits % subByteBits != 0)
1593  return rewriter.notifyMatchFailure(op, "unalagined element types");
1594 
1595  // Condition 2 ("full" alignment)
1596  if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1597  return rewriter.notifyMatchFailure(
1598  op, "not possible to fit this sub-byte vector type into a vector of "
1599  "the given multi-byte type");
1600 
1601  return success();
1602 }
1603 
1605 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1607  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1608  shuffleIdx < e; ++shuffleIdx) {
1609  SmallVector<int64_t> shuffles;
1610  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1611 
1612  // Create the attribute quantities for the shuffle / mask / shift ops.
1613  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1614  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1615  ? srcEltRangeList[shuffleIdx].sourceElementIdx
1616  : 0;
1617  shuffles.push_back(sourceElement);
1618 
1619  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1620  ? srcEltRangeList[shuffleIdx].sourceBitBegin
1621  : 0;
1622  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1623  ? srcEltRangeList[shuffleIdx].sourceBitEnd
1624  : 0;
1625  IntegerAttr mask = IntegerAttr::get(
1626  shuffledElementType,
1627  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1628  bitLo, bitHi));
1629  masks.push_back(mask);
1630 
1631  int64_t shiftRight = bitLo;
1632  shiftRightAmounts.push_back(
1633  IntegerAttr::get(shuffledElementType, shiftRight));
1634 
1635  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1636  shiftLeftAmounts.push_back(
1637  IntegerAttr::get(shuffledElementType, shiftLeft));
1638  }
1639 
1640  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1641  }
1642  return result;
1643 }
1644 
1645 Value BitCastRewriter::genericRewriteStep(
1646  PatternRewriter &rewriter, Location loc, Value initialValue,
1647  Value runningResult, const BitCastRewriter::Metadata &metadata) {
1648  // Create vector.shuffle from the metadata.
1649  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
1650  loc, initialValue, initialValue, metadata.shuffles);
1651 
1652  // Intersect with the mask.
1653  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1654  auto constOp = rewriter.create<arith::ConstantOp>(
1655  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1656  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
1657 
1658  // Align right on 0.
1659  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
1660  loc,
1661  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1662  Value shiftedRight =
1663  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1664 
1665  // Shift bits left into their final position.
1666  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
1667  loc,
1668  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1669  Value shiftedLeft =
1670  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1671 
1672  runningResult =
1673  runningResult
1674  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1675  : shiftedLeft;
1676 
1677  return runningResult;
1678 }
1679 
1680 /// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1681 /// Where aligned means it satisfies the alignedConversionPreconditions.
1682 ///
1683 /// Example:
1684 /// vector<16x16xi2> -> vector<16x4xi8>
1685 /// vector<16x16xi4> -> vector<16x8xi8>
1687  Value subByteVec) {
1688  auto srcVecType = cast<VectorType>(subByteVec.getType());
1689  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1690  assert(8 % srcBitwidth == 0 &&
1691  "Unsupported sub-byte type (not a divisor of i8)");
1692  int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1693  SmallVector<int64_t> vecShape(srcVecType.getShape());
1694  // Adjust last dimension of the vector, so the total size remains the same.
1695  vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1696  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1697  return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1698 }
1699 
1700 /// Extracts a signed N-bit sequence from each element of a vector of bytes,
1701 /// starting at the specified bit index.
1702 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1703 ///
1704 /// Example for a single element:
1705 /// Extract numBits=2 starting at bitIdx=2
1706 /// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1707 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1708 /// target = [. . . . ^ ^ . .]
1709 ///
1710 /// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1711 /// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1712 ///
1713 /// src = [01 01 11 10]
1714 /// shl = arith.shl(src, 4) -> [11 10 00 00]
1715 /// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1717  Location loc, Value src,
1718  int bitIdx, int numBits) {
1719  auto srcType = cast<VectorType>(src.getType());
1720  Value shl = src;
1721  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1722  assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1723  "Invalid bitIdx range");
1724  if (bitsToShiftLeft != 0) {
1725  Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
1726  loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1727  shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
1728  }
1729 
1730  int8_t bitsToShiftRight = 8 - numBits;
1731  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1732  loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1733  Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1734  return shr;
1735 }
1736 
1737 /// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1738 /// starting at the specified bit index.
1739 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1740 ///
1741 /// Example for a single element:
1742 /// Extract numBits=2 starting at bitIdx=2
1743 /// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1744 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1745 /// target = [. . . . ^ ^ . .]
1746 ///
1747 /// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1748 /// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1749 ///
1750 /// src = [01 01 10 10]
1751 /// mask = [00 00 00 11]
1752 /// shr = arith.shrui(src, 2) = [00 01 01 10]
1753 /// result = arith.andi(shr, mask) = [00 00 00 10]
1754 /// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1755 /// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1756 /// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1757 /// left when the index is 0.
1759  Location loc, Value src,
1760  int bitIdx, int numBits) {
1761  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1762  "Invalid bitIdx range");
1763  auto srcType = cast<VectorType>(src.getType());
1764  int8_t bitsToShiftRight = bitIdx;
1765  Value shr = src;
1766  if (bitsToShiftRight != 0) {
1767  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1768  loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1769  shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
1770  }
1771  if (bitIdx + numBits == 8) {
1772  return shr;
1773  }
1774  uint8_t lowBitsMask = (1 << numBits) - 1;
1775  Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1776  loc, DenseElementsAttr::get(srcType, lowBitsMask));
1777  return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1778 }
1779 
1781  std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1782 
1783 /// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1784 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1786  Value srcValue, const ExtractNBitsFn &extFn) {
1787  [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1788  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1789  "Expected i4 type");
1790 
1791  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1792  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1793 
1794  // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1795  // byte are place in one vector and the high i4 elements in another vector.
1796  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1797  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1798 
1799  // 3. Interleave low and high i8 elements.
1800  return rewriter.create<vector::InterleaveOp>(loc, low, high);
1801 }
1802 
1803 /// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1804 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1806  Value srcValue, const ExtractNBitsFn &extFn) {
1807  [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1808  assert(srcVecType.getElementType().isSignlessInteger(2) &&
1809  "Expected i2 type");
1810 
1811  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1812  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1813 
1814  // 2. Extract each i2 element
1815  // Positon 0 (bits 0-1)
1816  Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1817  // Position 1 (bits 2-3)
1818  Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1819  // Position 2 (bits 4-5)
1820  Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1821  // Position 3 (bits 6-7)
1822  Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1823 
1824  // 3. Interleave all 4 elements by first interleaving
1825  // even elements and then odd
1826  // vec0 = [0,0,0,0],...
1827  // vec1 = [1,1,1,1],...
1828  // vec2 = [2,2,2,2],...
1829  // vec3 = [3,3,3,3],...
1830  // 02 = [0,2,0,2,0,2,0,2],...
1831  // 13 = [1,3,1,3,1,3,1,3],...
1832  // 0213 = [0,1,2,3,...],...
1833  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
1834  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
1835  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
1836 }
1837 
1838 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1839 /// ops to avoid leaving LLVM to scramble with peephole optimizations.
1841  Value srcValue) {
1842  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1843  assert(srcVecType.getElementType().isSignlessInteger(8) &&
1844  "Expected i8 type");
1845 
1846  // 1. De-interleave low and high i8 elements.
1847  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
1848 
1849  // 2. Zero out the upper side of each low i8 element.
1850  constexpr int8_t i8LowBitMask = 0x0F;
1851  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1852  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
1853  loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1854  Value zeroOutLow = rewriter.create<arith::AndIOp>(
1855  loc, deinterleaveOp.getRes1(), zeroOutMask);
1856 
1857  // 3. Move high i4 values to upper side of the byte.
1858  constexpr int8_t bitsToShift = 4;
1859  auto shiftValues = rewriter.create<arith::ConstantOp>(
1860  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1861  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1862  shiftValues);
1863 
1864  // 4. Merge high and low i4 values.
1865  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1866 
1867  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1868  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1869  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1870 }
1871 
1872 namespace {
1873 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1874 /// advantage of high-level information to avoid leaving LLVM to scramble with
1875 /// peephole optimizations.
1876 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1878 
1879  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1880  PatternRewriter &rewriter) const override {
1881  // The source must be a trunc op.
1882  auto truncOp =
1883  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1884  if (!truncOp)
1885  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1886 
1887  // Set up the BitCastRewriter and verify the precondition.
1888  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1889  VectorType targetVectorType = bitCastOp.getResultVectorType();
1890  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1891  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1892  return failure();
1893 
1894  // Perform the rewrite.
1895  Value truncValue = truncOp.getIn();
1896  auto shuffledElementType =
1897  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1898  Value runningResult;
1899  for (const BitCastRewriter ::Metadata &metadata :
1900  bcr.precomputeMetadata(shuffledElementType)) {
1901  runningResult = bcr.genericRewriteStep(
1902  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1903  }
1904 
1905  // Finalize the rewrite.
1906  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1907  shuffledElementType.getIntOrFloatBitWidth();
1908  if (narrowing) {
1909  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1910  rewriter.replaceOp(bitCastOp, runningResult);
1911  } else {
1912  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1913  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1914  }
1915  } else {
1916  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1917  rewriter.replaceOp(bitCastOp, runningResult);
1918  } else {
1919  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1920  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1921  }
1922  }
1923 
1924  return success();
1925  }
1926 };
1927 } // namespace
1928 
1929 //===----------------------------------------------------------------------===//
1930 // RewriteExtOfBitCast
1931 //===----------------------------------------------------------------------===//
1932 
1933 namespace {
1934 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1935 /// take advantage of high-level information to avoid leaving LLVM to scramble
1936 /// with peephole optimizations.
1937 template <typename ExtOpType>
1938 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1940 
1941  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1942  : OpRewritePattern<ExtOpType>(context, benefit) {}
1943 
1944  LogicalResult matchAndRewrite(ExtOpType extOp,
1945  PatternRewriter &rewriter) const override {
1946  // The source must be a bitcast op.
1947  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1948  if (!bitCastOp)
1949  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1950 
1951  // Set up the BitCastRewriter and verify the precondition.
1952  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1953  VectorType targetVectorType = bitCastOp.getResultVectorType();
1954  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1955  if (failed(bcr.commonPrecondition(
1956  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1957  return failure();
1958 
1959  // Perform the rewrite.
1960  Value runningResult;
1961  Value sourceValue = bitCastOp.getSource();
1962  auto shuffledElementType =
1963  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1964  for (const BitCastRewriter::Metadata &metadata :
1965  bcr.precomputeMetadata(shuffledElementType)) {
1966  runningResult = bcr.genericRewriteStep(
1967  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1968  }
1969 
1970  // Finalize the rewrite.
1971  bool narrowing =
1972  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1973  shuffledElementType.getIntOrFloatBitWidth();
1974  if (narrowing) {
1975  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1976  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1977  } else {
1978  rewriter.replaceOpWithNewOp<ExtOpType>(
1979  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1980  }
1981 
1982  return success();
1983  }
1984 };
1985 
1986 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1987 /// bitwise ops that take advantage of high-level information to avoid leaving
1988 /// LLVM to scramble with peephole optimizations. Templated to choose between
1989 /// signed and unsigned conversions.
1990 ///
1991 /// EXAMPLE 1 (signed):
1992 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
1993 /// is rewriten as:
1994 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1995 /// %1 = arith.shli %0, 4 : vector<4xi8>
1996 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1997 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1998 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1999 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2000 ///
2001 /// EXAMPLE 2 (fp):
2002 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2003 /// is rewriten as:
2004 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2005 /// %1 = arith.shli %0, 4 : vector<4xi8>
2006 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
2007 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
2008 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2009 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2010 ///
2011 /// EXAMPLE 3 (unsigned):
2012 /// arith.extui %in : vector<8xi4> to vector<8xi32>
2013 /// is rewritten as:
2014 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2015 /// %1 = arith.andi %0, 15 : vector<4xi8>
2016 /// %2 = arith.shrui %0, 4 : vector<4xi8>
2017 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2018 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2019 ///
2020 template <typename ConversionOpType, bool isSigned>
2021 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2023 
2024  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2025  PatternRewriter &rewriter) const override {
2026  // Verify the preconditions.
2027  Value srcValue = conversionOp.getIn();
2028  VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2029  VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2030 
2031  if (failed(
2032  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2033  return failure();
2034 
2035  // Check general alignment preconditions.
2036  if (failed(alignedConversionPrecondition(
2037  rewriter, srcVecType,
2038  /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2039  return failure();
2040 
2041  // Perform the rewrite.
2042  Location loc = conversionOp.getLoc();
2043  const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2045  Value subByteExt;
2046  switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2047  case 2:
2048  subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2049  break;
2050  case 4:
2051  subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2052  break;
2053  default:
2054  return failure();
2055  }
2056 
2057  // Finalize the rewrite.
2058  rewriter.replaceOpWithNewOp<ConversionOpType>(
2059  conversionOp, conversionOp.getType(), subByteExt);
2060  return success();
2061  }
2062 };
2063 
2064 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2065 /// bitwise ops that take advantage of high-level information to avoid leaving
2066 /// LLVM to scramble with peephole optimizations.
2067 ///
2068 /// For example:
2069 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
2070 ///
2071 /// is rewriten as:
2072 ///
2073 /// %cst = arith.constant dense<15> : vector<4xi8>
2074 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
2075 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2076 /// %2 = arith.andi %0, %cst : vector<4xi8>
2077 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2078 /// %4 = arith.ori %2, %3 : vector<4xi8>
2079 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2080 ///
2081 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2083 
2084  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2085  PatternRewriter &rewriter) const override {
2086  // Verify the preconditions.
2087  Value srcValue = truncOp.getIn();
2088  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2089  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2090  if (!srcVecType || !dstVecType)
2091  return failure();
2092 
2093  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2094  return failure();
2095 
2096  // TODO: Add support for truncating to i2.
2097  if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2098  return failure();
2099 
2100  // Check general alignment preconditions. We invert the src/dst type order
2101  // to reuse the existing precondition logic.
2102  if (failed(alignedConversionPrecondition(
2103  rewriter, dstVecType,
2104  /*containerTy=*/rewriter.getI8Type(), truncOp)))
2105  return failure();
2106 
2107  // Create a new iX -> i8 truncation op.
2108  Location loc = truncOp.getLoc();
2109  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2110  Value i8TruncVal =
2111  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
2112 
2113  // Rewrite the i8 -> i4 truncation part.
2114  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2115 
2116  // Finalize the rewrite.
2117  rewriter.replaceOp(truncOp, subByteTrunc);
2118  return success();
2119  }
2120 };
2121 
2122 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
2123 /// perform the transpose on wider (byte) element types.
2124 ///
2125 /// EXAMPLE:
2126 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2127 ///
2128 /// is rewritten as:
2129 ///
2130 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2131 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2132 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2133 ///
2134 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2136 
2137  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2138  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2139 
2140  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2141  PatternRewriter &rewriter) const override {
2142  // Precondition: sub-byte integer transpose.
2143  constexpr unsigned minNativeBitwidth = 8;
2144  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2145  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2146  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2147  return rewriter.notifyMatchFailure(transposeOp,
2148  "not a sub-byte transpose");
2149  }
2150 
2151  // Perform the rewrite.
2152  Location loc = transposeOp.getLoc();
2153  // Signed/unsigned interpretation shouldn't matter here as we are just
2154  // transposing the elements and truncating them back to the original size.
2155  // TODO: Use unsigned extension (more efficient) when emulation or backend
2156  // support is available.
2157  auto srcNativeVecType = srcSubByteVecType.cloneWith(
2158  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2159  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
2160  transposeOp.getVector());
2161  Value newTranspose = rewriter.create<vector::TransposeOp>(
2162  loc, extOp, transposeOp.getPermutation());
2163  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2164  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2165  newTranspose);
2166  return success();
2167  }
2168 };
2169 
2170 } // namespace
2171 
2172 //===----------------------------------------------------------------------===//
2173 // Public Interface Definition
2174 //===----------------------------------------------------------------------===//
2175 
2176 // The emulated type is inferred from the converted memref type.
2178  const arith::NarrowTypeEmulationConverter &typeConverter,
2179  RewritePatternSet &patterns, bool disableAtomicRMW) {
2180 
2181  // Populate `vector.*` conversion patterns.
2182  // TODO: #119553 support atomicity
2183  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2184  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2185  typeConverter, patterns.getContext());
2186 
2187  // Populate `vector.*` store conversion patterns. The caller can choose
2188  // to avoid emitting atomic operations and reduce it to read-modify-write
2189  // sequence for stores if it is known there are no thread contentions.
2190  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
2191 }
2192 
2195  // TODO: Document what the emulated type is.
2196  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2197  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2198  benefit);
2199 
2200  // Patterns for aligned cases. We set higher priority as they are expected to
2201  // generate better performance for aligned cases.
2202  // The container type is always i8.
2203  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2204  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2205  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2206  benefit.getBenefit() + 1);
2207  // The container type is always i8.
2208  patterns
2209  .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2210  RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2211  patterns.getContext(), benefit.getBenefit() + 1);
2212 }
2213 
2214 // The container type is always i8.
2217  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2218 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
IntegerType getI4Type()
Definition: Builders.cpp:57
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:428
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:301
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.