MLIR 23.0.0git
VectorContractToAMXDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
17#include "mlir/IR/Dominance.h"
19#include "llvm/Support/Casting.h"
20
21#include "mlir/Pass/Pass.h"
23
24using namespace mlir;
25using namespace mlir::vector;
26using namespace mlir::x86;
27
28namespace {
29
30// Function to collapse the last two dimension (vnni and k) to help the
31// amx.tile_load to correctly load the packed element type.
33 Value input) {
34 ShapedType inputType = cast<ShapedType>(input.getType());
35 int64_t firstDimToCollapse = inputType.getRank() - 2;
36
37 if (inputType.getRank() == 1)
38 return input;
39
41 for (int64_t i = 0; i < firstDimToCollapse; ++i)
42 reassociation.push_back(ReassociationIndices{i});
43
44 ReassociationIndices collapsedIndices;
45 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
46 collapsedIndices.push_back(i);
47
48 reassociation.push_back(collapsedIndices);
49 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
50}
51
52// Get the MemRef source and offset index for the operands of
53// vector.contract.
54static FailureOr<std::pair<Value, SmallVector<Value>>>
55getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
56 bool isNotAcc) {
57 Operation *defOp = operand.getDefiningOp();
58 if (!defOp)
59 return failure();
60
61 Value srcBuff;
64 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
65 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
66 readOp.getIndices().end());
67 srcBuff = readOp.getOperand(0);
68 });
69
70 if (!srcBuff)
71 return failure();
72
73 if (isNotAcc)
74 indexVals.pop_back();
75
77 indices.reserve(indexVals.size());
78
79 for (OpFoldResult ofr : indexVals) {
80 indices.push_back(
81 mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
82 }
83
84 if (isNotAcc) {
85 srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
86 }
87
88 return std::make_pair(srcBuff, indices);
89}
90
91// Function to validate the vector.contract operation.
92static LogicalResult validateContractOps(OpBuilder &rewriter,
93 vector::ContractionOp contractOp,
94 unsigned int blockingFactor,
95 Value srcBuffLhs, Value srcBuffRhs,
96 bool srcValidate) {
97
98 if (srcValidate) {
99 // Get the MemRef buffer of LHS operand.
100 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
101 contractOp.getLhs(), false);
102 if (failed(srcIndxLhs))
103 return failure();
104 auto [buffLhs, indicesLhs] = *srcIndxLhs;
105
106 // Get the MemRef buffer of RHS operand.
107 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
108 contractOp.getRhs(), false);
109 if (failed(srcIndxRhs))
110 return failure();
111 auto [buffRhs, indicesRhs] = *srcIndxRhs;
112
113 // Return failure if the Memref buff didn't match.
114 if (buffLhs != srcBuffLhs)
115 return failure();
116
117 if (buffRhs != srcBuffRhs)
118 return failure();
119 }
120
121 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
122 if (!accTy)
123 return failure();
124
125 // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
126 ArrayRef<int64_t> accShape = accTy.getShape();
127 llvm::SmallVector<int64_t> nonUnitDimAcc;
128 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
129 [](int64_t dim) { return (dim != 16 && dim != 1); });
130
131 if (nonUnitDimAcc.size() != 0)
132 return failure();
133
134 // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
135 // <16x16x4>. The vnni dims should be 2 or 4.
136 VectorType lhsTy = contractOp.getLhsType();
137 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
138 llvm::SmallVector<int64_t> nonUnitDimLhs;
139 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
140 [](int64_t dim) { return (dim != 16 && dim != 1); });
141
142 if (nonUnitDimLhs.size() != 1)
143 return failure();
144
145 if (nonUnitDimLhs[0] != blockingFactor)
146 return failure();
147
148 // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
149 // <16x16x4>. The vnni dims should be 2 or 4.
150 VectorType rhsTy = contractOp.getRhsType();
151 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
152 llvm::SmallVector<int64_t> nonUnitDimRhs;
153 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
154 [](int64_t dim) { return (dim != 16 && dim != 1); });
155
156 if (nonUnitDimRhs.size() != 1)
157 return failure();
158
159 if (nonUnitDimRhs[0] != blockingFactor)
160 return failure();
161
162 return success();
163}
164
165// Returns the loop index position to get mapped during the
166// MemRef type clone.
167static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
168 Value iv = loop.getInductionVar();
169
170 Value srcBuff;
172 .Case<TransferReadOp, LoadOp>(
173 [&](auto readOp) { srcBuff = readOp.getOperand(0); });
174
175 auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
176 if (!subview)
177 return 0;
178
179 auto offsets = subview.getOffsets();
180
181 for (auto it : llvm::enumerate(offsets)) {
182 if (it.value() == iv)
183 return it.index();
184 }
185
186 return 0;
187}
188
189// Creates amx.tile_loads.
190static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
191 Value operand, Value mat, Type ipType,
192 bool rhs, unsigned int offset,
193 bool isVnni) {
194
195 auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
196 auto [srcBuff, indices] = *srcIndx;
197 if (isVnni) {
198 indices.pop_back();
199 }
200
201 if (rhs && isVnni) {
202 auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
203 indices[indices.size() - 1] = arith::MulIOp::create(
204 rewriter, loc, indices[indices.size() - 1], cOffset);
205 }
206
207 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
208 return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
209}
210
211static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
212 Type ipType, unsigned int offset, Value packedBuffer,
213 Value indxToStoreInBuffer) {
214
215 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
216 Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
217
218 auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
219 SmallVector<Value> subviewOffset(subview.getOffsets().size(), c0);
220
221 Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
222 Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
223 Value offsetIndx =
224 arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
225
226 scf::ForOp::create(
227 rewriter, loc, c0, cBound, cStep, ValueRange{},
228 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
229 ValueRange iterArgs) {
230 subviewOffset[subviewOffset.size() - 2] = iv;
231 auto vec1 = vector::LoadOp::create(
232 rewriter, loc, VectorType::get((16 * offset), ipType), matB,
233 ValueRange(subviewOffset));
234
235 // Increment the iv by 1 or 2 based on the type to load the next 32/64
236 // elements
237 Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
238 subviewOffset[subviewOffset.size() - 2] = incIV;
239 auto vec2 = vector::LoadOp::create(
240 rewriter, loc, VectorType::get((16 * offset), ipType), matB,
241 ValueRange(subviewOffset));
242
243 vector::ShuffleOp shuffle1;
244 vector::ShuffleOp shuffle2;
245
246 if (ipType.isBF16()) {
247
248 shuffle1 = vector::ShuffleOp::create(
249 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
250 vec2,
251 ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
252 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
253 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
254
255 shuffle2 = vector::ShuffleOp::create(
256 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
257 vec2,
258 ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
259 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
260 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
261 }
262
263 if (ipType.isSignlessInteger(8)) {
264
265 shuffle1 = vector::ShuffleOp::create(
266 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
267 vec2,
269 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
270 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42,
271 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81,
272 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120,
273 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123});
274
275 shuffle2 = vector::ShuffleOp::create(
276 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
277 vec2,
279 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39,
280 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110,
281 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54,
282 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125,
283 30, 62, 94, 126, 31, 63, 95, 127});
284 }
285
286 // iv to store the shuffled elements
287 Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
288 Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
289
290 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
291 ValueRange{indxToStoreInBuffer, ivShuff1, c0});
292 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
293 ValueRange{indxToStoreInBuffer, ivShuff2, c0});
294
295 scf::YieldOp::create(nestedBuilder, loc);
296 });
297}
298
300packInputs(OpBuilder &rewriter, Location loc,
302 unsigned int offset, Value packedBuffer, bool pack,
303 Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
304
306 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
307 Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
308
309 for (size_t j = 0; j < ops.size(); j++) {
310 for (size_t i = 0; i < ops.size(); i++) {
311
312 if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
313
314 Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
315 auto itRhs = readsToTileLoads.find(readOpRhs);
316 if (itRhs != readsToTileLoads.end()) {
317 continue;
318 }
319
320 if (pack) {
321 performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
322 indxToStoreInBuffer);
323 }
324
325 amx::TileType tileType =
326 amx::TileType::get({16, (16 * offset)}, ipType);
327 auto loadRow1 =
328 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
329 ValueRange{indxToLoadFromMatB, c0, c0});
330
331 auto loadRow2 =
332 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
333 ValueRange{indxToLoadFromMatB, c16, c0});
334
335 readsToTileLoads.try_emplace(readOpRhs, loadRow1);
336 readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
337 }
338 }
339 }
340
341 return readsToTileLoads;
342}
343
344// Creates tiled amx dot-products.
346createTiledDp(OpBuilder &rewriter, Location loc,
348 Type ipType, Type opType, ValueRange accIterArgs,
349 unsigned int offset, bool isVnni, Value packedBuffer, bool pack,
350 Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
351
352 if (isVnni) {
353 matA = collapseInnerDims(rewriter, loc, matA);
354 matB = collapseInnerDims(rewriter, loc, matB);
355 }
356
357 SmallVector<Value> accumulators;
358 // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
359 // or load operations.
361
362 // function call to online pack the input B matrix
363 if (!isVnni) {
364 readsToTileLoads =
365 packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
366 indxToStoreInBuffer, indxToLoadFromMatB);
367 }
368
369 // Iterate over the contraction operations and compute the tiled dot-product.
370 for (size_t i = 0; i < ops.size(); i++) {
371
372 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
373 amx::TileLoadOp tilesLhs;
374 auto itLhs = readsToTileLoads.find(readOpLhs);
375 if (itLhs != readsToTileLoads.end()) {
376 tilesLhs = itLhs->second;
377 } else {
378 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
379 false, offset, isVnni);
380 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
381 }
382
383 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
384 amx::TileLoadOp tilesRhs;
385 auto itRhs = readsToTileLoads.find(readOpRhs);
386 if (itRhs != readsToTileLoads.end()) {
387 tilesRhs = itRhs->second;
388 } else {
389 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
390 true, offset, isVnni);
391 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
392 }
393
394 auto accTileType = amx::TileType::get({16, 16}, opType);
395
396 Value dp;
397 if (ipType.isBF16())
398 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
399 tilesRhs, accIterArgs[i]);
400
401 if (ipType.isSignlessInteger(8))
402 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
403 tilesRhs, accIterArgs[i]);
404
405 accumulators.push_back(dp);
406 }
407 return accumulators;
408}
409
410static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
411 Type opType, scf::ForOp outerLoop,
412 int64_t size) {
413 rewriter.setInsertionPoint(outerLoop);
414
415 SmallVector<Value> loopItrArgs;
416 auto zeroTileType = amx::TileType::get({16, 16}, opType);
417
418 for (int i = 0; i < size; i++) {
419 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
420 loopItrArgs.push_back(zeroTile);
421 }
422 return loopItrArgs;
423}
424
425static Value getIndxToLoadStoreFromPckBuffer(
426 OpBuilder &rewriter, Location loc, Value ivInnerLoop, Value ivOuterLoop,
427 bool isInnerLoopUBHasOddQuot, bool isInnerLoopUBLarger, bool pack,
428 unsigned int blockingFactor) {
429
430 Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
431 Value packOffset =
432 arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
433
434 Value quotientInnerLoop =
435 arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
436 Value remInnerLoop = arith::RemUIOp::create(
437 rewriter, loc, rewriter.getIndexType(), quotientInnerLoop, c2);
438
439 if (!isInnerLoopUBLarger && !pack) {
440 remInnerLoop = arith::RemUIOp::create(
441 rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
442 }
443
444 if (isInnerLoopUBHasOddQuot) {
445 auto remOuterLoop = arith::RemUIOp::create(
446 rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
447 auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
448 remInnerLoop, remOuterLoop);
449 remInnerLoop = arith::RemUIOp::create(rewriter, loc,
450 rewriter.getIndexType(), remAdd, c2);
451 }
452
453 return remInnerLoop;
454}
455
456static scf::ForOp
457createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
458 Value upperBound, Value step, SmallVector<Value> loopItrArgs,
459 Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
460 Operation *vectorOpLhs, Operation *vectorOpRhs,
461 vector::ContractionOp contractOp, scf::ForOp outerLoop,
462 scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
463 Value ivOuterLoop, Value packedBuffer, bool pack,
464 arith::ConstantIndexOp innerLoopIndex, bool isInnerLoopUBLarger,
465 bool isInnerLoopUBHasOddQuot) {
466
467 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
468 Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
469 Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
470
471 auto newLoop = scf::ForOp::create(
472 rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
473 [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
474 Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
475 IRMapping mapping;
476 if (outerLoop)
477 mapping.map(vectorOpLhs->getOperand(
478 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
479 ivOuterLoop);
480
481 mapping.map(vectorOpLhs->getOperand(
482 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
483 ivNewInnerLoop);
484 auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
485
486 Value indxToStoreInBuffer = c0;
487 Value indxToLoadFromBuffer = c0;
488
489 if (!isVnni) {
490 if (outerLoop) {
491 if (innerLoopIndex.value() == 0) {
492 if (pack) {
493 ivNewInnerLoop = c0;
494 ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
495 c1, ivOuterLoop);
496
497 if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
498 indxToStoreInBuffer = arith::RemUIOp::create(
499 rewriter, locNewInnerLoop, rewriter.getIndexType(),
500 ivOuterLoop, c2);
501 }
502
503 Value indxToLoadFromMatB = arith::AddIOp::create(
504 rewriter, loc, indxToStoreInBuffer, c1);
505 indxToLoadFromBuffer = arith::RemUIOp::create(
506 rewriter, loc, rewriter.getIndexType(), indxToLoadFromMatB,
507 c2);
508 }
509
510 } else {
512 rewriter, locNewInnerLoop, (16 * blockingFactor));
513 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
514 nLoadIndx, ivNewInnerLoop);
515 indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
516 rewriter, loc, ivNewInnerLoop, ivOuterLoop,
517 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
518 blockingFactor);
519 Value indxToLoadFromMatB =
520 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
521 indxToLoadFromBuffer =
522 arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
523 indxToLoadFromMatB, c2);
524 }
525 } else {
526 if (pack) {
528 rewriter, locNewInnerLoop, (16 * blockingFactor));
529 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
530 nLoadIndx, ivNewInnerLoop);
531 Value quotient_K = arith::DivUIOp::create(
532 rewriter, loc, ivNewInnerLoop, nLoadIndx);
533 indxToStoreInBuffer = arith::RemUIOp::create(
534 rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
535
536 Value indxToLoadFromMatB =
537 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
538 indxToLoadFromBuffer =
539 arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
540 indxToLoadFromMatB, c2);
541 }
542 }
543 }
544
545 IRMapping rhsMapping;
546 if (outerLoop)
547 rhsMapping.map(
548 vectorOpRhs->getOperand(
549 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
550 ivOuterLoop);
551
552 rhsMapping.map(
553 vectorOpRhs->getOperand(
554 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
555 ivNewInnerLoop);
556 auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
557
558 Value matB = rhsClone->getResult(0);
559
560 if (!isVnni) {
561 if (outerLoop) {
562 if (!pack) {
564 rewriter, locNewInnerLoop, (16 * blockingFactor));
565 matB = Value();
566 indxToLoadFromBuffer = c0;
567 indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
568 rewriter, loc, nLoadIndx, ivOuterLoop,
569 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
570 blockingFactor);
571 }
572 } else {
573 if (!pack) {
575 rewriter, locNewInnerLoop, (16 * blockingFactor));
576 matB = Value();
577 Value quotient_K = arith::DivUIOp::create(
578 rewriter, loc, ivNewInnerLoop, nLoadIndx);
579 indxToLoadFromBuffer = arith::RemUIOp::create(
580 rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
581 }
582 }
583 }
584
585 // compute tiled dot-product
586 SmallVector<Value> accumulators = createTiledDp(
587 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
588 ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
589 packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
590
591 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
592 accumulators);
593 });
594
595 return newLoop;
596}
597
598// Implements tiled dot-product operation for a vector.contract operation or a
599// sequence of vector.contracts inside the reduction loops.
600//
601// For example:
602// Case 1: register blocked vector.contract with prepacked input
603// ```
604// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
605// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
606// vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
607// vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
608// ```
609// to
610// ```
611// amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
612// amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
613// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
614// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
615// ```
616//
617//
618// Case2: vector.contract with register blocked
619//
620// Output IR with online packing (with s/w pipeline advantage):
621// s/w pipeline: load, pack to VNNI, and store the B sub matrix
622// of the 0th batch-reduce and K iteration.
623// scf.for (0 to 31) {
624// - load 0th and 1st vector<32xbf16>, pack into VNNI, store the
625// first shuffle in 0th and 2nd shuffle in 16th index of the
626// buffer.
627// }
628// scf.for (0 to br-2) { batch-reduce loop
629// scf.for (0 to k-2) { K loop
630// - load A matrix
631// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
632// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
633// iteration from the buffer (d) compute the tiled dot-product
634// }
635// Last iteration of the the K Loop (k-1) {
636// - load A matrix
637// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
638// matrix for the next batch-reduce + K loop iteration (c) load VNNI pack B
639// matrix of K iteration from the buffer (d) compute the tiled dot-product
640// }
641// }
642// Last iteration of the batch-reduce loop (br-1) {
643// scf.for (0 to k-2) { K loop
644// - load A matrix
645// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
646// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
647// iteration from the buffer (d) compute the tiled dot-product
648// }
649// Last iteration of the the K Loop (k-1) {
650// - load A matrix
651// - load VNNI pack B matrix of K iteration from the buffer
652// - compute the tiled dot-product
653// }
654// }
655//
656// scf.for (0 to M)
657// scf.for (0 to N)
658// - Load the ith and i+1th acc
659// - Shuffle them as we packed using vpunpack
660// - Load C matrix and do arith.add with the shuffle
661// - Store back into C matrix
662struct VectorContractToAMXDotProduct
663 : public OpRewritePattern<vector::ContractionOp> {
664 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
665
666 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
667 PatternRewriter &rewriter) const override {
668
669 if (contractOp.getKind() != vector::CombiningKind::ADD)
670 return rewriter.notifyMatchFailure(contractOp,
671 "Expects add combining kind.");
672
673 unsigned int blockingFactor =
674 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
675 bool isVnni =
676 isInVnniLayout(contractOp.getOperation(),
677 contractOp.getIndexingMapsArray(), blockingFactor);
678
679 VectorType lhsTy = contractOp.getLhsType();
680 if (!lhsTy.getElementType().isBF16() &&
681 !lhsTy.getElementType().isSignlessInteger(8))
682 return rewriter.notifyMatchFailure(
683 contractOp, "Only BF16/Int8 lowering is supported.");
684
685 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
686 if (!accTy)
687 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
688
689 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
690 (lhsTy.getElementType().isSignlessInteger(8) &&
691 !accTy.getElementType().isSignlessInteger(32)))
692 return rewriter.notifyMatchFailure(contractOp,
693 "Only F32 for BF16 or Int32 for Int8 "
694 "accumulation type is supported.");
695
696 Operation *accReadOp =
697 traceToVectorReadLikeParentOperation(contractOp.getAcc());
698
699 Operation *resultWriteOp =
700 traceToVectorWriteLikeUserOperation(contractOp.getResult());
701
702 if (!accReadOp || !resultWriteOp)
703 return rewriter.notifyMatchFailure(
704 contractOp, "The ACC operand of the vector.contract should be a "
705 "transfer_read or a load. And, the result should be "
706 "stored using transfer_write or store.");
707
708 Type ipType = rewriter.getBF16Type();
709 Type opType = rewriter.getF32Type();
710
711 if (lhsTy.getElementType().isSignlessInteger(8)) {
712 ipType = rewriter.getIntegerType(8);
713 opType = rewriter.getIntegerType(32);
714 }
715
716 if (accReadOp->getBlock() == contractOp->getBlock() &&
717 resultWriteOp->getBlock() != contractOp->getBlock())
718 return rewriter.notifyMatchFailure(
719 contractOp, "The accumulator store is in different block.");
720
721 if (accReadOp->getBlock() != contractOp->getBlock() &&
722 resultWriteOp->getBlock() == contractOp->getBlock())
723 return rewriter.notifyMatchFailure(
724 contractOp, "The accumulator read is in different block.");
725
726 unsigned int dimValue = blockingFactor;
727 if (!isVnni)
728 dimValue = 16 * blockingFactor;
729
730 // Case 1: For just one VC rewrite. Where all accumulator read/write
731 // within the same block.
732 if (accReadOp->getBlock() == contractOp->getBlock() &&
733 resultWriteOp->getBlock() == contractOp->getBlock()) {
734
735 bool collapse = false;
736 if (isVnni)
737 collapse = true;
738
739 LogicalResult validate = validateContractOps(
740 rewriter, contractOp, dimValue, Value(), Value(), false);
741
742 if (failed(validate))
743 return rewriter.notifyMatchFailure(
744 contractOp, "The contract operation doesn't satisfy the operands "
745 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
746 "The rest dims should be 1.");
747
748 Location loc = contractOp.getLoc();
749
750 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
751 contractOp.getLhs(), collapse);
752 if (failed(srcIndxLhs))
753 return rewriter.notifyMatchFailure(contractOp,
754 "The LHS src is not a MemRef type.");
755 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
756
757 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
758 contractOp.getRhs(), collapse);
759 if (failed(srcIndxRhs))
760 return rewriter.notifyMatchFailure(contractOp,
761 "The RHS src is not a MemRef type.");
762 auto rhsSrc = *srcIndxRhs;
763 auto srcBuffRhs = rhsSrc.first;
764 auto indicesRhs = rhsSrc.second;
765
766 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
767 contractOp.getAcc(), false);
768 if (failed(srcIndxAcc))
769 return rewriter.notifyMatchFailure(contractOp,
770 "The ACC src is not a MemRef type.");
771 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
772
773 // amx.tile_loads
774 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
775 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
776 srcBuffLhs, indicesLhs);
777
778 // Create the subview and then load.
779 amx::TileLoadOp loadRhs;
780 if (!isVnni) {
781 VectorType vecTy;
782 SmallVector<OpFoldResult> indexVals;
783 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
784 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
785 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
786 readOp.getIndices().end());
787 vecTy = readOp.getType();
788 });
789 auto one = rewriter.getIndexAttr(1);
790 SmallVector<OpFoldResult> strides(indexVals.size(), one);
791 SmallVector<OpFoldResult> sizes = getAsIndexOpFoldResult(
792 contractOp.getRhs().getDefiningOp()->getContext(),
793 vecTy.getShape());
794 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
795 indexVals, sizes, strides);
796 auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
797 auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
798
799 // create a loop that does online packing.
800 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
801 Value step =
802 arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
803 Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
804 (blockingFactor * 16));
805 Value nextLoadIndx =
806 arith::ConstantIndexOp::create(rewriter, loc, (blockingFactor / 2));
807 Value nextStoreIndx = arith::ConstantIndexOp::create(
808 rewriter, loc, 16 * (blockingFactor / 2));
809
810 scf::ForOp::create(
811 rewriter, loc, c0, uBound, step, ValueRange{},
812 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
813 ValueRange iterArgs) {
814 Value i1_load =
815 arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
816
817 indicesRhs[indicesRhs.size() - 2] = iv;
818 ValueRange range1(indicesRhs);
819 auto vec1 = vector::LoadOp::create(
820 rewriter, loc,
821 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
822 range1);
823
824 indicesRhs[indicesRhs.size() - 2] = i1_load;
825 ValueRange range2(indicesRhs);
826 auto vec2 = vector::LoadOp::create(
827 rewriter, loc,
828 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
829 range2);
830
831 vector::ShuffleOp shuffle1;
832 vector::ShuffleOp shuffle2;
833
834 if (blockingFactor == 2) {
835
836 shuffle1 = vector::ShuffleOp::create(
837 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
838 ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
839 6, 22, 7, 23});
840
841 shuffle2 = vector::ShuffleOp::create(
842 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
843 ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
844 29, 14, 30, 15, 31});
845 }
846
847 if (blockingFactor == 4) {
848 shuffle1 = vector::ShuffleOp::create(
849 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
850 ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
851 2, 18, 34, 50, 3, 19, 35, 51,
852 4, 20, 36, 52, 5, 21, 37, 53,
853 6, 22, 38, 54, 7, 23, 39, 55});
854
855 shuffle2 = vector::ShuffleOp::create(
856 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
857 ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
858 10, 26, 42, 58, 11, 27, 43, 59,
859 12, 28, 44, 60, 13, 29, 45, 61,
860 14, 30, 46, 62, 15, 31, 47, 63});
861 }
862
863 auto rem = arith::RemUIOp::create(
864 rewriter, loc, rewriter.getIndexType(), iv, step);
865
866 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
867 ValueRange{rem, c0});
868 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
869 ValueRange{rem, nextStoreIndx});
870
871 scf::YieldOp::create(nestedBuilder, loc);
872 });
873 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
874 ValueRange{c0, c0});
875 } else {
876
877 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
878 indicesRhs);
879 }
880
881 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
882 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
883 srcBuffAcc, indicesAcc);
884
885 // Tiled dot-product.
886 Value dp;
887 if (ipType.isBF16())
888 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
889 loadRhs, loadAcc);
890
891 if (ipType.isSignlessInteger(8))
892 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
893 loadRhs, loadAcc);
894
895 amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
896
897 rewriter.eraseOp(resultWriteOp);
898 return success();
899 }
900
901 // Case 2: The acc are passed as iter args through the reduction loop.
902 // We support, reduction loop depth until 2. TODO: Support for n-depth
903 // reduction loop.
904 // TODOs: Re-factor 2a and 2b.
905 SmallVector<scf::ForOp> loopLists;
906 Operation *current = contractOp;
907 while (true) {
908 Operation *parent = current->getParentOfType<scf::ForOp>();
909
910 if (!parent)
911 return rewriter.notifyMatchFailure(
912 contractOp,
913 "Accumulator read and contract op not within scf.for op");
914
915 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
916
917 if (accReadOp->getBlock() == parent->getBlock()) {
918 break;
919 }
920
921 current = parent;
922 }
923 if (loopLists.size() > 2 || loopLists.size() == 0)
924 return rewriter.notifyMatchFailure(
925 contractOp, "Rewrite is supported until reduction loop depth of 2.");
926
927 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
928 contractOp.getLhs(), false);
929 if (failed(srcIndxLhs))
930 return rewriter.notifyMatchFailure(contractOp,
931 "The LHS src is not a MemRef type.");
932 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
933
934 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
935 contractOp.getRhs(), false);
936 if (failed(srcIndxRhs))
937 return rewriter.notifyMatchFailure(contractOp,
938 "The RHS src is not a MemRef type.");
939 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
940 Operation *vectorOpLhs;
941 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
942 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
943 vectorOpLhs = readOp.getBase().getDefiningOp();
944 });
945
946 Operation *vectorOpRhs;
947 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
948 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
949 vectorOpRhs = readOp.getBase().getDefiningOp();
950 });
951
952 // Retrive all the contaction operation within the loop.
953 SmallVector<vector::ContractionOp> ops;
954 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
955
956 if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
957
958 LogicalResult validate = validateContractOps(
959 rewriter, contract, dimValue, srcBuffLhs, srcBuffRhs, true);
960
961 if (failed(validate))
962 return rewriter.notifyMatchFailure(
963 contractOp, "The associated contract operations doesn't satisfy "
964 "the re-write conditions either the dimensions are "
965 "wrong or MemRef source are different.");
966
967 ops.push_back(contract);
968 }
969 }
970
971 if (!isVnni) {
972 unsigned int pairCount = 0;
973 for (size_t j = 0; j < ops.size(); j++) {
974 for (size_t i = j; i < ops.size(); i++) {
975 if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16))
976 pairCount = pairCount + 2;
977 }
978 }
979
980 if (pairCount != ops.size())
981 return rewriter.notifyMatchFailure(
982 contractOp, "Coudn't find the pair vector contract ");
983 }
984
985 scf::ForOp innerLoop;
986 scf::ForOp outerLoop;
987
988 scf::ForOp newLoop;
989 // Case 2a: Reduction loop depth is 2.
990 if (loopLists.size() == 2) {
991 outerLoop = loopLists[1];
992 innerLoop = loopLists[0];
993
994 SmallVector<Value> loopItrArgs = createTileZeros(
995 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
996
997 if (isVnni) {
998 newLoop = scf::ForOp::create(
999 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1000 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
1001 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1002 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1003 auto newInnerLoop = createLoops(
1004 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1005 innerLoop.getUpperBound(), innerLoop.getStep(),
1006 iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
1007 vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
1008 ops, ivOuterLoop, nullptr, true, nullptr, false, false);
1009
1010 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1011 newInnerLoop.getResults());
1012 });
1013
1014 } else {
1015
1016 bool isInnerLoopUBLarger = false;
1017 bool isInnerLoopUBHasOddQuot = false;
1018
1019 int64_t ubVal = 16 * blockingFactor;
1020 mlir::Value ub = innerLoop.getUpperBound();
1021 if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
1022 if (auto intAttr =
1023 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1024 ubVal = intAttr.getInt();
1025 }
1026 }
1027
1028 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1029 isInnerLoopUBHasOddQuot =
1030 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1031
1032 rewriter.setInsertionPoint(outerLoop);
1033
1034 auto c0 =
1035 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
1036 auto c1 =
1037 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
1038 auto spillLoopBound = arith::ConstantIndexOp::create(
1039 rewriter, outerLoop.getLoc(), 16 * blockingFactor);
1040
1041 Value spillOuterLoop = arith::SubIOp::create(
1042 rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
1043 Value spillInnerLoop =
1044 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1045 innerLoop.getUpperBound(), spillLoopBound);
1046 auto bufferType =
1047 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1048 auto packedBuffer =
1049 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1050
1051 // First Shuffling outside the reduction loops
1052 IRMapping rhsMapping;
1053 rhsMapping.map(
1054 vectorOpRhs->getOperand(
1055 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
1056 c0);
1057 rhsMapping.map(
1058 vectorOpRhs->getOperand(
1059 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1060 c0);
1061 auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
1062
1063 performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
1064 ipType, blockingFactor, packedBuffer, c0);
1065
1066 // First Set of Loops
1067 auto newLoopNonSpill = scf::ForOp::create(
1068 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1069 spillOuterLoop, outerLoop.getStep(), loopItrArgs,
1070 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1071 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1072 auto newInnerLoop1 = createLoops(
1073 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1074 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1075 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1076 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1077 ivOuterLoop, packedBuffer, true, spillLoopBound,
1078 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1079
1080 auto newInnerLoop = createLoops(
1081 rewriter, innerLoop.getLoc(), spillInnerLoop,
1082 innerLoop.getUpperBound(), innerLoop.getStep(),
1083 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1084 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1085 innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
1086 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1087
1088 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1089 newInnerLoop.getResults());
1090 });
1091
1092 // Last set of Loops
1093 newLoop = scf::ForOp::create(
1094 rewriter, outerLoop.getLoc(), spillOuterLoop,
1095 outerLoop.getUpperBound(), outerLoop.getStep(),
1096 newLoopNonSpill.getResults(),
1097 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1098 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1099 auto newInnerLoop1 = createLoops(
1100 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1101 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1102 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1103 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1104 ivOuterLoop, packedBuffer, true, spillLoopBound,
1105 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1106
1107 auto newInnerLoop = createLoops(
1108 rewriter, innerLoop.getLoc(), spillInnerLoop,
1109 innerLoop.getUpperBound(), innerLoop.getStep(),
1110 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1111 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1112 innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
1113 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1114
1115 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1116 newInnerLoop.getResults());
1117 });
1118 }
1119 }
1120
1121 // Case 2b: Reduction loop depth is 1.
1122 if (loopLists.size() == 1) {
1123 innerLoop = loopLists[0];
1124
1125 SmallVector<Value> loopItrArgs = createTileZeros(
1126 rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
1127
1128 if (isVnni) {
1129
1130 newLoop = createLoops(
1131 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1132 innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
1133 opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1134 contractOp, nullptr, innerLoop, ops, nullptr, nullptr, true,
1135 nullptr, false, false);
1136
1137 } else {
1138 bool isInnerLoopUBLarger = false;
1139 bool isInnerLoopUBHasOddQuot = false;
1140
1141 int64_t ubVal = 16 * blockingFactor;
1142 mlir::Value ub = innerLoop.getUpperBound();
1143 if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
1144 if (auto intAttr =
1145 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1146 ubVal = intAttr.getInt();
1147 }
1148 }
1149
1150 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1151 isInnerLoopUBHasOddQuot =
1152 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1153
1154 rewriter.setInsertionPoint(innerLoop);
1155 auto c0 =
1156 arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
1157 auto spillLoopBound = arith::ConstantIndexOp::create(
1158 rewriter, innerLoop.getLoc(), 16 * blockingFactor);
1159
1160 Value spillInnerLoop =
1161 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1162 innerLoop.getUpperBound(), spillLoopBound);
1163
1164 auto bufferType =
1165 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1166 auto packedBuffer =
1167 memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
1168
1169 // First Shuffling outside the reduction loops
1170 IRMapping rhsMapping;
1171 rhsMapping.map(
1172 vectorOpRhs->getOperand(
1173 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1174 c0);
1175 auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
1176
1177 performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
1178 ipType, blockingFactor, packedBuffer, c0);
1179
1180 auto newLoopNonSpill = createLoops(
1181 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1182 spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
1183 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
1184 nullptr, innerLoop, ops, nullptr, packedBuffer, true,
1185 spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1186
1187 newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
1188 innerLoop.getUpperBound(), innerLoop.getStep(),
1189 newLoopNonSpill.getResults(), ipType, opType,
1190 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1191 contractOp, nullptr, innerLoop, ops, nullptr,
1192 packedBuffer, false, c0, isInnerLoopUBLarger,
1193 isInnerLoopUBHasOddQuot);
1194 }
1195
1196 // This helps the final store back to the acc uses the same code for
1197 // the both reduction loop depth 1 or 2.
1198 outerLoop = innerLoop;
1199 }
1200
1201 // Copy the amx tile accumulation results to a MemRef buffer, add the
1202 // initial accumulation value, and store back to the C-Matrix
1203
1204 if (!isVnni) {
1205 Location loc = outerLoop.getLoc();
1206 Operation *accReadOp =
1207 traceToVectorReadLikeParentOperation(contractOp.getAcc());
1208
1209 Value srcBuffAcc;
1210 SmallVector<Value> indicesAcc;
1211
1212 llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
1213 [&](auto readOp) {
1214 srcBuffAcc = readOp.getOperand(0);
1215
1216 auto indices = readOp.getIndices();
1217 indicesAcc.reserve(indices.size());
1218
1219 llvm::transform(indices, std::back_inserter(indicesAcc),
1220 [&](OpFoldResult ofr) {
1222 rewriter, loc, ofr);
1223 });
1224 });
1225
1226 auto outputShapes =
1227 mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
1228 unsigned int M = outputShapes[outputShapes.size() - 2];
1229 unsigned int N = outputShapes[outputShapes.size() - 1];
1230
1231 SmallVector<Value> dps = newLoop.getResults();
1232 auto bufferType = MemRefType::get({M, N}, opType);
1233 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
1234
1235 // Store the amx tiled-dot product output into an MxN memref.
1236 for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
1237 for (unsigned int j = 0; j < N; j = j + 16) {
1238 Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
1239 Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
1240 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
1241 ValueRange{indexOp_i, indexOp_j}, dps[k]);
1242 k++;
1243 }
1244 }
1245 auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1246 auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
1247 auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
1248 auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
1249
1250 // Create a loop that iterates over the MxN memerf, retrives two rows +
1251 // shuffle them, add up the C element values and stores them back.
1252 scf::ForOp::create(
1253 rewriter, loc, c0, mBound, one, ValueRange{},
1254 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
1255 ValueRange iterArgs) {
1256 auto row = vector::LoadOp::create(rewriter, loc,
1257 VectorType::get(16, opType),
1258 resultBuffer, ValueRange{iv, c0});
1259
1260 auto row2 = vector::LoadOp::create(
1261 rewriter, loc, VectorType::get(16, opType), resultBuffer,
1262 ValueRange{iv, c16});
1263
1264 auto shuffle1 = vector::ShuffleOp::create(
1265 rewriter, loc, VectorType::get(16, opType), row, row2,
1266 ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
1267 21, 22, 23});
1268
1269 auto shuffle2 = vector::ShuffleOp::create(
1270 rewriter, loc, VectorType::get(16, opType), row, row2,
1271 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
1272 28, 29, 30, 31});
1273
1274 indicesAcc[indicesAcc.size() - 2] = iv;
1275 indicesAcc[indicesAcc.size() - 1] = c0;
1276
1277 Value valueCRow1 = vector::LoadOp::create(
1278 rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
1279 indicesAcc);
1280 indicesAcc[indicesAcc.size() - 1] = c16;
1281
1282 Value valueCRow2 = vector::LoadOp::create(
1283 rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
1284 indicesAcc);
1285
1286 Value addOp;
1287 Value addOp2;
1288
1289 if (ipType.isBF16()) {
1290 addOp =
1291 arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
1292
1293 addOp2 =
1294 arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
1295 }
1296
1297 if (ipType.isSignlessInteger(8)) {
1298 addOp =
1299 arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
1300
1301 addOp2 =
1302 arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
1303 }
1304 indicesAcc[indicesAcc.size() - 1] = c0;
1305 vector::StoreOp::create(rewriter, loc, addOp, srcBuffAcc,
1306 indicesAcc);
1307 indicesAcc[indicesAcc.size() - 1] = c16;
1308 vector::StoreOp::create(rewriter, loc, addOp2, srcBuffAcc,
1309 indicesAcc);
1310
1311 scf::YieldOp::create(nestedBuilder, loc);
1312 });
1313 }
1314
1315 auto bufferType = MemRefType::get({16, 16}, opType);
1316 auto resultBuffer =
1317 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1318 SmallVector<Value> dps = newLoop.getResults();
1319
1320 for (size_t i = 0; i < ops.size(); i++) {
1321 vector::ContractionOp contOp = ops[i];
1322 Operation *resultWriteOp =
1323 traceToVectorWriteLikeUserOperation(contOp.getResult());
1324 if (isVnni) {
1325 rewriter.setInsertionPoint(resultWriteOp);
1326
1327 Value indexOp_0 =
1328 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
1329
1330 amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), resultBuffer,
1331 ValueRange{indexOp_0, indexOp_0}, dps[i]);
1332
1333 auto c0 =
1334 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
1335 auto one =
1336 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
1337 auto mBound =
1338 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
1339
1340 scf::ForOp::create(
1341 rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
1342 [&](OpBuilder &builder, Location loc, Value iv,
1343 ValueRange iterArgs) {
1344 auto resultAcc = vector::LoadOp::create(
1345 rewriter, loc, VectorType::get(16, opType), resultBuffer,
1346 ValueRange{iv, c0});
1347
1348 Operation *accReadOp =
1349 traceToVectorReadLikeParentOperation(ops[i].getAcc());
1350
1351 Value srcBuffAcc;
1352 SmallVector<Value> indicesAcc;
1353
1354 llvm::TypeSwitch<Operation *>(accReadOp)
1355 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
1356 srcBuffAcc = readOp.getOperand(0);
1357
1358 auto indices = readOp.getIndices();
1359 indicesAcc.reserve(indices.size());
1360
1361 llvm::transform(
1362 indices, std::back_inserter(indicesAcc),
1363 [&](OpFoldResult ofr) {
1365 rewriter, loc, ofr);
1366 });
1367 });
1368
1369 Value sum =
1370 arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
1371 indicesAcc[indicesAcc.size() - 2] = sum;
1372
1373 auto acc = vector::LoadOp::create(rewriter, loc,
1374 VectorType::get(16, opType),
1375 srcBuffAcc, indicesAcc);
1376 Value addition;
1377 if (ipType.isBF16())
1378 addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
1379
1380 if (ipType.isSignlessInteger(8))
1381 addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
1382
1383 vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
1384 indicesAcc);
1385
1386 scf::YieldOp::create(builder, outerLoop.getLoc());
1387 });
1388 }
1389
1390 rewriter.eraseOp(resultWriteOp);
1391 }
1392
1393 return success();
1394 }
1395};
1396
1397} // namespace
1398
1400 RewritePatternSet &patterns) {
1401 patterns.add<VectorContractToAMXDotProduct>(patterns.getContext());
1402}
return success()
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
#define rem(a, b)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatType getBF16Type()
Definition Builders.cpp:41
IndexType getIndexType()
Definition Builders.cpp:55
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
mlir::x86::AMXTileType TileType
Definition X86Dialect.h:40
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:194
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:154
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:352
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.