MLIR 23.0.0git
XeGPUTransformOps.cpp
Go to the documentation of this file.
1//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "llvm/ADT/SmallVectorExtras.h"
16
17#include <optional>
18
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
21
22using namespace mlir;
23using namespace mlir::transform;
24
25/// Assuming that `ofr` is an index attr or a param of index type
26/// or a transform dialect handle mapped to exactly one op
27/// with one index result, get that value and cast it to int type.
29 transform::TransformState &state, TransformOpInterface transformOp,
31 for (OpFoldResult ofr : ofrs) {
32 // Attribute case.
33 if (auto attr = dyn_cast<Attribute>(ofr)) {
34 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
36 continue;
37 }
38 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
39 }
40
41 // Transform param case.
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
44 ArrayRef<Attribute> params = state.getParams(transformValue);
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 << "requires exactly one parameter associated";
48 result.push_back(
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
50 continue;
51 }
52
53 // Payload value case.
54 auto payloadOps = state.getPayloadOps(transformValue);
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 << "handle must be mapped to exactly one payload op";
59 diag.attachNote(transformValue.getLoc())
60 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
61 return diag;
62 }
63
64 Operation *op = *payloadOps.begin();
65 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
67 transformOp.emitSilenceableError()
68 << "payload op must have exactly 1 index result";
69 diag.attachNote(op->getLoc())
70 << "has " << op->getNumResults() << " results";
71 return diag;
72 }
73
74 IntegerAttr intAttr;
75 if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
76 return transformOp.emitSilenceableError()
77 << "requires param or handle to be the result of a constant like "
78 "op";
79
80 result.push_back(intAttr.getInt());
81 }
83}
84
85/// Find producer operation of type T for the given value.
86/// It's assumed that producer ops are chained through their first operand.
87/// Producer chain is traced trough loop block arguments (init values).
88template <typename T>
89static std::optional<T> findProducerOfType(Value val) {
90 Value currentValue = val;
91 if (!currentValue.getDefiningOp()) {
92 // Value may be a block argument initialized outside a loop.
93 if (val.getNumUses() == 0) {
94 LDBG() << "Failed to find producer op, value has no uses.";
95 return std::nullopt;
96 }
97 auto userOp = val.getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
99 if (!parentLoop) {
100 LDBG() << "Failed to find producer op, not in a loop.";
101 return std::nullopt;
102 }
103 int64_t iterArgIdx;
104 if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
108 } else {
109 LDBG() << "Failed to find producer op, value not in init values.";
110 return std::nullopt;
111 }
112 }
113 Operation *producerOp = currentValue.getDefiningOp();
114
115 if (auto matchingOp = dyn_cast<T>(producerOp))
116 return matchingOp;
117
118 if (producerOp->getNumOperands() == 0)
119 return std::nullopt;
120
121 return findProducerOfType<T>(producerOp->getOperand(0));
122}
123
124/// Create a layout attribute from the given parameters.
125static xegpu::LayoutAttr createLayoutAttr(
126 MLIRContext *ctx, ArrayRef<int32_t> sgLayout, ArrayRef<int32_t> sgData,
127 std::optional<ArrayRef<int32_t>> instData, ArrayRef<int32_t> order) {
128 return xegpu::LayoutAttr::get(
129 ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
130 DenseI32ArrayAttr::get(ctx, sgData),
131 instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
132 /*lane_layout=*/nullptr,
133 /*lane_data=*/nullptr,
134 /*order=*/order.empty() ? nullptr : DenseI32ArrayAttr::get(ctx, order));
135}
136
137/// Generate `xegpu::LayoutAttr` from op mixed layout values.
140 TransformOpInterface transformOp,
141 ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
143 ArrayRef<::mlir::OpFoldResult> mixedInstData,
144 ArrayRef<int32_t> order,
145 xegpu::LayoutAttr &layoutAttr) {
146 SmallVector<int32_t> sgLayout, sgData, instData;
147 auto status =
148 convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
149 if (!status.succeeded())
150 return status;
151
152 status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
153 if (!status.succeeded())
154 return status;
155
156 status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
157 if (!status.succeeded())
158 return status;
159 auto maybeInstData = instData.empty()
160 ? std::nullopt
161 : std::optional<ArrayRef<int32_t>>(instData);
162
163 layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData, order);
164
166}
167
168/// Replace xegpu.create_nd_desc op with a new one with the given layout.
169static xegpu::CreateNdDescOp
171 xegpu::CreateNdDescOp descOp,
172 xegpu::DistributeLayoutAttr layout) {
173 assert(descOp.getMixedOffsets().size() == 0 &&
174 "create desc op with offsets is not supported");
175 auto oldTensorDesc = descOp.getType();
176 auto descType = xegpu::TensorDescType::get(
177 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
178 /*array_length=*/oldTensorDesc.getArrayLength(),
179 /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
180 /*memory_space=*/oldTensorDesc.getMemorySpace(),
181 /*layout=*/layout);
182
183 rewriter.setInsertionPointAfter(descOp);
184 auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
185 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
186 descOp.getMixedStrides());
187 return newDescOp;
188}
189
191transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
194 auto targetValues = state.getPayloadValues(getTarget());
195 if (!llvm::hasSingleElement(targetValues)) {
196 return emitDefiniteFailure()
197 << "requires exactly one target value handle (got "
198 << llvm::range_size(targetValues) << ")";
199 }
200
201 auto maybeDescOp =
202 findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
203 if (!maybeDescOp) {
204 return emitSilenceableFailure(getLoc())
205 << "Could not find a matching descriptor op when walking the "
206 "producer chain of the first operand.";
207 }
208
209 results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
211}
212
213void transform::SetDescLayoutOp::build(OpBuilder &builder,
215 ArrayRef<OpFoldResult> mixedSgLayout,
216 ArrayRef<OpFoldResult> mixedSgData,
217 ArrayRef<OpFoldResult> mixedInstData,
218 ArrayRef<int32_t> order,
219 ArrayRef<int64_t> sliceDims) {
220 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
221 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
222 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
223 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
224 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
225 build(builder, result, target.getType(),
226 /*target=*/target,
227 /*sg_layout=*/dynamicSgLayout,
228 /*sg_data=*/dynamicSgData,
229 /*inst_data=*/dynamicInstData,
230 /*static_sg_layout=*/staticSgLayout,
231 /*static_sg_data=*/staticSgData,
232 /*static_inst_data=*/staticInstData,
233 /*order=*/order,
234 /*slice_dims=*/sliceDims);
235}
236
238transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
241 auto targetOps = state.getPayloadOps(getTarget());
242 if (!llvm::hasSingleElement(targetOps)) {
243 return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
244 << llvm::range_size(targetOps) << ")";
245 }
246 Operation *target = *targetOps.begin();
247
248 xegpu::LayoutAttr layoutAttr = nullptr;
249 auto status = getLayoutAttrFromOperands(
250 getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
251 getMixedInstData(), getOrder(), layoutAttr);
252 if (!status.succeeded())
253 return status;
254
255 xegpu::DistributeLayoutAttr layout = layoutAttr;
256 auto sliceDims = getSliceDims();
257 if (sliceDims.size() > 0) {
258 // Wrap layoutAttr in a slice attribute.
259 layout = xegpu::SliceAttr::get(
260 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
261 }
262
263 // For now only create_nd_desc op is supported.
264 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
265 if (!descOp) {
266 auto diag = emitSilenceableFailure(getLoc())
267 << "Expected a xegpu.create_nd_desc op, but got: "
268 << target->getName();
269 diag.attachNote(target->getLoc()) << "target op";
270 return diag;
271 }
272
273 // Set layout attr in desc op's return type. Replaces old desc op.
274 auto newdescOp = setDescLayout(rewriter, descOp, layout);
275
276 // Map result handles.
277 results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
278
280}
281
282void transform::SetDescLayoutOp::getEffects(
284 consumesHandle(getTargetMutable(), effects);
285 onlyReadsHandle(getSgLayoutMutable(), effects);
286 onlyReadsHandle(getSgDataMutable(), effects);
287 onlyReadsHandle(getInstDataMutable(), effects);
288 producesHandle(getOperation()->getOpResults(), effects);
289 modifiesPayload(effects);
290}
291
292void transform::SetOpLayoutAttrOp::build(
293 OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
294 ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
295 ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int32_t> order,
296 ArrayRef<int64_t> sliceDims, bool result, bool operand) {
297 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
298 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
299 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
300 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
301 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
302 build(builder, ostate, target.getType(),
303 /*target=*/target,
304 /*index=*/index,
305 /*sg_layout=*/dynamicSgLayout,
306 /*sg_data=*/dynamicSgData,
307 /*inst_data=*/dynamicInstData,
308 /*static_sg_layout=*/staticSgLayout,
309 /*static_sg_data=*/staticSgData,
310 /*static_inst_data=*/staticInstData,
311 /*order=*/order,
312 /*slice_dims=*/sliceDims,
313 /*result=*/result,
314 /*operand=*/operand);
315}
316
318transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
321 auto targetOps = state.getPayloadOps(getTarget());
322 if (!llvm::hasSingleElement(targetOps)) {
323 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
324 << llvm::range_size(targetOps) << ")";
325 }
326 Operation *target = *targetOps.begin();
327
328 bool resultTarget = getResult();
329 bool operandTarget = getOperand();
330
332 if (resultTarget && index >= target->getNumResults()) {
333 return emitSilenceableFailure(getLoc())
334 << "Index exceeds the number of op results";
335 }
336 if (operandTarget && index >= target->getNumOperands()) {
337 return emitSilenceableFailure(getLoc())
338 << "Index exceeds the number of op operands";
339 }
340
341 xegpu::LayoutAttr layoutAttr = nullptr;
342 auto status = getLayoutAttrFromOperands(
343 getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
344 getMixedInstData(), getOrder(), layoutAttr);
345 if (!status.succeeded())
346 return status;
347
348 xegpu::DistributeLayoutAttr layout = layoutAttr;
349 auto sliceDims = getSliceDims();
350 if (sliceDims.size() > 0) {
351 // Wrap layoutAttr in a slice attribute.
352 layout = xegpu::SliceAttr::get(
353 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
354 }
355
356 // Set layout attribute
357 if (resultTarget) {
358 // op result
359 xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
360 } else if (operandTarget) {
361 // op operand
362 xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
363 } else if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
364 // dpas op is a special case where layout needs to be set for A, B, and C
365 if (index == 0)
366 dpasOp.getProperties().layout_a = layout;
367 else if (index == 1)
368 dpasOp.getProperties().layout_b = layout;
369 else if (index == 2)
370 dpasOp.getProperties().layout_cd = layout;
371 else {
372 auto diag = emitSilenceableFailure(getLoc())
373 << "Invalid index for setting dpas op layout: " << index;
374 diag.attachNote(target->getLoc()) << "target op";
375 return diag;
376 }
377 } else {
378 // op's anchor layout.
379 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
380 if (!anchorOp) {
381 auto diag = emitSilenceableFailure(getLoc())
382 << "Cannot set anchor layout to op: " << target->getName();
383 diag.attachNote(target->getLoc()) << "target op";
384 return diag;
385 }
386 anchorOp.setAnchorLayout(layout);
387 }
389}
390
391void transform::SetOpLayoutAttrOp::getEffects(
393 onlyReadsHandle(getTargetMutable(), effects);
394 onlyReadsHandle(getSgLayoutMutable(), effects);
395 onlyReadsHandle(getSgDataMutable(), effects);
396 onlyReadsHandle(getInstDataMutable(), effects);
397 modifiesPayload(effects);
398}
399
400LogicalResult transform::SetOpLayoutAttrOp::verify() {
401 if (getResult() && getOperand()) {
402 return emitOpError("Cannot set both result and operand simultaneously.");
403 }
404 return success();
405}
406
407void transform::SetGPULaunchThreadsOp::build(
408 OpBuilder &builder, OperationState &ostate, Value target,
409 ArrayRef<OpFoldResult> mixedThreads) {
410 SmallVector<int64_t> staticThreads;
411 SmallVector<Value> dynamicThreads;
412 dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
413 build(builder, ostate, target.getType(),
414 /*target=*/target,
415 /*threads=*/dynamicThreads,
416 /*static_threads=*/staticThreads);
417}
418
420transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
423 auto targetOps = state.getPayloadOps(getTarget());
424 if (!llvm::hasSingleElement(targetOps)) {
425 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
426 << llvm::range_size(targetOps) << ")";
427 }
428 Operation *target = *targetOps.begin();
429
430 auto launchOp = dyn_cast<gpu::LaunchOp>(target);
431 if (!launchOp) {
432 auto diag = emitSilenceableFailure(getLoc())
433 << "Expected a gpu.launch op, but got: " << target->getName();
434 diag.attachNote(target->getLoc()) << "target op";
435 return diag;
436 }
437
438 SmallVector<int32_t> threads;
440 convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
441 if (!status.succeeded())
442 return status;
443
444 if (threads.size() != 3) {
445 return emitSilenceableFailure(getLoc())
446 << "Expected threads argument to consist of three values (got "
447 << threads.size() << ")";
448 }
449
450 rewriter.setInsertionPoint(launchOp);
451 auto createConstValue = [&](int value) {
452 return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
453 };
454
455 // Replace threads in-place.
456 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
457 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
458 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
459
461}
462
463void transform::SetGPULaunchThreadsOp::getEffects(
465 onlyReadsHandle(getTargetMutable(), effects);
466 onlyReadsHandle(getThreadsMutable(), effects);
467 modifiesPayload(effects);
468}
469
471transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
474 auto targetValues = state.getPayloadValues(getTarget());
475 if (!llvm::hasSingleElement(targetValues))
476 return emitDefiniteFailure()
477 << "requires exactly one target value handle (got "
478 << llvm::range_size(targetValues) << ")";
479 auto value = *targetValues.begin();
480
481 int64_t nbPrefetch = getStaticNbPrefetch();
482 if (getDynamicNbPrefetch()) {
483 // Get dynamic prefetch count from transform param or handle.
484 SmallVector<int32_t> dynamicNbPrefetch;
485 auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
486 {getDynamicNbPrefetch()});
487 if (!status.succeeded())
488 return status;
489 if (dynamicNbPrefetch.size() != 1)
490 return emitDefiniteFailure()
491 << "requires exactly one value for dynamic_nb_prefetch";
492 nbPrefetch = dynamicNbPrefetch[0];
493 }
494 if (nbPrefetch <= 0)
495 return emitSilenceableFailure(getLoc())
496 << "nb_prefetch must be a positive integer.";
497
498 // Find load operation of the operand.
499 auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
500 if (!maybeLoadOp)
501 return emitSilenceableFailure(getLoc()) << "Could not find load op.";
502 auto loadOp = *maybeLoadOp;
503 if (loadOp.getMixedOffsets().size() == 0) {
504 auto diag = emitSilenceableFailure(getLoc())
505 << "Load op must have offsets.";
506 diag.attachNote(loadOp.getLoc()) << "load op";
507 return diag;
508 }
509
510 // Find the parent scf.for loop.
511 auto forOp = loadOp->getParentOfType<scf::ForOp>();
512 if (!forOp) {
513 auto diag = emitSilenceableFailure(getLoc())
514 << "Load op is not contained in a scf.for loop.";
515 diag.attachNote(loadOp.getLoc()) << "load op";
516 return diag;
517 }
518
519 // Find descriptor op.
520 auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
521 if (!maybeDescOp)
522 return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
523 auto descOp = *maybeDescOp;
524 if (descOp.getMixedOffsets().size() > 0) {
525 auto diag = emitSilenceableFailure(getLoc())
526 << "desc op with offsets is not supported.";
527 diag.attachNote(descOp.getLoc()) << "desc op";
528 }
529
530 // Clone desc op outside the loop.
531 rewriter.setInsertionPoint(forOp);
532 auto newDescOp =
533 cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
534
535 // Clone reduction loop to emit initial prefetches.
536 // Compute upper bound of the init loop: start + nbPrefetch * step.
537 auto nbPrefetchCst =
538 arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
539 auto nbStep = rewriter.createOrFold<arith::MulIOp>(
540 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
541 auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
542 forOp.getLoc(), forOp.getLowerBound(), nbStep);
543 auto initForOp =
544 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
545 initUpBound, forOp.getStep());
546
547 auto ctx = rewriter.getContext();
548 auto readCacheHint =
549 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
550
551 // Modify loadOp mixedOffsets by replacing the for loop induction variable
552 // with the given value.
553 auto getPrefetchOffsets =
554 [&](Value replacementVal) -> SmallVector<OpFoldResult> {
555 IRMapping mapping;
556 mapping.map(forOp.getInductionVar(), replacementVal);
557 SmallVector<Value> dynamicOffsets =
558 llvm::map_to_vector(loadOp.getOffsets(), [&](Value v) {
559 return mapping.lookupOrDefault(v);
560 });
561 auto constOffsets = loadOp.getConstOffsets().value();
562 return getMixedValues(constOffsets, dynamicOffsets, ctx);
563 };
564
565 // Insert prefetch op in init loop.
566 // Replace induction var with the init loop induction var.
567 rewriter.setInsertionPointToStart(initForOp.getBody());
568 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
569 newDescOp.getResult(),
570 getPrefetchOffsets(initForOp.getInductionVar()),
571 readCacheHint, readCacheHint, readCacheHint,
572 /*layout=*/nullptr);
573
574 // Insert prefetch op in main loop.
575 // Calculate prefetch offset after the init prefetches have been issued.
576 rewriter.setInsertionPointToStart(forOp.getBody());
577 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
578 forOp.getInductionVar(), nbStep);
579 // Replace induction var with correct offset.
580 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
581 newDescOp.getResult(),
582 getPrefetchOffsets(prefetchOffset), readCacheHint,
583 readCacheHint, readCacheHint, /*layout=*/nullptr);
584
585 // Unroll the init loop.
586 if (failed(loopUnrollFull(initForOp)))
587 return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
588
589 results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
590
592}
593
594void transform::InsertPrefetchOp::getEffects(
596 onlyReadsHandle(getTargetMutable(), effects);
597 onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
598 producesHandle(getOperation()->getOpResults(), effects);
599 modifiesPayload(effects);
600}
601
602void transform::ConvertLayoutOp::build(
603 OpBuilder &builder, OperationState &ostate, Value target,
604 ArrayRef<OpFoldResult> mixedInputSgLayout,
605 ArrayRef<OpFoldResult> mixedInputSgData,
606 ArrayRef<OpFoldResult> mixedInputInstData, ArrayRef<int32_t> inputOrder,
607 ArrayRef<OpFoldResult> mixedTargetSgLayout,
608 ArrayRef<OpFoldResult> mixedTargetSgData,
609 ArrayRef<OpFoldResult> mixedTargetInstData, ArrayRef<int32_t> targetOrder) {
610 SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
611 staticInputInstData;
612 SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
613 dynamicInputInstData;
614 dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
615 staticInputSgLayout);
616 dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
617 staticInputSgData);
618 dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
619 staticInputInstData);
620 SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
621 staticTargetInstData;
622 SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
623 dynamicTargetInstData;
624 dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
625 staticTargetSgLayout);
626 dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
627 staticTargetSgData);
628 dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
629 staticTargetInstData);
630 build(builder, ostate, target.getType(),
631 /*target=*/target,
632 /*input_sg_layout=*/dynamicInputSgLayout,
633 /*input_sg_data=*/dynamicInputSgData,
634 /*input_inst_data=*/dynamicInputInstData,
635 /*target_sg_layout=*/dynamicTargetSgLayout,
636 /*target_sg_data=*/dynamicTargetSgData,
637 /*target_inst_data=*/dynamicTargetInstData,
638 /*input_order=*/inputOrder,
639 /*static_input_sg_layout=*/staticInputSgLayout,
640 /*static_input_sg_data=*/staticInputSgData,
641 /*static_input_inst_data=*/staticInputInstData,
642 /*static_target_sg_layout=*/staticTargetSgLayout,
643 /*static_target_sg_data=*/staticTargetSgData,
644 /*static_target_inst_data=*/staticTargetInstData,
645 /*target_order=*/targetOrder);
646}
647
649transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
652 auto targetValues = state.getPayloadValues(getTarget());
653 if (!llvm::hasSingleElement(targetValues))
654 return emitDefiniteFailure()
655 << "requires exactly one target value handle (got "
656 << llvm::range_size(targetValues) << ")";
657 auto value = *targetValues.begin();
658
659 // Construct layout attributes.
660 xegpu::LayoutAttr inputLayoutAttr = nullptr;
661 auto status = getLayoutAttrFromOperands(
662 getContext(), state, (*this), getMixedInputSgLayout(),
663 getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
664 inputLayoutAttr);
665 if (!status.succeeded())
666 return status;
667
668 xegpu::LayoutAttr targetLayoutAttr = nullptr;
670 getContext(), state, (*this), getMixedTargetSgLayout(),
671 getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
672 targetLayoutAttr);
673 if (!status.succeeded())
674 return status;
675
676 // Find first user op to define insertion point for layout conversion.
677 if (value.use_empty())
678 return emitSilenceableFailure(getLoc())
679 << "Value has no users to insert layout conversion.";
680 Operation *userOp = *value.getUsers().begin();
681
682 // Emit convert_layout op.
683 rewriter.setInsertionPoint(userOp);
684 auto convLayoutOp =
685 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
686 value, inputLayoutAttr, targetLayoutAttr);
687 // Replace load op result with the converted layout.
688 rewriter.replaceUsesWithIf(
689 value, convLayoutOp.getResult(), [&](OpOperand &use) {
690 return use.getOwner() != convLayoutOp.getOperation();
691 });
692
693 results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
695}
696
697void transform::ConvertLayoutOp::getEffects(
699 onlyReadsHandle(getTargetMutable(), effects);
700 onlyReadsHandle(getInputSgLayoutMutable(), effects);
701 onlyReadsHandle(getInputSgDataMutable(), effects);
702 onlyReadsHandle(getInputInstDataMutable(), effects);
703 onlyReadsHandle(getTargetSgLayoutMutable(), effects);
704 onlyReadsHandle(getTargetSgDataMutable(), effects);
705 onlyReadsHandle(getTargetInstDataMutable(), effects);
706 producesHandle(getOperation()->getOpResults(), effects);
707 modifiesPayload(effects);
708}
709
710namespace {
711class XeGPUTransformDialectExtension
713 XeGPUTransformDialectExtension> {
714public:
715 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
716
717 using Base::Base;
718
719 void init();
720};
721
722void XeGPUTransformDialectExtension::init() {
723 declareGeneratedDialect<scf::SCFDialect>();
724 declareGeneratedDialect<arith::ArithDialect>();
725 declareGeneratedDialect<xegpu::XeGPUDialect>();
726
727 registerTransformOps<
728#define GET_OP_LIST
729#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
730 >();
731}
732} // namespace
733
734#define GET_OP_CLASSES
735#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
736
738 registry.addExtensions<XeGPUTransformDialectExtension>();
739}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static std::optional< T > findProducerOfType(Value val)
Find producer operation of type T for the given value.
static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef< int32_t > sgLayout, ArrayRef< int32_t > sgData, std::optional< ArrayRef< int32_t > > instData, ArrayRef< int32_t > order)
Create a layout attribute from the given parameters.
static DiagnosedSilenceableFailure convertMixedValuesToInt(transform::TransformState &state, TransformOpInterface transformOp, SmallVectorImpl< int32_t > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
DiagnosedSilenceableFailure getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef<::mlir::OpFoldResult > mixedSgLayout, ArrayRef<::mlir::OpFoldResult > mixedSgData, ArrayRef<::mlir::OpFoldResult > mixedInstData, ArrayRef< int32_t > order, xegpu::LayoutAttr &layoutAttr)
Generate xegpu::LayoutAttr from op mixed layout values.
static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, xegpu::CreateNdDescOp descOp, xegpu::DistributeLayoutAttr layout)
Replace xegpu.create_nd_desc op with a new one with the given layout.
MLIRContext * getContext() const
Definition Builders.h:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
unsigned getNumOperands()
Definition Operation.h:375
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIndex() const
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void registerTransformDialectExtension(DialectRegistry &registry)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:495
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.