MLIR 23.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to emulate
10// narrow types that are not supported by the target hardware, e.g. i4
11// ("emulated type"), using wider types, e.g. i8 ("container type").
12//
13/// Currently, only power-of-two integer types are supported. These are
14/// converted to wider integers that are either 8 bits wide or wider.
15///
16/// TODO: Support for non-powers-of-two.
17//===----------------------------------------------------------------------===//
18
32#include "mlir/IR/Value.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
38#include <cstdint>
39#include <optional>
40
42
43using namespace mlir;
44
45#define DEBUG_TYPE "vector-narrow-type-emulation"
46
49
50//===----------------------------------------------------------------------===//
51// Utils
52//===----------------------------------------------------------------------===//
53
54/// Returns a compressed mask for the emulated vector. For example, when
55/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
56/// elements span two dest elements), this method compresses `vector<8xi1>`
57/// into `vector<2xi1>`.
58///
59/// The compressed/output mask value is set iff any mask in the corresponding
60/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
61/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
62/// following mask:
63///
64/// %mask = [1, 1, 0, 0, 0, 0]
65///
66/// will first be padded in the front with `numFrontPadElems` zeros, and zeros
67/// will be added in the back to make the number of elements a multiple of
68/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
69///
70/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
71///
72/// then it will return the following new compressed mask:
73///
74/// %mask = [1, 1, 0, 0]
75///
76/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
77/// `numSrcElemsPerDest`.
78static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
79 Location loc, Value mask,
80 int numSrcElems,
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
83
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
86
87 auto numDestElems =
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
89 numSrcElemsPerDest;
90
91 Operation *maskOp = mask.getDefiningOp();
93 // TODO: add support to `vector.broadcast`.
94 // Finding the mask creation operation.
95 while (maskOp &&
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
97 maskOp)) {
98 if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
101 }
102 }
103
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
105 maskOp))
106 return failure();
107
108 // Computing the "compressed" mask. All the emulation logic (i.e. computing
109 // new mask index) only happens on the last dimension of the vectors.
110 SmallVector<int64_t> maskShape(
111 cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
114 std::optional<Operation *> newMask =
116 .Case(
117 [&](vector::CreateMaskOp createMaskOp)
118 -> std::optional<Operation *> {
119 OperandRange maskOperands = createMaskOp.getOperands();
120 // The `vector.create_mask` op creates a mask arrangement
121 // without any zeros at the front. Also, because
122 // `numFrontPadElems` is strictly smaller than
123 // `numSrcElemsPerDest`, the compressed mask generated by
124 // padding the original mask by `numFrontPadElems` will not
125 // have any zeros at the front as well.
126 AffineExpr s0;
127 bindSymbols(rewriter.getContext(), s0);
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
129 OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
131 rewriter, loc, s0, origIndex);
132 SmallVector<Value> newMaskOperands(maskOperands.drop_back());
133 newMaskOperands.push_back(
134 getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
135 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
136 newMaskOperands);
137 })
138 .Case([&](vector::ConstantMaskOp constantMaskOp)
139 -> std::optional<Operation *> {
140 // Take the shape of mask, compress its trailing dimension:
141 SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
142 int64_t &maskIndex = maskDimSizes.back();
143 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
144 numSrcElemsPerDest);
145 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
146 maskDimSizes);
147 })
148 .Case([&](arith::ConstantOp constantOp)
149 -> std::optional<Operation *> {
150 // TODO: Support multiple dimensions.
151 if (maskShape.size() != 1)
152 return std::nullopt;
153 // Rearrange the original mask values to cover the whole potential
154 // loading region. For example, in the case of using byte-size for
155 // emulation, given the following mask:
156 //
157 // %mask = [0, 1, 0, 1, 0, 0]
158 //
159 // With front offset of 1, the mask will be padded 0s in the front
160 // and back so that:
161 // 1. It is aligned with the effective loading bits
162 // 2. Its length is multiple of `numSrcElemPerDest` (and the total
163 // coverage size is mulitiple of bytes). The new mask will be like
164 // this before compressing:
165 //
166 // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
167 auto originalMask =
168 cast<DenseIntElementsAttr>(constantOp.getValue());
169 SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
170 paddedMaskValues.append(originalMask.template value_begin<bool>(),
171 originalMask.template value_end<bool>());
172 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
173
174 // Compressing by combining every `numSrcElemsPerDest` elements:
175 SmallVector<bool> compressedMaskValues;
176 for (size_t i = 0; i < paddedMaskValues.size();
177 i += numSrcElemsPerDest) {
178 bool combinedValue = false;
179 for (int j = 0; j < numSrcElemsPerDest; ++j) {
180 combinedValue |= paddedMaskValues[i + j];
181 }
182 compressedMaskValues.push_back(combinedValue);
183 }
184 return arith::ConstantOp::create(
185 rewriter, loc,
186 DenseElementsAttr::get(newMaskType, compressedMaskValues));
187 });
188
189 if (!newMask)
190 return failure();
191
192 while (!extractOps.empty()) {
193 newMask =
194 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
195 extractOps.back().getMixedPosition());
196 extractOps.pop_back();
197 }
198
199 return *newMask;
200}
201
202/// Extracts 1-D subvector from a 1-D vector.
203///
204/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
205/// from `src`, starting at `offset`. The result is also a rank-1 vector:
206///
207/// vector<numElemsToExtract x !elemType>
208///
209/// (`!elType` is the element type of the source vector). As `offset` is a known
210/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
211///
212/// EXAMPLE:
213/// %res = vector.extract_strided_slice %src
214/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
216 Value src, int64_t offset,
217 int64_t numElemsToExtract) {
218 auto vectorType = cast<VectorType>(src.getType());
219 assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
220 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
221 "subvector out of bounds");
222
223 // When extracting all available elements, just use the source vector as the
224 // result.
225 if (vectorType.getNumElements() == numElemsToExtract)
226 return src;
227
228 auto offsets = rewriter.getI64ArrayAttr({offset});
229 auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
230 auto strides = rewriter.getI64ArrayAttr({1});
231
232 auto resultVectorType =
233 VectorType::get({numElemsToExtract}, vectorType.getElementType());
234 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
235 src, offsets, sizes, strides)
236 ->getResult(0);
237}
238
239/// Inserts 1-D subvector into a 1-D vector.
240///
241/// Inserts the input rank-1 source vector into the destination vector starting
242/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
243/// `vector.insert_strided_slice`.
244///
245/// EXAMPLE:
246/// %res = vector.insert_strided_slice %src, %dest
247/// {offsets = [%offset], strides [1]}
249 Value src, Value dest, int64_t offset) {
250 [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
251 [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
254
255 // If overwritting the destination vector, just return the source.
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
257 return src;
258
259 auto offsets = rewriter.getI64ArrayAttr({offset});
260 auto strides = rewriter.getI64ArrayAttr({1});
261 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
262 dest, offsets, strides);
263}
264
265/// Extracts 1-D subvector from a 1-D vector.
266///
267/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268/// from `src`, starting at `offset`. The result is also a rank-1 vector:
269///
270/// vector<numElemsToExtact x !elType>
271///
272/// (`!elType` is the element type of the source vector). As `offset` is assumed
273/// to be a _dynamic_ SSA value, this helper method generates a sequence of
274/// `vector.extract` + `vector.insert` pairs.
275///
276/// EXAMPLE:
277/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279/// %c1 = arith.constant 1 : index
280/// %idx2 = arith.addi %offset, %c1 : index
281/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283/// (...)
285 Value src, Value dest,
286 OpFoldResult offset,
287 int64_t numElemsToExtract) {
288 auto srcVecTy = cast<VectorType>(src.getType());
289 assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
290 // NOTE: We are unable to take the offset into account in the following
291 // assert, hence its still possible that the subvector is out-of-bounds even
292 // if the condition is true.
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
295
296 // When extracting all available elements, just use the source vector as the
297 // result.
298 if (srcVecTy.getNumElements() == numElemsToExtract)
299 return src;
300
301 for (int i = 0; i < numElemsToExtract; ++i) {
302 Value extractLoc =
303 (i == 0) ? dyn_cast<Value>(offset)
304 : arith::AddIOp::create(
305 rewriter, loc, rewriter.getIndexType(),
306 dyn_cast<Value>(offset),
307 arith::ConstantIndexOp::create(rewriter, loc, i));
308 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
309 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
310 }
311 return dest;
312}
313
314/// Inserts 1-D subvector into a 1-D vector.
315///
316/// Inserts the input rank-1 source vector into the destination vector starting
317/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
318/// uses a sequence of `vector.extract` + `vector.insert` pairs.
319///
320/// EXAMPLE:
321/// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
322/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
323/// %c1 = arith.constant 1 : index
324/// %idx2 = arith.addi %offset, %c1 : index
325/// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
326/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
327/// (...)
329 Value src, Value dest,
330 OpFoldResult offset,
331 int64_t numElemsToInsert) {
332 auto srcVecTy = cast<VectorType>(src.getType());
333 auto destVecTy = cast<VectorType>(dest.getType());
334 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
335 "expected source and dest to be rank-1 vector types");
336 (void)srcVecTy;
337 (void)destVecTy;
338 assert(numElemsToInsert > 0 &&
339 "the number of elements to insert must be greater than 0");
340 // NOTE: We are unable to take the offset into account in the following
341 // assert, hence its still possible that the subvector is out-of-bounds even
342 // if the condition is true.
343 assert(numElemsToInsert <= destVecTy.getNumElements() &&
344 "subvector out of bounds");
345
346 Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
347 for (int64_t i = 0; i < numElemsToInsert; ++i) {
348 auto insertLoc =
349 i == 0 ? destOffsetVal
350 : arith::AddIOp::create(
351 rewriter, loc, rewriter.getIndexType(), destOffsetVal,
352 arith::ConstantIndexOp::create(rewriter, loc, i));
353 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
354 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
355 }
356 return dest;
357}
358
359/// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
360///
361/// Specifically, use `containerElemTy` for loading a vector of
362/// `emulatedElemTy`. The load location is given by `base` and
363/// `linearizedIndices`, and the load size is given by
364/// `numEmulatedElementsToLoad`.
366 Value base,
367 OpFoldResult linearizedIndices,
368 int64_t numContainerElemsToLoad,
369 Type emulatedElemTy,
370 Type containerElemTy) {
371 auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
372 emulatedElemTy.getIntOrFloatBitWidth();
373 auto newLoad = vector::LoadOp::create(
374 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
375 base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
376 return vector::BitCastOp::create(
377 rewriter, loc,
378 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
379 emulatedElemTy),
380 newLoad);
381}
382
383/// Downcast two values to `downcastType`, then select values
384/// based on `mask`, and casts the result to `upcastType`.
386 VectorType downcastType,
387 VectorType upcastType, Value mask,
388 Value trueValue, Value falseValue) {
389 assert(
390 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
391 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
392 "expected input and output number of bits to match");
393 if (trueValue.getType() != downcastType) {
394 trueValue =
395 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
396 }
397 if (falseValue.getType() != downcastType) {
398 falseValue =
399 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
400 }
401 Value selectedType =
402 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
403 // Upcast the selected value to the new type.
404 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
405}
406
407/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
408/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
409/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
410/// which elements to store.
411///
412/// Inputs:
413/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
414/// storeIdx = 2
415/// valueToStore = |3|3|3|3| : vector<4xi2>
416/// mask = |0|0|1|1| : vector<4xi1>
417///
418/// Result:
419/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
420static void atomicRMW(OpBuilder &builder, Location loc,
421 MemRefValue linearizedMemref, Value storeIdx,
422 VectorValue valueToStore, Value mask) {
423 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
424
425 // Create an atomic load-modify-write region using
426 // `memref.generic_atomic_rmw`.
427 auto atomicOp = memref::GenericAtomicRMWOp::create(
428 builder, loc, linearizedMemref, ValueRange{storeIdx});
429 Value origValue = atomicOp.getCurrentValue();
430
431 OpBuilder::InsertionGuard guard(builder);
432 builder.setInsertionPointToStart(atomicOp.getBody());
433
434 // Load the original value from memory, and cast it to the original element
435 // type.
436 auto oneElemVecType = VectorType::get({1}, origValue.getType());
437 Value origVecValue = vector::FromElementsOp::create(
438 builder, loc, oneElemVecType, ValueRange{origValue});
439
440 // Construct the final masked value and yield it.
441 Value maskedValue =
442 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
443 oneElemVecType, mask, valueToStore, origVecValue);
444 auto scalarMaskedValue =
445 vector::ExtractOp::create(builder, loc, maskedValue, 0);
446 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
447}
448
449/// Generate a non-atomic read-modify-write sequence for storing to the emulated
450/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
451static void nonAtomicRMW(OpBuilder &builder, Location loc,
452 MemRefValue linearizedMemref, Value linearizedIndex,
453 VectorValue valueToStore, Value mask) {
454 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
455
456 auto oneElemVecType =
457 VectorType::get({1}, linearizedMemref.getType().getElementType());
458 Value origVecValue =
459 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
460 ValueRange{linearizedIndex});
461 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
462 origVecValue);
463
464 Value maskedValue =
465 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
466 oneElemVecType, mask, valueToStore, origVecValue);
467 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
468 linearizedIndex);
469}
470
471/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
472/// and insert it into an empty vector at `insertOffset`.
473/// Inputs:
474/// vec_in = |0|1|2|3| : vector<4xi2>
475/// extractOffset = 1
476/// sliceNumElements = 2
477/// insertOffset = 2
478/// Output:
479/// vec_out = |0|0|1|2| : vector<4xi2>
480static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
482 int64_t extractOffset,
483 int64_t sliceNumElements,
484 int64_t insertOffset) {
485 assert(vector.getType().getRank() == 1 && "expected 1-D vector");
486 auto vectorElementType = vector.getType().getElementType();
487 // TODO: update and use `alignedConversionPrecondition` in the place of
488 // these asserts.
489 assert(
490 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
491 "sliceNumElements * vector element size must be less than or equal to 8");
492 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
493 "vector element must be a valid sub-byte type");
494 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
495 auto emptyByteVector = arith::ConstantOp::create(
496 rewriter, loc,
497 VectorType::get({emulatedPerContainerElem}, vectorElementType),
498 rewriter.getZeroAttr(
499 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
500 auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
501 extractOffset, sliceNumElements);
502 return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
503 insertOffset);
504}
505
506namespace {
507
508//===----------------------------------------------------------------------===//
509// ConvertVectorStore
510//===----------------------------------------------------------------------===//
511
512// Emulate `vector.store` using a multi-byte container type.
513//
514// When `assumeAligned` is true, store offsets are assumed to be aligned to
515// container element boundaries, so a store whose source vector fills whole
516// container elements (isDivisibleInSize) is emitted as a simple bitcast +
517// store without checking the offset. Stores that are not divisible in size
518// are rejected. This is useful for downstream users that have already
519// ensured alignment.
520//
521// The container type is obtained through Op adaptor and would normally be
522// generated via `NarrowTypeEmulationConverter`.
523//
524// EXAMPLE 1
525// (aligned store of i4, emulated using i8 as the container type)
526//
527// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
528//
529// is rewritten as:
530//
531// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
532// vector.store %src_bitcast, %dest_bitcast[%idx]
533// : memref<16xi8>, vector<4xi8>
534//
535// EXAMPLE 2
536// (unaligned store of i2, emulated using i8 as the container type)
537//
538// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
539//
540// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
541// is modelled using 3 bytes:
542//
543// Byte 0 Byte 1 Byte 2
544// +----------+----------+----------+
545// | oooooooo | ooooNNNN | NNoooooo |
546// +----------+----------+----------+
547//
548// N - (N)ew entries (i.e. to be overwritten by vector.store)
549// o - (o)ld entries (to be preserved)
550//
551// For the generated output in the non-atomic case, see:
552// * @vector_store_i2_const_index_two_partial_stores`
553// in:
554// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
555//
556// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
557// `false` to generate non-atomic RMW sequences.
558struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
559 using Base::Base;
560
561 ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW,
562 bool assumeAligned)
563 : OpConversionPattern<vector::StoreOp>(context),
564 disableAtomicRMW(disableAtomicRMW), assumeAligned(assumeAligned) {}
565
566 LogicalResult
567 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter) const override {
569
570 if (op.getValueToStore().getType().getRank() != 1)
571 return rewriter.notifyMatchFailure(op,
572 "only 1-D vectors are supported ATM");
573
574 auto loc = op.getLoc();
575
576 auto valueToStore = cast<VectorValue>(op.getValueToStore());
577 auto containerElemTy =
578 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
579 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
580 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
581 int containerBits = containerElemTy.getIntOrFloatBitWidth();
582
583 // Check per-element alignment.
584 if (containerBits % emulatedBits != 0) {
585 return rewriter.notifyMatchFailure(
586 op, "impossible to pack emulated elements into container elements "
587 "(bit-wise misalignment)");
588 }
589 int emulatedPerContainerElem = containerBits / emulatedBits;
590
591 // Adjust the number of elements to store when emulating narrow types.
592 // Here only the 1-D vector store is considered, and the N-D memref types
593 // should be linearized.
594 // For example, to emulate i4 to i8, the following op:
595 //
596 // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
597 //
598 // can be replaced with
599 //
600 // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
601 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
602 // vector<4xi8>
603
604 auto origElements = valueToStore.getType().getNumElements();
605 // Note, per-element-alignment was already verified above.
606 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
607
608 // In assume-aligned mode, isDivisibleInSize alone is sufficient — the
609 // caller guarantees that store offsets are aligned to container element
610 // boundaries.
611 if (assumeAligned) {
612 if (!isDivisibleInSize)
613 return rewriter.notifyMatchFailure(
614 op, "the source vector does not fill whole container elements "
615 "(not divisible in size)");
616
617 auto stridedMetadata =
618 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
619 OpFoldResult linearizedIndices;
620 std::tie(std::ignore, linearizedIndices) =
622 rewriter, loc, emulatedBits, containerBits,
623 stridedMetadata.getConstifiedMixedOffset(),
624 stridedMetadata.getConstifiedMixedSizes(),
625 stridedMetadata.getConstifiedMixedStrides(),
626 getAsOpFoldResult(adaptor.getIndices()));
627 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
628 int numElements = origElements / emulatedPerContainerElem;
629 auto bitCast = vector::BitCastOp::create(
630 rewriter, loc, VectorType::get(numElements, containerElemTy),
631 op.getValueToStore());
632 rewriter.replaceOpWithNewOp<vector::StoreOp>(
633 op, bitCast.getResult(), memrefBase,
634 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
635 return success();
636 }
637
638 auto stridedMetadata =
639 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
640
641 OpFoldResult linearizedIndices;
642 memref::LinearizedMemRefInfo linearizedInfo;
643 std::tie(linearizedInfo, linearizedIndices) =
645 rewriter, loc, emulatedBits, containerBits,
646 stridedMetadata.getConstifiedMixedOffset(),
647 stridedMetadata.getConstifiedMixedSizes(),
648 stridedMetadata.getConstifiedMixedStrides(),
649 getAsOpFoldResult(adaptor.getIndices()));
650
651 // Use the exact intraDataOffset when it can be folded. Dynamic values are
652 // rejected in this path because a dynamic offset is not necessarily aligned
653 // to a container element boundary. Callers that can guarantee alignment
654 // should use assumeAligned.
655 std::optional<int64_t> foldedNumFrontPadElems =
656 getConstantIntValue(linearizedInfo.intraDataOffset);
657
658 if (!foldedNumFrontPadElems) {
659 return rewriter.notifyMatchFailure(
660 op, "subbyte store emulation: dynamic front padding size is "
661 "not yet implemented");
662 }
663
664 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
665
666 // RMWs are not needed when:
667 // * no _partial_ stores are required.
668 // A partial store is defined as a store in which only a part of the
669 // container element is overwritten, e.g.
670 //
671 // Dest before (8 bits)
672 // +----------+
673 // | 11000000 |
674 // +----------+
675 //
676 // Dest after storing 0xF at offset 4 (in bits)
677 // +----------+
678 // | 11001111 |
679 // +----------+
680 //
681 // At a higher level, this translats to:
682 // 1. The source vector size (in bits) is a multiple of byte size.
683 // 2. The address of the store is aligned to the container type width
684 // boundary.
685 //
686 // EXAMPLE 1:
687 // Requires partial store:
688 // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
689 //
690 // EXAMPLE 2:
691 // Does not require a partial store:
692 // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
693 //
694 // TODO: Take linearizedInfo.linearizedOffset into account. This is
695 // currently not needed/used/exercised as all our tests set offset to 0.
696 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
697
698 if (!emulationRequiresPartialStores) {
699 // Basic case: storing full bytes.
700 auto numElements = origElements / emulatedPerContainerElem;
701 auto bitCast = vector::BitCastOp::create(
702 rewriter, loc, VectorType::get(numElements, containerElemTy),
703 op.getValueToStore());
704 rewriter.replaceOpWithNewOp<vector::StoreOp>(
705 op, bitCast.getResult(), memrefBase,
706 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
707 return success();
708 }
709
710 // Next, handle the case when sub-byte read-modify-write
711 // sequences are needed to emulate a vector store.
712 // Here is an example:
713 //
714 // Vector to store: vector<7xi2>
715 // Value to store: 11 11 11 11 11 11 11 (all ones)
716 //
717 // Destination: memref<12xi2>
718 // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
719 //
720 // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
721 //
722 // Destination memref before:
723 //
724 // Byte 0 Byte 1 Byte 2
725 // +----------+----------+----------+
726 // | 00000000 | 00000000 | 00000000 |
727 // +----------+----------+----------+
728 //
729 // Destination memref after:
730 //
731 // Byte 0 Byte 1 Byte 2
732 // +----------+----------+----------+
733 // | 00001111 | 11111111 | 11000000 |
734 // +----------+----------+----------+
735 //
736 // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
737 // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
738 // requiring RMW access (atomicity is required).
739
740 // The index into the target memref we are storing to.
741 Value currentDestIndex =
742 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
743 // The index into the source vector we are currently processing.
744 auto currentSourceIndex = 0;
745
746 // Build a mask used for rmw.
747 auto subWidthStoreMaskType =
748 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
749
750 auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
751
752 // 1. Partial width store for the leading byte.
753 // When the store address is not aligned to emulated width boundary, deal
754 // with the unaligned part so that the rest elements are aligned to width
755 // boundary.
756 auto frontSubWidthStoreElem =
757 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
758 emulatedPerContainerElem;
759 if (frontSubWidthStoreElem > 0) {
760 SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
761 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
762 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
763 origElements, true);
764 frontSubWidthStoreElem = origElements;
765 } else {
766 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
767 *foldedNumFrontPadElems, true);
768 }
769 auto frontMask = arith::ConstantOp::create(
770 rewriter, loc,
771 DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
772
773 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
774 auto value =
775 extractSliceIntoByte(rewriter, loc, valueToStore, 0,
776 frontSubWidthStoreElem, *foldedNumFrontPadElems);
777
778 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
779 cast<VectorValue>(value), frontMask.getResult());
780 }
781
782 if (currentSourceIndex >= origElements) {
783 rewriter.eraseOp(op);
784 return success();
785 }
786
787 // Increment the destination index by 1 to align to the emulated width
788 // boundary.
789 auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
790 currentDestIndex = arith::AddIOp::create(
791 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
792
793 // 2. Full width store for the inner output bytes.
794 // After the previous step, the store address is aligned to the emulated
795 // width boundary.
796 int64_t fullWidthStoreSize =
797 (origElements - currentSourceIndex) / emulatedPerContainerElem;
798 int64_t numNonFullWidthElements =
799 fullWidthStoreSize * emulatedPerContainerElem;
800 if (fullWidthStoreSize > 0) {
801 auto fullWidthStorePart = staticallyExtractSubvector(
802 rewriter, loc, valueToStore, currentSourceIndex,
803 numNonFullWidthElements);
804
805 auto originType = cast<VectorType>(fullWidthStorePart.getType());
806 auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
807 auto storeType = VectorType::get(
808 {originType.getNumElements() / emulatedPerContainerElem},
809 memrefElemType);
810 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
811 fullWidthStorePart);
812 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
813 currentDestIndex);
814
815 currentSourceIndex += numNonFullWidthElements;
816 currentDestIndex = arith::AddIOp::create(
817 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
818 arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
819 }
820
821 // 3. Partial width store for the trailing output byte.
822 // It is needed when the residual length is smaller than the emulated width,
823 // which is not covered in step 2 above.
824 auto remainingElements = origElements - currentSourceIndex;
825 if (remainingElements != 0) {
826 auto subWidthStorePart =
827 extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
828 currentSourceIndex, remainingElements, 0);
829
830 // Generate back mask.
831 auto maskValues = SmallVector<bool>(emulatedPerContainerElem, false);
832 std::fill_n(maskValues.begin(), remainingElements, 1);
833 auto backMask = arith::ConstantOp::create(
834 rewriter, loc,
835 DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
836
837 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
838 cast<VectorValue>(subWidthStorePart), backMask.getResult());
839 }
840
841 rewriter.eraseOp(op);
842 return success();
843 }
844
845private:
846 const bool disableAtomicRMW;
847 const bool assumeAligned;
848};
849
850//===----------------------------------------------------------------------===//
851// ConvertVectorMaskedStore
852//===----------------------------------------------------------------------===//
853
854/// Converts `vector.maskedstore` operations on narrow element types to work
855/// with wider, byte-aligned container types by adjusting the mask and using
856/// bitcasting.
857///
858/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
859/// (each `i8` container element holds two `i4` values) and storing with an
860/// adjusted mask .
861struct ConvertVectorMaskedStore final
862 : OpConversionPattern<vector::MaskedStoreOp> {
863 using Base::Base;
864
865 LogicalResult
866 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
867 ConversionPatternRewriter &rewriter) const override {
868
869 // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
870 if (op.getValueToStore().getType().getRank() != 1)
871 return rewriter.notifyMatchFailure(
872 op, "Memref in vector.maskedstore op must be flattened beforehand.");
873
874 auto loc = op.getLoc();
875 auto containerElemTy =
876 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
877 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
878 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
879 int containerBits = containerElemTy.getIntOrFloatBitWidth();
880
881 // Check per-element alignment.
882 if (containerBits % emulatedBits != 0) {
883 return rewriter.notifyMatchFailure(
884 op, "impossible to pack emulated elements into container elements "
885 "(bit-wise misalignment)");
886 }
887
888 int emulatedPerContainerElem = containerBits / emulatedBits;
889 int origElements = op.getValueToStore().getType().getNumElements();
890 if (origElements % emulatedPerContainerElem != 0)
891 return failure();
892
893 auto stridedMetadata =
894 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
895 OpFoldResult linearizedIndicesOfr;
896 memref::LinearizedMemRefInfo linearizedInfo;
897 std::tie(linearizedInfo, linearizedIndicesOfr) =
899 rewriter, loc, emulatedBits, containerBits,
900 stridedMetadata.getConstifiedMixedOffset(),
901 stridedMetadata.getConstifiedMixedSizes(),
902 stridedMetadata.getConstifiedMixedStrides(),
903 getAsOpFoldResult(adaptor.getIndices()));
904 Value linearizedIndices =
905 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
906
907 // Load the whole data and use arith.select to handle the corner cases.
908 //
909 // As an example, for this masked store of i4 values:
910 //
911 // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
912 //
913 // and given these input values:
914 //
915 // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
916 // %0[%c0, %c0] =
917 // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
918 // %val_to_store =
919 // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
920 //
921 // we'll have the following i4 output:
922 //
923 // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
924 //
925 // Emulating the above using i8 will give:
926 //
927 // %compressed_mask = [1, 1, 1, 0] (4 * i1)
928 // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
929 // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
930 // %select_using_shifted_mask =
931 // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
932 // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
933 //
934 // Using the compressed mask to store %packed_data results in expected
935 // output.
936 //
937 // FIXME: Make an example based on the comment above work (see #115460 for
938 // reproducer).
939 FailureOr<Operation *> newMask = getCompressedMaskOp(
940 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
941 if (failed(newMask))
942 return failure();
943
944 auto numElements = (origElements + emulatedPerContainerElem - 1) /
945 emulatedPerContainerElem;
946 auto newType = VectorType::get(numElements, containerElemTy);
947 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
948 rewriter.getZeroAttr(newType));
949
950 auto newLoad = vector::MaskedLoadOp::create(
951 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
952 newMask.value()->getResult(0), passThru);
953
954 auto newBitCastType =
955 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
956 Value valueToStore =
957 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
958 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
959 op.getValueToStore(), valueToStore);
960 valueToStore =
961 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
962
963 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
964 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
965 valueToStore);
966 return success();
967 }
968};
969
970//===----------------------------------------------------------------------===//
971// ConvertVectorLoad
972//===----------------------------------------------------------------------===//
973
974/// Converts `vector.load` on narrow element types to work with
975/// wider, byte-aligned container types by adjusting load sizes and using
976/// bitcasting.
977///
978/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
979/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
980/// container holds two `i4` values) and bitcasting back.
981///
982/// There are cases where the number of elements to load is not byte-aligned. In
983/// those cases, loads are converted to byte-aligned, byte-sized loads and the
984/// target vector is extracted from the loaded vector.
985struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
986 using Base::Base;
987
988 LogicalResult
989 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
990 ConversionPatternRewriter &rewriter) const override {
991 // Prerequisite: memref in the vector.load op is flattened into 1-D.
992 if (op.getVectorType().getRank() != 1)
993 return rewriter.notifyMatchFailure(
994 op, "Memref in emulated vector ops must be flattened beforehand.");
995
996 auto loc = op.getLoc();
997 auto containerElemTy =
998 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
999 Type emulatedElemTy = op.getType().getElementType();
1000 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1001 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1002
1003 // Check per-element alignment.
1004 if (containerBits % emulatedBits != 0) {
1005 return rewriter.notifyMatchFailure(
1006 op, "impossible to pack emulated elements into container elements "
1007 "(bit-wise misalignment)");
1008 }
1009 int emulatedPerContainerElem = containerBits / emulatedBits;
1010
1011 // Adjust the number of elements to load when emulating narrow types,
1012 // and then cast back to the original type with vector.bitcast op.
1013 // For example, to emulate i4 to i8, the following op:
1014 //
1015 // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
1016 //
1017 // can be replaced with
1018 //
1019 // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
1020 // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
1021 //
1022 // There are cases where the number of elements to load is not byte-aligned,
1023 // for example:
1024 //
1025 // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
1026 //
1027 // we will have to load extra bytes and extract the exact slice in between.
1028 //
1029 // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
1030 // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
1031 // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
1032 // = [1]}
1033 // : vector<8xi2> to vector<3xi2>
1034 //
1035 // TODO: Currently the extract_strided_slice's attributes must be known at
1036 // compile time as they must be constants.
1037
1038 auto origElements = op.getVectorType().getNumElements();
1039 // Note, per-element-alignment was already verified above.
1040 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1041
1042 auto stridedMetadata =
1043 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1044
1045 OpFoldResult linearizedIndices;
1046 memref::LinearizedMemRefInfo linearizedInfo;
1047 std::tie(linearizedInfo, linearizedIndices) =
1049 rewriter, loc, emulatedBits, containerBits,
1050 stridedMetadata.getConstifiedMixedOffset(),
1051 stridedMetadata.getConstifiedMixedSizes(),
1052 stridedMetadata.getConstifiedMixedStrides(),
1053 getAsOpFoldResult(adaptor.getIndices()));
1054
1055 std::optional<int64_t> foldedIntraVectorOffset =
1056 isDivisibleInSize ? 0
1057 : getConstantIntValue(linearizedInfo.intraDataOffset);
1058
1059 // Always load enough elements which can cover the original elements.
1060 int64_t maxintraDataOffset =
1061 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1062 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1063 emulatedPerContainerElem);
1064 Value result =
1065 emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1066 numElements, emulatedElemTy, containerElemTy);
1067
1068 if (!foldedIntraVectorOffset) {
1069 auto resultVector = arith::ConstantOp::create(
1070 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1072 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
1073 linearizedInfo.intraDataOffset, origElements);
1074 } else if (!isDivisibleInSize) {
1076 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1077 }
1078 rewriter.replaceOp(op, result);
1079 return success();
1080 }
1081};
1082
1083//===----------------------------------------------------------------------===//
1084// ConvertVectorMaskedLoad
1085//===----------------------------------------------------------------------===//
1086
1087/// Converts `vector.maskedload` operations on narrow element types to work with
1088/// wider, byte-aligned container types by adjusting the mask and using
1089/// bitcasting.
1090///
1091/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1092/// bitcasting, since each `i8` container element holds two `i4` values.
1093struct ConvertVectorMaskedLoad final
1094 : OpConversionPattern<vector::MaskedLoadOp> {
1095 using Base::Base;
1096
1097 LogicalResult
1098 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter) const override {
1100 if (op.getVectorType().getRank() != 1)
1101 return rewriter.notifyMatchFailure(
1102 op, "Memref in emulated vector ops must be flattened beforehand.");
1103
1104 auto loc = op.getLoc();
1105
1106 auto containerElemTy =
1107 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1108 Type emulatedElemTy = op.getType().getElementType();
1109 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1110 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1111
1112 // Check per-element alignment.
1113 if (containerBits % emulatedBits != 0) {
1114 return rewriter.notifyMatchFailure(
1115 op, "impossible to pack emulated elements into container elements "
1116 "(bit-wise misalignment)");
1117 }
1118 int emulatedPerContainerElem = containerBits / emulatedBits;
1119
1120 // Adjust the number of elements to load when emulating narrow types,
1121 // and then cast back to the original type with vector.bitcast op.
1122 // For example, to emulate i4 to i8, the following op:
1123 //
1124 // %mask = vector.constant_mask [3] : vector<6xi1>
1125 // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1126 // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1127 //
1128 // can be replaced with
1129 //
1130 // %new_mask = vector.constant_mask [2] : vector<3xi1>
1131 // %new_pass_thru = vector.bitcast %pass_thru :
1132 // vector<6xi4> to vector<3xi8>
1133 // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1134 // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1135 // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1136 //
1137 // Since we are effectively loading 16 bits (2xi8) from the memref with the
1138 // new mask, while originally we only wanted to effectively load 12 bits
1139 // (3xi4) from the memref, we need to set the second half of the last i8
1140 // that was effectively loaded (i.e. the second i8) to %pass_thru.
1141 //
1142 // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1143 //
1144 // Given these input values:
1145 // %mask = [1, 1, 1, 0, 0, 0]
1146 // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1147 // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1148 //
1149 // we'll have:
1150 //
1151 // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1152 //
1153 // %new_mask = [1, 1, 0]
1154 // %new_pass_thru = [0x78, 0x9A, 0xBC]
1155 // %1 = [0x12, 0x34, 0xBC]
1156 // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1157 // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1158 //
1159 // TODO: Currently, only the even number of elements loading is supported.
1160 // To deal with the odd number of elements, one has to extract the
1161 // subvector at the proper offset after bit-casting.
1162 auto origType = op.getVectorType();
1163 auto origElements = origType.getNumElements();
1164 // Note, per-element-alignment was already verified above.
1165 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1166
1167 auto stridedMetadata =
1168 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1169 OpFoldResult linearizedIndices;
1170 memref::LinearizedMemRefInfo linearizedInfo;
1171 std::tie(linearizedInfo, linearizedIndices) =
1173 rewriter, loc, emulatedBits, containerBits,
1174 stridedMetadata.getConstifiedMixedOffset(),
1175 stridedMetadata.getConstifiedMixedSizes(),
1176 stridedMetadata.getConstifiedMixedStrides(),
1177 getAsOpFoldResult(adaptor.getIndices()));
1178
1179 std::optional<int64_t> foldedIntraVectorOffset =
1180 isDivisibleInSize ? 0
1181 : getConstantIntValue(linearizedInfo.intraDataOffset);
1182
1183 int64_t maxIntraDataOffset =
1184 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1185 FailureOr<Operation *> newMask =
1186 getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1187 emulatedPerContainerElem, maxIntraDataOffset);
1188 if (failed(newMask))
1189 return failure();
1190
1191 Value passthru = op.getPassThru();
1192
1193 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1194 emulatedPerContainerElem);
1195 auto loadType = VectorType::get(numElements, containerElemTy);
1196 auto newBitcastType =
1197 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1198
1199 auto emptyVector = arith::ConstantOp::create(
1200 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1201 if (!foldedIntraVectorOffset) {
1202 passthru = dynamicallyInsertSubVector(
1203 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1204 origElements);
1205 } else if (!isDivisibleInSize) {
1206 passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1207 *foldedIntraVectorOffset);
1208 }
1209 auto newPassThru =
1210 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1211
1212 // Generating the new masked load.
1213 auto newLoad = vector::MaskedLoadOp::create(
1214 rewriter, loc, loadType, adaptor.getBase(),
1215 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1216 newMask.value()->getResult(0), newPassThru);
1217
1218 // Setting the part that originally was not effectively loaded from memory
1219 // to pass through.
1220 auto bitCast =
1221 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1222
1223 Value mask = op.getMask();
1224 auto newSelectMaskType = VectorType::get(
1225 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1226 // TODO: try to fold if op's mask is constant
1227 auto emptyMask =
1228 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1229 rewriter.getZeroAttr(newSelectMaskType));
1230 if (!foldedIntraVectorOffset) {
1231 mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1232 linearizedInfo.intraDataOffset,
1233 origElements);
1234 } else if (!isDivisibleInSize) {
1235 mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1236 *foldedIntraVectorOffset);
1237 }
1238
1239 Value result =
1240 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1241 if (!foldedIntraVectorOffset) {
1243 rewriter, loc, result, op.getPassThru(),
1244 linearizedInfo.intraDataOffset, origElements);
1245 } else if (!isDivisibleInSize) {
1247 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1248 }
1249 rewriter.replaceOp(op, result);
1250
1251 return success();
1252 }
1253};
1254
1255/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1256///
1257/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1258/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1259/// (a multi-byte scalar, e.g. i16), where N is some integer.
1260///
1261/// Put differently, this method checks whether this would be valid:
1262///
1263/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1264///
1265/// EXAMPLES:
1266/// * vector<4xi4> -> i16 - yes (N = 1)
1267/// * vector<4xi4> -> i8 - yes (N = 2)
1268/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1269/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1270static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1271 Type multiByteScalarTy) {
1272 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1273
1274 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1275 int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1276
1277 assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1278 assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1279 assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1280
1281 int elemsPerMultiByte = multiByteBits / subByteBits;
1282
1283 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1284}
1285
1286//===----------------------------------------------------------------------===//
1287// ConvertVectorTransferRead
1288//===----------------------------------------------------------------------===//
1289
1290// TODO: Document-me
1291struct ConvertVectorTransferRead final
1292 : OpConversionPattern<vector::TransferReadOp> {
1293 using Base::Base;
1294
1295 LogicalResult
1296 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1297 ConversionPatternRewriter &rewriter) const override {
1298
1299 // Prerequisites: memref in the vector.transfer_read op is flattened into
1300 // 1-D.
1301 if (op.getVectorType().getRank() != 1)
1302 return rewriter.notifyMatchFailure(
1303 op, "Memref in emulated vector ops must be flattened beforehand.");
1304
1305 auto loc = op.getLoc();
1306 auto containerElemTy =
1307 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1308 Type emulatedElemTy = op.getType().getElementType();
1309 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1310 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1311
1312 // Check per-element alignment.
1313 if (containerBits % emulatedBits != 0) {
1314 return rewriter.notifyMatchFailure(
1315 op, "impossible to pack emulated elements into container elements "
1316 "(bit-wise misalignment)");
1317 }
1318 int emulatedPerContainerElem = containerBits / emulatedBits;
1319
1320 auto origElements = op.getVectorType().getNumElements();
1321
1322 // Note, per-element-alignment was already verified above.
1323 bool isDivisibleInSize =
1324 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1325
1326 // Pad the padding value with 0s on the left. These bits are discarded and
1327 // thus their values don't matter.
1328 Value padding = adaptor.getPadding();
1329 if (!padding.getType().isInteger()) {
1330 padding = arith::BitcastOp::create(
1331 rewriter, loc,
1332 IntegerType::get(rewriter.getContext(),
1333 padding.getType().getIntOrFloatBitWidth()),
1334 padding);
1335 }
1336 auto newPadding =
1337 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1338
1339 auto stridedMetadata =
1340 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1341
1342 OpFoldResult linearizedIndices;
1343 memref::LinearizedMemRefInfo linearizedInfo;
1344 std::tie(linearizedInfo, linearizedIndices) =
1346 rewriter, loc, emulatedBits, containerBits,
1347 stridedMetadata.getConstifiedMixedOffset(),
1348 stridedMetadata.getConstifiedMixedSizes(),
1349 stridedMetadata.getConstifiedMixedStrides(),
1350 getAsOpFoldResult(adaptor.getIndices()));
1351
1352 std::optional<int64_t> foldedIntraVectorOffset =
1353 isDivisibleInSize ? 0
1354 : getConstantIntValue(linearizedInfo.intraDataOffset);
1355
1356 int64_t maxIntraDataOffset =
1357 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1358 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1359 emulatedPerContainerElem);
1360
1361 auto newRead = vector::TransferReadOp::create(
1362 rewriter, loc, VectorType::get(numElements, containerElemTy),
1363 adaptor.getBase(),
1364 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1365 newPadding);
1366
1367 auto bitCast = vector::BitCastOp::create(
1368 rewriter, loc,
1369 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1370 newRead);
1371
1372 Value result = bitCast->getResult(0);
1373 if (!foldedIntraVectorOffset) {
1374 auto zeros = arith::ConstantOp::create(
1375 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1376 result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1377 linearizedInfo.intraDataOffset,
1378 origElements);
1379 } else if (!isDivisibleInSize) {
1381 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1382 }
1383 rewriter.replaceOp(op, result);
1384
1385 return success();
1386 }
1387};
1388} // end anonymous namespace
1389
1390//===----------------------------------------------------------------------===//
1391// RewriteBitCastOfTruncI
1392//===----------------------------------------------------------------------===//
1393
1394namespace {
1395
1396/// Helper struct to keep track of the provenance of a contiguous set of bits
1397/// in a source vector.
1398struct SourceElementRange {
1399 /// The index of the source vector element that contributes bits to *this.
1400 int64_t sourceElementIdx;
1401 /// The range of bits in the source vector element that contribute to *this.
1402 int64_t sourceBitBegin;
1403 int64_t sourceBitEnd;
1404};
1405
1406struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1407 /// Given the index of a SourceElementRange in the SourceElementRangeList,
1408 /// compute the amount of bits that need to be shifted to the left to get the
1409 /// bits in their final location. This shift amount is simply the sum of the
1410 /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1411 /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1412 int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1413 int64_t res = 0;
1414 for (int64_t i = 0; i < shuffleIdx; ++i)
1415 res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1416 return res;
1417 }
1418};
1419
1420/// Helper struct to enumerate the source elements and bit ranges that are
1421/// involved in a bitcast operation.
1422/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1423/// any 1-D vector shape and any source/target bitwidths.
1424/// This creates and holds a mapping of the form:
1425/// [dstVectorElementJ] ==
1426/// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1427/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1428/// [0] = {0, [0-8)}
1429/// [1] = {0, [8-16)}
1430/// [2] = {0, [16-24)}
1431/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1432/// [0] = {0, [0, 10)}, {1, [0, 5)}
1433/// [1] = {1, [5, 10)}, {2, [0, 10)}
1434struct BitCastBitsEnumerator {
1435 BitCastBitsEnumerator(VectorType sourceVectorType,
1436 VectorType targetVectorType);
1437
1438 int64_t getMaxNumberOfEntries() {
1439 int64_t numVectors = 0;
1440 for (const auto &l : sourceElementRanges)
1441 numVectors = std::max(numVectors, (int64_t)l.size());
1442 return numVectors;
1443 }
1444
1445 VectorType sourceVectorType;
1446 VectorType targetVectorType;
1447 SmallVector<SourceElementRangeList> sourceElementRanges;
1448};
1449
1450/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1451/// advantage of high-level information to avoid leaving LLVM to scramble with
1452/// peephole optimizations.
1453/// BitCastBitsEnumerator encodes for each element of the target vector the
1454/// provenance of the bits in the source vector. We can "transpose" this
1455/// information to build a sequence of shuffles and bitwise ops that will
1456/// produce the desired result.
1457//
1458/// Consider the following motivating example:
1459/// ```
1460/// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1461/// ```
1462//
1463/// BitCastBitsEnumerator contains the following information:
1464/// ```
1465/// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1466/// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1467/// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1468/// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1469/// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1470/// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1471/// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1472/// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1473/// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1474/// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1475/// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1476/// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1477/// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1478/// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1479/// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1480/// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1481/// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1482/// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1483/// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1484/// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1485/// ```
1486///
1487/// In the above, each row represents one target vector element and each
1488/// column represents one bit contribution from a source vector element.
1489/// The algorithm creates vector.shuffle operations (in this case there are 3
1490/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1491/// algorithm populates the bits as follows:
1492/// ```
1493/// src bits 0 ...
1494/// 1st shuffle |xxxxx |xx |...
1495/// 2nd shuffle | xxx| xxxxx |...
1496/// 3rd shuffle | | x|...
1497/// ```
1498//
1499/// The algorithm proceeds as follows:
1500/// 1. for each vector.shuffle, collect the source vectors that participate in
1501/// this shuffle. One source vector per target element of the resulting
1502/// vector.shuffle. If there is no source element contributing bits for the
1503/// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1504/// 2 columns).
1505/// 2. represent the bitrange in the source vector as a mask. If there is no
1506/// source element contributing bits for the current vector.shuffle, take 0.
1507/// 3. shift right by the proper amount to align the source bitrange at
1508/// position 0. This is exactly the low end of the bitrange. For instance,
1509/// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1510/// shift right by 3 to get the bits contributed by the source element #1
1511/// into position 0.
1512/// 4. shift left by the proper amount to to align to the desired position in
1513/// the result element vector. For instance, the contribution of the second
1514/// source element for the first row needs to be shifted by `5` to form the
1515/// first i8 result element.
1516///
1517/// Eventually, we end up building the sequence
1518/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1519/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1520/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1521struct BitCastRewriter {
1522 /// Helper metadata struct to hold the static quantities for the rewrite.
1523 struct Metadata {
1524 SmallVector<int64_t> shuffles;
1525 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1526 };
1527
1528 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1529
1530 /// Verify that general preconditions for the rewrite are met.
1531 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1532 VectorType preconditionType, Operation *op);
1533
1534 /// Precompute the metadata for the rewrite.
1535 SmallVector<BitCastRewriter::Metadata>
1536 precomputeMetadata(IntegerType shuffledElementType);
1537
1538 /// Rewrite one step of the sequence:
1539 /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1540 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1541 Value initialValue, Value runningResult,
1542 const BitCastRewriter::Metadata &metadata);
1543
1544private:
1545 /// Underlying enumerator that encodes the provenance of the bits in the each
1546 /// element of the result vector.
1547 BitCastBitsEnumerator enumerator;
1548};
1549
1550} // namespace
1551
1552[[maybe_unused]] static raw_ostream &
1554 for (const auto &l : vec) {
1555 for (auto it : llvm::enumerate(l)) {
1556 os << "{ " << it.value().sourceElementIdx << ": b@["
1557 << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1558 << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1559 }
1560 os << "\n";
1561 }
1562 return os;
1563}
1564
1565BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1566 VectorType targetVectorType)
1567 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1568
1569 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1570 "requires -D non-scalable vector type");
1571 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1572 "requires -D non-scalable vector type");
1573 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1574 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1575 LDBG() << "sourceVectorType: " << sourceVectorType;
1576
1577 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1578 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1579 LDBG() << "targetVectorType: " << targetVectorType;
1580
1581 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1582 (void)mostMinorSourceDim;
1583 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1584 "source and target bitwidths must match");
1585
1586 // Prepopulate one source element range per target element.
1587 sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1588 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1589 int64_t resultElement = resultBit / targetBitWidth;
1590 int64_t resultBitInElement = resultBit % targetBitWidth;
1591 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1592 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1593 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1594 targetBitWidth - resultBitInElement);
1595 sourceElementRanges[resultElement].push_back(
1596 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1597 resultBit += step;
1598 }
1599}
1600
1601BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1602 VectorType targetVectorType)
1603 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1604 LDBG() << "\n" << enumerator.sourceElementRanges;
1605}
1606
1607/// Verify that the precondition type meets the common preconditions for any
1608/// conversion.
1609static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1610 VectorType preconditionType,
1611 Operation *op) {
1612 if (!preconditionType || preconditionType.isScalable())
1613 return rewriter.notifyMatchFailure(op, "scalable vector");
1614
1615 // TODO: consider relaxing this restriction in the future if we find ways
1616 // to really work with subbyte elements across the MLIR/LLVM boundary.
1617 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1618 if (bitwidth % 8 != 0)
1619 return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1620
1621 return success();
1622}
1623
1624LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1625 VectorType preconditionType,
1626 Operation *op) {
1627 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1628 return rewriter.notifyMatchFailure(op, "types are not vector");
1629
1630 if (!preconditionType || preconditionType.getRank() != 1)
1631 return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1632
1633 return commonConversionPrecondition(rewriter, preconditionType, op);
1634}
1635
1636/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1637///
1638/// Alignment means that `subByteVecTy` can be packed into a vector of
1639/// `containerTy` elements. More specifically:
1640/// 1. The bit-width of `containerTy` is a multiple of the
1641/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1642/// this multiple is 4.
1643/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1644/// elements in `subByteVecTy`.
1645///
1646/// EXAMPLE 1:
1647/// `subByteVecTy = vector<2xi4>`, and
1648/// `containerTy = i16`
1649///
1650/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1651///
1652/// EXAMPLE 2:
1653/// `subByteVecTy = vector<3xi4>`, and
1654/// `containerTy = i16`
1655///
1656/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1657///
1658/// EXAMPLE 3:
1659/// `subByteVecTy = vector<3xi3>`, and
1660/// `containerTy = i16`
1661///
1662/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1663///
1664/// NOTE: This method assumes that common conversion preconditions are met. In
1665/// particular, `containerTy` is assumed to be a
1666/// multi-byte scalar type (e.g., i8, i16, i32).
1667static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1668 VectorType subByteVecTy,
1669 Type containerTy,
1670 Operation *op) {
1671 assert(containerTy.isIntOrFloat() &&
1672 "container element type is not a scalar");
1673
1674 // TODO: This is validating the inputs rather than checking the conditions
1675 // documented above. Replace with an assert.
1676 if (!subByteVecTy)
1677 return rewriter.notifyMatchFailure(op, "not a vector!");
1678
1679 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1680 unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1681
1682 // Enforced by the common pre-conditions.
1683 assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1684
1685 // TODO: Add support other widths (when/if needed)
1686 if (subByteBits != 2 && subByteBits != 4)
1687 return rewriter.notifyMatchFailure(
1688 op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1689
1690 // Condition 1 ("per-element" alignment)
1691 if (containerBits % subByteBits != 0)
1692 return rewriter.notifyMatchFailure(op, "unalagined element types");
1693
1694 // Condition 2 ("full" alignment)
1695 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1696 return rewriter.notifyMatchFailure(
1697 op, "not possible to fit this sub-byte vector type into a vector of "
1698 "the given multi-byte type");
1699
1700 return success();
1701}
1702
1703SmallVector<BitCastRewriter::Metadata>
1704BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1705 SmallVector<BitCastRewriter::Metadata> result;
1706 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1707 shuffleIdx < e; ++shuffleIdx) {
1708 SmallVector<int64_t> shuffles;
1709 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1710
1711 // Create the attribute quantities for the shuffle / mask / shift ops.
1712 for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1713 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1714 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1715 : 0;
1716 shuffles.push_back(sourceElement);
1717
1718 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1719 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1720 : 0;
1721 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1722 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1723 : 0;
1724 IntegerAttr mask = IntegerAttr::get(
1725 shuffledElementType,
1726 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1727 bitLo, bitHi));
1728 masks.push_back(mask);
1729
1730 int64_t shiftRight = bitLo;
1731 shiftRightAmounts.push_back(
1732 IntegerAttr::get(shuffledElementType, shiftRight));
1733
1734 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1735 shiftLeftAmounts.push_back(
1736 IntegerAttr::get(shuffledElementType, shiftLeft));
1737 }
1738
1739 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1740 }
1741 return result;
1742}
1743
1744Value BitCastRewriter::genericRewriteStep(
1745 PatternRewriter &rewriter, Location loc, Value initialValue,
1746 Value runningResult, const BitCastRewriter::Metadata &metadata) {
1747 // Create vector.shuffle from the metadata.
1748 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1749 initialValue, metadata.shuffles);
1750
1751 // Intersect with the mask.
1752 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1753 auto constOp = arith::ConstantOp::create(
1754 rewriter, loc,
1755 DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1756 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1757
1758 // Align right on 0.
1759 auto shiftRightConstantOp = arith::ConstantOp::create(
1760 rewriter, loc,
1761 DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1762 Value shiftedRight =
1763 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1764
1765 // Shift bits left into their final position.
1766 auto shiftLeftConstantOp = arith::ConstantOp::create(
1767 rewriter, loc,
1768 DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1769 Value shiftedLeft =
1770 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1771
1772 runningResult =
1773 runningResult
1774 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1775 : shiftedLeft;
1776
1777 return runningResult;
1778}
1779
1780/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1781/// Where aligned means it satisfies the alignedConversionPreconditions.
1782///
1783/// Example:
1784/// vector<16x16xi2> -> vector<16x4xi8>
1785/// vector<16x16xi4> -> vector<16x8xi8>
1787 Value subByteVec) {
1788 auto srcVecType = cast<VectorType>(subByteVec.getType());
1789 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1790 assert(8 % srcBitwidth == 0 &&
1791 "Unsupported sub-byte type (not a divisor of i8)");
1792 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1793 SmallVector<int64_t> vecShape(srcVecType.getShape());
1794 // Adjust last dimension of the vector, so the total size remains the same.
1795 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1796 auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1797 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1798}
1799
1800/// Extracts a signed N-bit sequence from each element of a vector of bytes,
1801/// starting at the specified bit index.
1802/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1803///
1804/// Example for a single element:
1805/// Extract numBits=2 starting at bitIdx=2
1806/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1807/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1808/// target = [. . . . ^ ^ . .]
1809///
1810/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1811/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1812///
1813/// src = [01 01 11 10]
1814/// shl = arith.shl(src, 4) -> [11 10 00 00]
1815/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1817 Location loc, Value src,
1818 int bitIdx, int numBits) {
1819 auto srcType = cast<VectorType>(src.getType());
1820 Value shl = src;
1821 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1822 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1823 "Invalid bitIdx range");
1824 if (bitsToShiftLeft != 0) {
1825 Value shiftLeftValues = arith::ConstantOp::create(
1826 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1827 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1828 }
1829
1830 int8_t bitsToShiftRight = 8 - numBits;
1831 Value shiftRightValues = arith::ConstantOp::create(
1832 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1833 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1834 return shr;
1835}
1836
1837/// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1838/// starting at the specified bit index.
1839/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1840///
1841/// Example for a single element:
1842/// Extract numBits=2 starting at bitIdx=2
1843/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1844/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1845/// target = [. . . . ^ ^ . .]
1846///
1847/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1848/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1849///
1850/// src = [01 01 10 10]
1851/// mask = [00 00 00 11]
1852/// shr = arith.shrui(src, 2) = [00 01 01 10]
1853/// result = arith.andi(shr, mask) = [00 00 00 10]
1854/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1855/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1856/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1857/// left when the index is 0.
1859 Location loc, Value src,
1860 int bitIdx, int numBits) {
1861 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1862 "Invalid bitIdx range");
1863 auto srcType = cast<VectorType>(src.getType());
1864 int8_t bitsToShiftRight = bitIdx;
1865 Value shr = src;
1866 if (bitsToShiftRight != 0) {
1867 Value shiftRightValues = arith::ConstantOp::create(
1868 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1869 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1870 }
1871 if (bitIdx + numBits == 8) {
1872 return shr;
1873 }
1874 uint8_t lowBitsMask = (1 << numBits) - 1;
1875 Value lowBitsMaskValues = arith::ConstantOp::create(
1876 rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask));
1877 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1878}
1879
1881 std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1882
1883/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1884/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1886 Value srcValue, const ExtractNBitsFn &extFn) {
1887 [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1888 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1889 "Expected i4 type");
1890
1891 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1892 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1893
1894 // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1895 // byte are place in one vector and the high i4 elements in another vector.
1896 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1897 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1898
1899 // 3. Interleave low and high i8 elements.
1900 return vector::InterleaveOp::create(rewriter, loc, low, high);
1901}
1902
1903/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1904/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1906 Value srcValue, const ExtractNBitsFn &extFn) {
1907 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1908 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1909 "Expected i2 type");
1910
1911 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1912 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1913
1914 // 2. Extract each i2 element
1915 // Positon 0 (bits 0-1)
1916 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1917 // Position 1 (bits 2-3)
1918 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1919 // Position 2 (bits 4-5)
1920 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1921 // Position 3 (bits 6-7)
1922 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1923
1924 // 3. Interleave all 4 elements by first interleaving
1925 // even elements and then odd
1926 // vec0 = [0,0,0,0],...
1927 // vec1 = [1,1,1,1],...
1928 // vec2 = [2,2,2,2],...
1929 // vec3 = [3,3,3,3],...
1930 // 02 = [0,2,0,2,0,2,0,2],...
1931 // 13 = [1,3,1,3,1,3,1,3],...
1932 // 0213 = [0,1,2,3,...],...
1933 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1934 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1935 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1936 interleave13);
1937}
1938
1939/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1940/// ops to avoid leaving LLVM to scramble with peephole optimizations.
1942 Value srcValue) {
1943 VectorType srcVecType = cast<VectorType>(srcValue.getType());
1944 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1945 "Expected i8 type");
1946
1947 // 1. De-interleave low and high i8 elements.
1948 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1949
1950 // 2. Zero out the upper side of each low i8 element.
1951 constexpr int8_t i8LowBitMask = 0x0F;
1952 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1953 Value zeroOutMask = arith::ConstantOp::create(
1954 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1955 Value zeroOutLow = arith::AndIOp::create(
1956 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1957
1958 // 3. Move high i4 values to upper side of the byte.
1959 constexpr int8_t bitsToShift = 4;
1960 auto shiftValues = arith::ConstantOp::create(
1961 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1962 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1963 shiftValues);
1964
1965 // 4. Merge high and low i4 values.
1966 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1967
1968 // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1969 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1970 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1971}
1972
1973namespace {
1974/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1975/// advantage of high-level information to avoid leaving LLVM to scramble with
1976/// peephole optimizations.
1977struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1978 using Base::Base;
1979
1980 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1981 PatternRewriter &rewriter) const override {
1982 // The source must be a trunc op.
1983 auto truncOp =
1984 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1985 if (!truncOp)
1986 return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1987
1988 // Set up the BitCastRewriter and verify the precondition.
1989 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1990 VectorType targetVectorType = bitCastOp.getResultVectorType();
1991 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1992 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1993 return failure();
1994
1995 // Perform the rewrite.
1996 Value truncValue = truncOp.getIn();
1997 auto shuffledElementType =
1998 cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1999 Value runningResult;
2000 for (const BitCastRewriter ::Metadata &metadata :
2001 bcr.precomputeMetadata(shuffledElementType)) {
2002 runningResult = bcr.genericRewriteStep(
2003 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
2004 }
2005
2006 // Finalize the rewrite.
2007 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
2008 shuffledElementType.getIntOrFloatBitWidth();
2009 if (narrowing) {
2010 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
2011 rewriter.replaceOp(bitCastOp, runningResult);
2012 } else {
2013 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2014 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2015 }
2016 } else {
2017 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
2018 rewriter.replaceOp(bitCastOp, runningResult);
2019 } else {
2020 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2021 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2022 }
2023 }
2024
2025 return success();
2026 }
2027};
2028} // namespace
2029
2030//===----------------------------------------------------------------------===//
2031// RewriteExtOfBitCast
2032//===----------------------------------------------------------------------===//
2033
2034namespace {
2035/// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
2036/// take advantage of high-level information to avoid leaving LLVM to scramble
2037/// with peephole optimizations.
2038template <typename ExtOpType>
2039struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2040 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2041
2042 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2043 : OpRewritePattern<ExtOpType>(context, benefit) {}
2044
2045 LogicalResult matchAndRewrite(ExtOpType extOp,
2046 PatternRewriter &rewriter) const override {
2047 // The source must be a bitcast op.
2048 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2049 if (!bitCastOp)
2050 return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
2051
2052 // Set up the BitCastRewriter and verify the precondition.
2053 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2054 VectorType targetVectorType = bitCastOp.getResultVectorType();
2055 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2056 if (failed(bcr.commonPrecondition(
2057 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2058 return failure();
2059
2060 // Perform the rewrite.
2061 Value runningResult;
2062 Value sourceValue = bitCastOp.getSource();
2063 auto shuffledElementType =
2064 cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
2065 for (const BitCastRewriter::Metadata &metadata :
2066 bcr.precomputeMetadata(shuffledElementType)) {
2067 runningResult = bcr.genericRewriteStep(
2068 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2069 }
2070
2071 // Finalize the rewrite.
2072 bool narrowing =
2073 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2074 shuffledElementType.getIntOrFloatBitWidth();
2075 if (narrowing) {
2076 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2077 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2078 } else {
2079 rewriter.replaceOpWithNewOp<ExtOpType>(
2080 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2081 }
2082
2083 return success();
2084 }
2085};
2086
2087/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
2088/// bitwise ops that take advantage of high-level information to avoid leaving
2089/// LLVM to scramble with peephole optimizations. Templated to choose between
2090/// signed and unsigned conversions.
2091///
2092/// EXAMPLE 1 (signed):
2093/// arith.extsi %in : vector<8xi4> to vector<8xi32>
2094/// is rewriten as:
2095/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2096/// %1 = arith.shli %0, 4 : vector<4xi8>
2097/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2098/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2099/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2100/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2101///
2102/// EXAMPLE 2 (fp):
2103/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2104/// is rewriten as:
2105/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2106/// %1 = arith.shli %0, 4 : vector<4xi8>
2107/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2108/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2109/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2110/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2111///
2112/// EXAMPLE 3 (unsigned):
2113/// arith.extui %in : vector<8xi4> to vector<8xi32>
2114/// is rewritten as:
2115/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2116/// %1 = arith.andi %0, 15 : vector<4xi8>
2117/// %2 = arith.shrui %0, 4 : vector<4xi8>
2118/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2119/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2120///
2121template <typename ConversionOpType, bool isSigned>
2122struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2123 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2124
2125 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2126 PatternRewriter &rewriter) const override {
2127 // Verify the preconditions.
2128 Value srcValue = conversionOp.getIn();
2129 VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2130 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2131
2132 if (failed(
2133 commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2134 return failure();
2135
2136 // Check general alignment preconditions.
2138 rewriter, srcVecType,
2139 /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2140 return failure();
2141
2142 // Perform the rewrite.
2143 Location loc = conversionOp.getLoc();
2144 const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2146 Value subByteExt;
2147 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2148 case 2:
2149 subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2150 break;
2151 case 4:
2152 subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2153 break;
2154 default:
2155 return failure();
2156 }
2157
2158 // Finalize the rewrite. If subByteExt already has the destination type
2159 // (e.g. extsi i4->i8 where the container is i8), replace directly without
2160 // creating a new conversion op that would have identical src and dst types.
2161 if (subByteExt.getType() == conversionOp.getType())
2162 rewriter.replaceOp(conversionOp, subByteExt);
2163 else
2164 rewriter.replaceOpWithNewOp<ConversionOpType>(
2165 conversionOp, conversionOp.getType(), subByteExt);
2166 return success();
2167 }
2168};
2169
2170/// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2171/// bitwise ops that take advantage of high-level information to avoid leaving
2172/// LLVM to scramble with peephole optimizations.
2173///
2174/// For example:
2175/// arith.trunci %in : vector<8xi32> to vector<8xi4>
2176///
2177/// is rewriten as:
2178///
2179/// %cst = arith.constant dense<15> : vector<4xi8>
2180/// %cst_0 = arith.constant dense<4> : vector<4xi8>
2181/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2182/// %2 = arith.andi %0, %cst : vector<4xi8>
2183/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2184/// %4 = arith.ori %2, %3 : vector<4xi8>
2185/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2186///
2187struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2188 using Base::Base;
2189
2190 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2191 PatternRewriter &rewriter) const override {
2192 // Verify the preconditions.
2193 Value srcValue = truncOp.getIn();
2194 auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2195 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2196 if (!srcVecType || !dstVecType)
2197 return failure();
2198
2199 if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2200 return failure();
2201
2202 // TODO: Add support for truncating to i2.
2203 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2204 return failure();
2205
2206 // Check general alignment preconditions. We invert the src/dst type order
2207 // to reuse the existing precondition logic.
2209 rewriter, dstVecType,
2210 /*containerTy=*/rewriter.getI8Type(), truncOp)))
2211 return failure();
2212
2213 // Create a new iX -> i8 truncation op, unless the source is already i8.
2214 Location loc = truncOp.getLoc();
2215 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2216 Value i8TruncVal =
2217 srcVecType == i8VecType
2218 ? srcValue
2219 : arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2220
2221 // Rewrite the i8 -> i4 truncation part.
2222 Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2223
2224 // Finalize the rewrite.
2225 rewriter.replaceOp(truncOp, subByteTrunc);
2226 return success();
2227 }
2228};
2229
2230/// Rewrite a sub-byte vector transpose into a sequence of instructions that
2231/// perform the transpose on wider (byte) element types.
2232///
2233/// EXAMPLE:
2234/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2235///
2236/// is rewritten as:
2237///
2238/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2239/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2240/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2241///
2242struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2243 using Base::Base;
2244
2245 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2246 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2247
2248 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2249 PatternRewriter &rewriter) const override {
2250 // Precondition: sub-byte integer transpose.
2251 constexpr unsigned minNativeBitwidth = 8;
2252 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2253 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2254 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2255 return rewriter.notifyMatchFailure(transposeOp,
2256 "not a sub-byte transpose");
2257 }
2258
2259 // Perform the rewrite.
2260 Location loc = transposeOp.getLoc();
2261 // Signed/unsigned interpretation shouldn't matter here as we are just
2262 // transposing the elements and truncating them back to the original size.
2263 // TODO: Use unsigned extension (more efficient) when emulation or backend
2264 // support is available.
2265 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2266 std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2267 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2268 transposeOp.getVector());
2269 Value newTranspose = vector::TransposeOp::create(
2270 rewriter, loc, extOp, transposeOp.getPermutation());
2271 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2272 rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2273 newTranspose);
2274 return success();
2275 }
2276};
2277
2278} // namespace
2279
2280//===----------------------------------------------------------------------===//
2281// Public Interface Definition
2282//===----------------------------------------------------------------------===//
2283
2284// The emulated type is inferred from the converted memref type.
2285void vector::populateVectorNarrowTypeEmulationPatterns(
2286 const arith::NarrowTypeEmulationConverter &typeConverter,
2287 RewritePatternSet &patterns, bool disableAtomicRMW, bool assumeAligned) {
2288 // Populate `vector.*` conversion patterns.
2289 // TODO: #119553 support atomicity
2290 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2291 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2292 typeConverter, patterns.getContext());
2293
2294 // Populate `vector.*` store conversion patterns. The caller can choose
2295 // to avoid emitting atomic operations and reduce it to read-modify-write
2296 // sequence for stores if it is known there are no thread contentions.
2297 patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW,
2298 assumeAligned);
2299}
2300
2301void vector::populateVectorNarrowTypeRewritePatterns(
2302 RewritePatternSet &patterns, PatternBenefit benefit) {
2303 // TODO: Document what the emulated type is.
2304 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2305 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2306 benefit);
2307
2308 // Patterns for aligned cases. We set higher priority as they are expected to
2309 // generate better performance for aligned cases.
2310 // The container type is always i8.
2311 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2312 RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2313 RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2314 benefit.getBenefit() + 1);
2315 // The container type is always i8.
2316 patterns
2317 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2318 RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2319 patterns.getContext(), benefit.getBenefit() + 1);
2320}
2321
2322// The container type is always i8.
2323void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2324 RewritePatternSet &patterns, PatternBenefit benefit) {
2325 patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2326}
2327
2328void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2329 arith::NarrowTypeEmulationConverter &typeConverter,
2330 RewritePatternSet &patterns) {
2332 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
2333}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
Definition AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IntegerType getI4Type()
Definition Builders.cpp:61
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:286
IndexType getIndexType()
Definition Builders.cpp:55
IntegerType getI8Type()
Definition Builders.cpp:63
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
result_type_range getResultTypes()
Definition Operation.h:453
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Patterns for flattening all supported multi-dimensional memref operations into one-dimensional memref...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.