dol: initial dol commit
[jump.git] / dol / src / dol / visitor / cbe / CbeModuleVisitor.java
diff --git a/dol/src/dol/visitor/cbe/CbeModuleVisitor.java b/dol/src/dol/visitor/cbe/CbeModuleVisitor.java
new file mode 100644 (file)
index 0000000..7987bfc
--- /dev/null
@@ -0,0 +1,417 @@
+/* $Id: CbeModuleVisitor.java 1 2010-02-24 13:03:05Z haidw $ */
+package dol.visitor.cbe;
+
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+import java.util.HashMap;
+import java.util.Vector;
+
+import dol.datamodel.pn.Channel;
+import dol.datamodel.pn.Port;
+import dol.datamodel.pn.Process;
+import dol.datamodel.pn.ProcessNetwork;
+import dol.main.UserInterface;
+import dol.util.CodePrintStream;
+import dol.visitor.PNVisitor;
+
+/**
+ * This class is a class for a visitor that is used to generate
+ * the main program.
+ *
+ * @author lschor, 2008-10-30
+ *
+ * Revision:
+ * 2008-10-30: Updated the file for the CBE
+ * 2008-11-08: Add double buffering
+ * 2008-11-16: Add new fifo implementation and defines for measurement
+ * 2008-11-21: Sink/Source do not run on the SPE, but on the PPE (as Linux
+ *             thread)
+ */
+public class CbeModuleVisitor extends PNVisitor {
+
+    /**
+     * Constructor.
+     *
+     * @param dir path of this file
+     */
+    public CbeModuleVisitor(String dir, HashMap<Port, Integer> portMap) {
+        _dir = dir;
+        _portMap = portMap;
+    }
+
+    /**
+     * Visit process network.
+     *
+     * @param x process network that needs to be rendered
+     */
+    public void visitComponent(ProcessNetwork x) {
+        try {
+            _ui = UserInterface.getInstance();
+            String filename = _dir + _delimiter + "ppu_main.c";
+            OutputStream file = new FileOutputStream(filename);
+            _mainPS = new CodePrintStream(file);
+
+            //create header section
+            _mainPS.println("// ========================");
+            _mainPS.println("// ppu_main.c file");
+            _mainPS.println("// ========================");
+            _mainPS.println("");
+
+            // Includes
+            _mainPS.println("#include \"ppu_main.h\"");
+            _mainPS.println("");
+
+            _mainPS.println("");
+            _mainPS.println("// include main function for a workloop");
+            _mainPS.println("#include \"ppu_main_workloop.h\"");
+            _mainPS.println("");
+
+            // Function to create and run one SPE thread
+            _mainPS.println("// create and run one SPE thread");
+            _mainPS.println("void *spu_pthread(void *arg) {");
+            _mainPS.println("\t spu_data_t *datp = (spu_data_t *)arg;");
+            _mainPS.println("\t uint32_t entry = SPE_DEFAULT_ENTRY;");
+            _mainPS.println("\t printf(\")PPE: spe thread starts\\n\");");
+            _mainPS.println("\t if (spe_context_run(datp->spe_ctx, "
+                    + "&entry, 0, datp->argp, NULL, NULL) < 0) {");
+            _mainPS.println("\t\t perror (\"Failed running context\"); "
+                    + "exit (1);");
+            _mainPS.println("\t}");
+            _mainPS.println("\t printf(\")PPE: spe thread stops\\n\");");
+            _mainPS.println("\t pthread_exit(NULL);");
+            _mainPS.println("}");
+            _mainPS.println("");
+
+            // Declaration of the Header function for the PPE-Wrappers
+            Vector<String> processList = new Vector<String>();
+            for (Process p : x.getProcessList()) {
+                String basename = p.getBasename();
+                if (!processList.contains(basename)) {
+                    processList.add(basename);
+                    if (!(p.getNumOfInports() > 0
+                            && p.getNumOfOutports() > 0))
+                        _mainPS.println("void *" + basename
+                                + "_wrapper( void *ptr );");
+                }
+            }
+            _mainPS.println();
+
+            // Create the port_id and the port_queue_id arrays to send
+            // over the DMA
+            for (Process process : x.getProcessList()) {
+                String processName = process.getName();
+                _mainPS.println("volatile uint32_t "+ processName
+                        + "_port_id["
+                        + roundDMA(process.getPortList().size())
+                        + "] __attribute__ ((aligned(16))); ");
+                _mainPS.println("volatile uint32_t " + processName
+                        + "_port_queue_id["
+                        + roundDMA(process.getPortList().size())
+                        + "] __attribute__ ((aligned(16)));");
+                _mainPS.println("volatile char " + processName
+                        + "_name[256] __attribute__ ((aligned(16)));");
+            }
+            _mainPS.println();
+
+            // Create the main function
+            _mainPS.println("int main()");
+            _mainPS.println("{");
+
+            // For Measure
+            _mainPS.println("#ifdef MEASURE_APPLICATION");
+            _mainPS.println("\tstruct timeval t_ppe_start, t_ppe_end;");
+            _mainPS.println("\tgettimeofday(&t_ppe_start,NULL);");
+            _mainPS.println("#endif");
+            _mainPS.println();
+
+            _mainPS.println("#ifdef MEASURE_SET_UP_SPE_THREAD");
+            _mainPS.println("\tstruct timeval t_ppe_setup_start, "
+                    + "t_ppe_setup_end;");
+            _mainPS.println("\tgettimeofday(&t_ppe_setup_start,NULL);");
+            _mainPS.println("#endif");
+            _mainPS.println();
+
+            // List with all process to be open
+            _mainPS.println("\tchar spe_names[NUM_SPES][60] = {");
+            int count = 0;
+            for (Process process : x.getProcessList()) {
+                if (process.getNumOfInports() > 0
+                        && process.getNumOfOutports() > 0) {
+                    count++;
+                    String processName = process.getBasename();
+                    _mainPS.println("\t\t\"spu_" + processName + "/spu_"
+                            + processName + "_wrapper\""
+                            + (count == x.getProcessList().size()
+                            ? "" : ", ") );
+                }
+            }
+            _mainPS.println("\t};");
+            _mainPS.println();
+            _mainPS.println("\t// Initialize the fifo, we use");
+            _mainPS.println("\tint j; ");
+            _mainPS.println("\tfor (j = 0; j < NUM_FIFO; j++)");
+            _mainPS.println("\t{");
+            _mainPS.println("\t\tlocBuf[j] = "
+                    + "(char*)malloc(MAXELEMENT * sizeof(char));");
+            _mainPS.println("\t\tlocBufCount[j] = 0;");
+            _mainPS.println("\t\tlocBufStart[j] = 0;");
+            _mainPS.println("\t\tpthread_mutex_init(&(mutex[j]), NULL);");
+            _mainPS.println("\t}");
+
+            //connect ports to channels
+            HashMap<Channel, Integer> channel_map =
+                new HashMap<Channel, Integer>();
+
+            int j = 0;
+            for (Channel c : x.getChannelList()) {
+                channel_map.put(c, j++);
+            }
+
+            // Init the SPE control structure
+            _mainPS.println("\t//Initiate SPEs control structure");
+            _mainPS.println("\tint num = 0; ");
+            _mainPS.println("\tfor( num=0; num<NUM_SPES; num++){");
+            _mainPS.println("\t\tdata[num].argp = (void *)&(ctx[num]);");
+            _mainPS.println("\t}");
+            _mainPS.println();
+
+            // Add to each process the ports and the queues
+            _mainPS.println("\t //Add to each process the ports and the "
+                    + "queues");
+            j = 0;
+            for (Process process : x.getProcessList()) {
+                String processName = process.getName();
+                int i = 0;
+                for (Port port : process.getPortList()) {
+                    Channel c = (Channel)(port.getPeerResource());
+
+                    _mainPS.println("\t" + processName + "_port_id["
+                                    + i + "] = "
+                                    + _portMap.get(port) + ";");
+                    _mainPS.println("\t" + processName + "_port_queue_id["
+                                    + i + "] = "
+                                    + channel_map.get(c) + ";");
+                    i++;
+                }
+
+                // Normal process
+                if (process.getNumOfInports() > 0
+                        && process.getNumOfOutports() > 0) {
+                    _mainPS.println("\tctx[" + j + "]"
+                                    + ".port_id = (uint64_t)"
+                                    + processName + "_port_id;");
+                    _mainPS.println("\tctx[" + j + "]"
+                                    + ".port_queue_id = (uint64_t)"
+                                    + processName + "_port_queue_id;");
+                    _mainPS.println("\tctx[" + j + "]"
+                                    + ".number_of_ports = " + i + ";");
+                    _mainPS.println("\tctx[" + j + "]"
+                                    + ".is_detached = 0;");
+                    _mainPS.println("\tstrcpy((char *)" + processName
+                            + "_name, " + "\""
+                            + processName + "\");");
+                    _mainPS.println("\tctx[" + j + "]"
+                            + ".processName = (uint64_t) " + processName
+                            + "_name;");
+                    _mainPS.println("\tctx[" + j + "]"
+                            + ".processNameLen = ((strlen((char *)"
+                            + processName + "_name) + 15) & ~15);");
+                    _mainPS.println();
+                    j++;
+                }
+                // Process is Sink or source
+                else {
+                    _mainPS.println("\tProcessWrapper *"
+                            + process.getName()
+                            + "_Process_Wrapper = (ProcessWrapper*)"
+                            + "malloc(sizeof(ProcessWrapper)); ");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->port_id = " + processName
+                            + "_port_id;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->port_queue_id = "
+                            + processName + "_port_queue_id;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->number_of_ports = "
+                            + i + ";");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->is_detached = 0;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->name = (char*)malloc("
+                            + "strlen(\"" + processName + "\"));");
+                    _mainPS.println("\tstrcpy(" + process.getName()
+                            + "_Process_Wrapper->name, \"" + processName
+                            + "\");");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->locBuf = locBuf;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->MAXELEMENT = "
+                            + "MAXELEMENT;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->locBufCount = "
+                            + "locBufCount;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->locBufStart = "
+                            + "locBufStart;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->processFinished = "
+                            + "&processFinished;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->mutex = mutex;");
+                    _mainPS.println("\t" + process.getName()
+                            + "_Process_Wrapper->mutexProcessNr = "
+                            + "&mutexProcessNr;");
+                    _mainPS.println();
+                }
+            }
+
+            _mainPS.println("\t// Loop on all SPEs and for each perform "
+                    + "three steps:");
+            _mainPS.println("\t// - create SPE context");
+            _mainPS.println("\t// - open images of SPE programs into main "
+                    + "storage");
+            _mainPS.println("\t//         <spe_names> variable store the "
+                    + "executable name");
+            _mainPS.println("\t// - Load SPEs objects into SPE context "
+                    + "local store");
+            _mainPS.println("\tfor( num=0; num<NUM_SPES; num++){");
+            _mainPS.println("\t\tif ((data[num].spe_ctx = "
+                    + "spe_context_create(0, NULL)) == NULL) {");
+            _mainPS.println("\t\t\tperror(\"Failed creating context\"); "
+                    + "exit(1);");
+            _mainPS.println("\t\t}");
+            _mainPS.println("\t\tif (!(program[num] = spe_image_open("
+                    + "&spe_names[num][0]))) {");
+            _mainPS.println("\t\t\t perror(\"Fail opening image\"); "
+                    + "exit(1);");
+            _mainPS.println("\t\t}");
+            _mainPS.println("\t\tif (spe_program_load(data[num].spe_ctx, "
+                    + "program[num])) {");
+            _mainPS.println("\t\t\tperror(\"Failed loading program\"); "
+                    + "exit(1);");
+            _mainPS.println("\t\t}      ");
+            _mainPS.println("\t}");
+            _mainPS.println("");
+
+            _mainPS.println("\t// create PPE pthreads");
+            for (Process process : x.getProcessList()) {
+                if (!(process.getNumOfInports() > 0
+                        && process.getNumOfOutports() > 0)) {
+                    _mainPS.println("\tpthread_t thread_"
+                            + process.getName() + ";");
+                    _mainPS.println("\tpthread_create( &thread_"
+                            + process.getName() + ", NULL, "
+                            + process.getBasename() + "_wrapper, "
+                            + process.getName() + "_Process_Wrapper);");
+                }
+            }
+            _mainPS.println("\t");
+
+            _mainPS.println("\t// create SPE pthreads");
+            _mainPS.println("\tfor( num=0; num<NUM_SPES; num++){");
+            _mainPS.println("\t\tif(pthread_create(&data[num].pthread, "
+                    + "NULL, &spu_pthread, &data[num])){");
+            _mainPS.println("\t\t\tperror(\"Failed creating thread\"); "
+                    + "exit(1);");
+            _mainPS.println("\t\t}");
+            _mainPS.println("\t}");
+
+            _mainPS.println();
+            _mainPS.println("#ifdef MEASURE_SET_UP_SPE_THREAD");
+            _mainPS.println("\tgettimeofday(&t_ppe_setup_end, NULL);");
+            _mainPS.println("\tprintf(\"PPE_SETUP;%f\\n\", "
+                    + "(t_ppe_setup_end.tv_sec "
+                    + "- t_ppe_setup_start.tv_sec) + 0.000001 "
+                    + "* (t_ppe_setup_end.tv_usec "
+                    + "- t_ppe_setup_start.tv_usec));");
+            _mainPS.println("#endif");
+            _mainPS.println();
+
+            _mainPS.println("\t//Start the main loop");
+            _mainPS.println("\tworkloop();");
+            _mainPS.println("");
+
+            _mainPS.println("\t// Loop on all SPEs and for each perform "
+                    + "two steps:");
+            _mainPS.println("\t//   - wait for all the SPE pthread to "
+                    + "complete");
+            _mainPS.println("\t//   - destroy the SPE contexts");
+            _mainPS.println("\tfor( num=0; num<NUM_SPES; num++){");
+            _mainPS.println("\t\tif(pthread_join(data[num].pthread, "
+                    + "NULL)) {");
+            _mainPS.println("\t\t\tperror(\"Failed joining thread\"); "
+                    + "exit(1);");
+            _mainPS.println("\t\t}");
+            _mainPS.println("\t\t");
+            _mainPS.println("\t\tif (spe_context_destroy("
+                    + "data[num].spe_ctx)) {");
+            _mainPS.println("\t\t\tperror(\"Failed "
+                    + "spe_context_destroy\"); exit(1);");
+            _mainPS.println("\t\t}");
+            _mainPS.println("\t}");
+            _mainPS.println("\tprintf(\")PPE:) Complete running "
+                    + "all SPEs\\n\");");
+            _mainPS.println("");
+            _mainPS.println("\tfor (j = 0; j < NUM_FIFO; j++)");
+            _mainPS.println("\t{");
+            _mainPS.println("\t\tfree(locBuf[j]);");
+            _mainPS.println("\t\tpthread_mutex_destroy (&(mutex[j]));");
+            _mainPS.println("\t}");
+            _mainPS.println("");
+            _mainPS.println("#ifdef MEASURE_APPLICATION");
+            _mainPS.println("\tgettimeofday(&t_ppe_end, NULL);");
+            _mainPS.println("\tprintf(\"PPE;%f\\n\",(t_ppe_end.tv_sec "
+                    + "- t_ppe_start.tv_sec) + 0.000001 "
+                    + "* (t_ppe_end.tv_usec - t_ppe_start.tv_usec));");
+            _mainPS.println("#endif");
+            _mainPS.println("");
+            _mainPS.println("\treturn (0);");
+            _mainPS.println("}");
+        }
+        catch (Exception e) {
+            System.out.println("CbeDBModuleVisitor: exception occured: "
+                    + e.getMessage());
+            e.printStackTrace();
+        }
+    }
+
+    /**
+     *
+     * @param x process that needs to be processed
+     */
+    public void visitComponent(Process x) {
+    }
+
+    /**
+     *
+     * @param x channel that needs to be processed
+     */
+    public void visitComponent(Channel x) {
+    }
+
+    /**
+     * Round the parameter to the next DMA-number up.
+     * Example: 23 gives 32.
+     * @param number number to round up
+     * @return next DMA-number
+     */
+    protected int roundDMA(int number) {
+        if (number > 16) {
+            return number + 16 - (number % 16);
+        } else if (number > 8) {
+            return 16;
+        } else if (number > 4) {
+            return 8;
+        } else if (number > 2) {
+            return 4;
+        } else if (number > 1) {
+            return 2;
+        } else {
+            return 1;
+        }
+    }
+
+    protected CodePrintStream _mainPS = null;
+    protected String _dir = null;
+    protected HashMap<Port, Integer> _portMap;
+}